package com.alibaba.cloud.ai.graph.agent;

import com.alibaba.cloud.ai.graph.CompileConfig;
import com.alibaba.cloud.ai.graph.CompiledGraph;
import com.alibaba.cloud.ai.graph.OverAllState;
import com.alibaba.cloud.ai.graph.StateGraph;
import com.alibaba.cloud.ai.graph.action.AsyncEdgeAction;
import com.alibaba.cloud.ai.graph.action.AsyncNodeAction;
import com.alibaba.cloud.ai.graph.action.NodeAction;
import com.alibaba.cloud.ai.graph.exception.GraphStateException;
import com.alibaba.cloud.ai.graph.node.LlmNode;
import com.alibaba.cloud.ai.graph.node.ToolNode;
import com.alibaba.cloud.ai.graph.state.strategy.AppendStrategy;
import java.util.List;
import java.util.Map;
import java.util.function.Function;
import org.springframework.ai.chat.client.ChatClient;
import org.springframework.ai.chat.messages.AssistantMessage;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.tool.ToolCallback;
import org.springframework.ai.tool.resolution.ToolCallbackResolver;

/* loaded from: input_file:com/alibaba/cloud/ai/graph/agent/ReactAgent.class */
public class ReactAgent {
    private String name;
    private final LlmNode llmNode;
    private final ToolNode toolNode;
    private CompiledGraph compiledGraph;
    private List<String> tools;
    private int max_iterations;
    private CompileConfig compileConfig;
    private OverAllState state;
    private Function<OverAllState, Boolean> shouldContinueFunc;
    private int iterations = 0;
    private final StateGraph graph = initGraph();

    /* loaded from: input_file:com/alibaba/cloud/ai/graph/agent/ReactAgent$Builder.class */
    public static class Builder {
        private String name;
        private ChatClient chatClient;
        private List<ToolCallback> tools;
        private ToolCallbackResolver resolver;
        private int maxIterations = 10;
        private CompileConfig compileConfig;
        private OverAllState state;
        private Function<OverAllState, Boolean> shouldContinueFunc;

        public Builder name(String str) {
            this.name = str;
            return this;
        }

        public Builder chatClient(ChatClient chatClient) {
            this.chatClient = chatClient;
            return this;
        }

        public Builder tools(List<ToolCallback> list) {
            this.tools = list;
            return this;
        }

        public Builder resolver(ToolCallbackResolver toolCallbackResolver) {
            this.resolver = toolCallbackResolver;
            return this;
        }

        public Builder maxIterations(int i) {
            this.maxIterations = i;
            return this;
        }

        public Builder state(OverAllState overAllState) {
            this.state = overAllState;
            return this;
        }

        public Builder compileConfig(CompileConfig compileConfig) {
            this.compileConfig = compileConfig;
            return this;
        }

        public Builder shouldContinueFunction(Function<OverAllState, Boolean> function) {
            this.shouldContinueFunc = function;
            return this;
        }

        public ReactAgent build() throws GraphStateException {
            if (this.resolver != null) {
                return new ReactAgent(this.name, this.chatClient, this.resolver, this.maxIterations, this.state, this.compileConfig, this.shouldContinueFunc);
            }
            if (this.tools != null) {
                return new ReactAgent(this.name, this.chatClient, this.tools, this.maxIterations, this.state, this.compileConfig, this.shouldContinueFunc);
            }
            throw new IllegalArgumentException("Either tools or resolver must be provided");
        }
    }

    /* loaded from: input_file:com/alibaba/cloud/ai/graph/agent/ReactAgent$SubGraphNodeAdapter.class */
    public static class SubGraphNodeAdapter implements NodeAction {
        private String inputKeyFromParent;
        private String outputKeyToParent;
        private CompiledGraph childGraph;

        SubGraphNodeAdapter(String str, String str2, CompiledGraph compiledGraph) {
            this.inputKeyFromParent = str;
            this.outputKeyToParent = str2;
            this.childGraph = compiledGraph;
        }

        @Override // com.alibaba.cloud.ai.graph.action.NodeAction
        public Map<String, Object> apply(OverAllState overAllState) throws Exception {
            List list = (List) this.childGraph.invoke(Map.of(ReflectAgent.MESSAGES, List.of(new UserMessage((String) overAllState.value(this.inputKeyFromParent).orElseThrow())))).get().value(ReflectAgent.MESSAGES).orElseThrow();
            return Map.of(this.outputKeyToParent, ((AssistantMessage) list.get(list.size() - 1)).getText());
        }
    }

    public ReactAgent(LlmNode llmNode, ToolNode toolNode, int i, OverAllState overAllState, CompileConfig compileConfig, Function<OverAllState, Boolean> function) throws GraphStateException {
        this.max_iterations = 10;
        this.llmNode = llmNode;
        this.toolNode = toolNode;
        this.max_iterations = i;
        this.state = overAllState;
        this.compileConfig = compileConfig;
        this.shouldContinueFunc = function;
    }

    public ReactAgent(String str, ChatClient chatClient, List<ToolCallback> list, int i) throws GraphStateException {
        this.max_iterations = 10;
        this.name = str;
        this.llmNode = LlmNode.builder().chatClient(chatClient).messagesKey(ReflectAgent.MESSAGES).build();
        this.toolNode = ToolNode.builder().toolCallbacks(list).build();
        this.max_iterations = i;
    }

    public ReactAgent(String str, ChatClient chatClient, List<ToolCallback> list, int i, OverAllState overAllState, CompileConfig compileConfig, Function<OverAllState, Boolean> function) throws GraphStateException {
        this.max_iterations = 10;
        this.name = str;
        this.llmNode = LlmNode.builder().chatClient(chatClient).messagesKey(ReflectAgent.MESSAGES).build();
        this.toolNode = ToolNode.builder().toolCallbacks(list).build();
        this.max_iterations = i;
        this.state = overAllState;
        this.compileConfig = compileConfig;
    }

    public ReactAgent(String str, ChatClient chatClient, ToolCallbackResolver toolCallbackResolver, int i) throws GraphStateException {
        this.max_iterations = 10;
        this.name = str;
        this.llmNode = LlmNode.builder().chatClient(chatClient).messagesKey(ReflectAgent.MESSAGES).build();
        this.toolNode = ToolNode.builder().toolCallbackResolver(toolCallbackResolver).build();
        this.max_iterations = i;
    }

    public ReactAgent(String str, ChatClient chatClient, ToolCallbackResolver toolCallbackResolver, int i, OverAllState overAllState, CompileConfig compileConfig, Function<OverAllState, Boolean> function) throws GraphStateException {
        this.max_iterations = 10;
        this.name = str;
        this.llmNode = LlmNode.builder().chatClient(chatClient).messagesKey(ReflectAgent.MESSAGES).build();
        this.toolNode = ToolNode.builder().toolCallbackResolver(toolCallbackResolver).build();
        this.max_iterations = i;
        this.state = overAllState;
        this.compileConfig = compileConfig;
        this.shouldContinueFunc = function;
    }

    public StateGraph getStateGraph() {
        return this.graph;
    }

    public CompiledGraph getCompiledGraph() throws GraphStateException {
        return this.compiledGraph;
    }

    public CompiledGraph getAndCompileGraph(CompileConfig compileConfig) throws GraphStateException {
        this.compiledGraph = getStateGraph().compile(compileConfig);
        return this.compiledGraph;
    }

    public CompiledGraph getAndCompileGraph() throws GraphStateException {
        if (this.compileConfig == null) {
            this.compiledGraph = getStateGraph().compile();
        } else {
            this.compiledGraph = getStateGraph().compile(this.compileConfig);
        }
        return this.compiledGraph;
    }

    public NodeAction asNodeAction(String str, String str2) {
        return new SubGraphNodeAdapter(str, str2, this.compiledGraph);
    }

    public AsyncNodeAction asAsyncNodeAction(String str, String str2) {
        if (this.compiledGraph == null) {
            throw new IllegalStateException("ReactAgent not compiled yet");
        }
        return AsyncNodeAction.node_async(new SubGraphNodeAdapter(str, str2, this.compiledGraph));
    }

    private StateGraph initGraph() throws GraphStateException {
        if (this.state == null) {
            OverAllState overAllState = new OverAllState();
            overAllState.registerKeyAndStrategy(ReflectAgent.MESSAGES, new AppendStrategy());
            this.state = overAllState;
        }
        return new StateGraph(this.name, this.state).addNode("agent", AsyncNodeAction.node_async(this.llmNode)).addNode("tool", AsyncNodeAction.node_async(this.toolNode)).addEdge(StateGraph.START, "agent").addConditionalEdges("agent", AsyncEdgeAction.edge_async(this::think), Map.of("continue", "tool", "end", StateGraph.END)).addEdge("tool", "agent");
    }

    private String think(OverAllState overAllState) {
        if (this.iterations > this.max_iterations) {
            return "end";
        }
        if (this.shouldContinueFunc != null && !this.shouldContinueFunc.apply(overAllState).booleanValue()) {
            return "end";
        }
        List list = (List) overAllState.value(ReflectAgent.MESSAGES).orElseThrow();
        return ((AssistantMessage) list.get(list.size() - 1)).hasToolCalls() ? "continue" : "end";
    }

    List<String> getTools() {
        return this.tools;
    }

    void setTools(List<String> list) {
        this.tools = list;
    }

    int getMax_iterations() {
        return this.max_iterations;
    }

    void setMax_iterations(int i) {
        this.max_iterations = i;
    }

    int getIterations() {
        return this.iterations;
    }

    void setIterations(int i) {
        this.iterations = i;
    }

    CompileConfig getCompileConfig() {
        return this.compileConfig;
    }

    void setCompileConfig(CompileConfig compileConfig) {
        this.compileConfig = compileConfig;
    }

    OverAllState getState() {
        return this.state;
    }

    void setState(OverAllState overAllState) {
        this.state = overAllState;
    }

    Function<OverAllState, Boolean> getShouldContinueFunc() {
        return this.shouldContinueFunc;
    }

    void setShouldContinueFunc(Function<OverAllState, Boolean> function) {
        this.shouldContinueFunc = function;
    }

    public static Builder builder() {
        return new Builder();
    }
}
