package com.alibaba.cloud.ai.graph.agent;

import com.alibaba.cloud.ai.graph.CompileConfig;
import com.alibaba.cloud.ai.graph.CompiledGraph;
import com.alibaba.cloud.ai.graph.OverAllState;
import com.alibaba.cloud.ai.graph.StateGraph;
import com.alibaba.cloud.ai.graph.action.AsyncEdgeAction;
import com.alibaba.cloud.ai.graph.action.AsyncNodeAction;
import com.alibaba.cloud.ai.graph.action.NodeAction;
import com.alibaba.cloud.ai.graph.exception.GraphStateException;
import com.alibaba.cloud.ai.graph.state.strategy.ReplaceStrategy;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.chat.messages.Message;
import org.springframework.ai.chat.messages.MessageType;

/* loaded from: input_file:com/alibaba/cloud/ai/graph/agent/ReflectAgent.class */
public class ReflectAgent {
    private static final Logger logger = LoggerFactory.getLogger(ReflectAgent.class);
    public static final String MESSAGES = "messages";
    public static final String ITERATION_NUM = "iteration_num";
    private final String REFLECTION_NODE_ID = "reflection";
    private final String GRAPH_NODE_ID = "graph";
    private int maxIterations;
    private NodeAction graph;
    private NodeAction reflection;
    private StateGraph stateGraph;
    private CompiledGraph compiledGraph;
    private CompileConfig compileConfig;
    private String name;

    /* loaded from: input_file:com/alibaba/cloud/ai/graph/agent/ReflectAgent$Builder.class */
    public static class Builder {
        private String name;
        private NodeAction graph;
        private NodeAction reflection;
        private int maxIterations = 5;
        private CompileConfig compileConfig;

        public Builder name(String str) {
            this.name = str;
            return this;
        }

        public Builder graph(NodeAction nodeAction) {
            this.graph = nodeAction;
            return this;
        }

        public Builder reflection(NodeAction nodeAction) {
            this.reflection = nodeAction;
            return this;
        }

        public Builder maxIterations(int i) {
            this.maxIterations = i;
            return this;
        }

        public Builder compileConfig(CompileConfig compileConfig) {
            this.compileConfig = compileConfig;
            return this;
        }

        public ReflectAgent build() throws GraphStateException {
            if (this.graph == null || this.reflection == null) {
                throw new IllegalArgumentException("Graph and reflection must be provided");
            }
            return new ReflectAgent(this.name, this.graph, this.reflection, this.maxIterations, this.compileConfig);
        }
    }

    public ReflectAgent(NodeAction nodeAction, NodeAction nodeAction2, int i) {
        this.graph = nodeAction;
        this.reflection = nodeAction2;
        this.maxIterations = i;
    }

    public ReflectAgent(String str, NodeAction nodeAction, NodeAction nodeAction2, int i, CompileConfig compileConfig) {
        this.name = str;
        this.graph = nodeAction;
        this.reflection = nodeAction2;
        this.maxIterations = i;
        this.compileConfig = compileConfig;
    }

    public StateGraph createReflectionGraph() throws GraphStateException {
        return createReflectionGraph(this.graph, this.reflection, this.maxIterations);
    }

    public StateGraph createReflectionGraph(NodeAction nodeAction, NodeAction nodeAction2) throws GraphStateException {
        return createReflectionGraph(nodeAction, nodeAction2, 5);
    }

    public StateGraph createReflectionGraph(NodeAction nodeAction, NodeAction nodeAction2, int i) throws GraphStateException {
        this.maxIterations = i;
        logger.debug("Creating reflection graph with max iterations: {}", Integer.valueOf(i));
        StateGraph addConditionalEdges = new StateGraph(() -> {
            OverAllState overAllState = new OverAllState();
            overAllState.registerKeyAndStrategy(MESSAGES, new ReplaceStrategy());
            overAllState.registerKeyAndStrategy(ITERATION_NUM, new ReplaceStrategy());
            return overAllState;
        }).addNode("graph", AsyncNodeAction.node_async(nodeAction)).addNode("reflection", AsyncNodeAction.node_async(nodeAction2)).addEdge(StateGraph.START, "graph").addConditionalEdges("graph", AsyncEdgeAction.edge_async(this::graphCount), Map.of("reflection", "reflection", StateGraph.END, StateGraph.END)).addConditionalEdges("reflection", AsyncEdgeAction.edge_async(this::apply), Map.of("graph", "graph", StateGraph.END, StateGraph.END));
        logger.info("Reflection graph created successfully with {} nodes", 2);
        return addConditionalEdges;
    }

    public StateGraph getStateGraph() throws GraphStateException {
        if (this.stateGraph == null) {
            this.stateGraph = createReflectionGraph();
        }
        return this.stateGraph;
    }

    public CompiledGraph getCompiledGraph() throws GraphStateException {
        if (this.compiledGraph == null) {
            getAndCompileGraph();
        }
        return this.compiledGraph;
    }

    public CompiledGraph getAndCompileGraph(CompileConfig compileConfig) throws GraphStateException {
        this.compiledGraph = getStateGraph().compile(compileConfig);
        return this.compiledGraph;
    }

    public CompiledGraph getAndCompileGraph() throws GraphStateException {
        if (this.compileConfig == null) {
            this.compiledGraph = getStateGraph().compile();
        } else {
            this.compiledGraph = getStateGraph().compile(this.compileConfig);
        }
        return this.compiledGraph;
    }

    public void printMessage(OverAllState overAllState) {
        for (Message message : (List) overAllState.value(MESSAGES).get()) {
            logger.info(message.getMessageType().name());
            logger.info(message.getText());
            logger.info("===================================");
        }
    }

    public String graphCount(OverAllState overAllState) throws Exception {
        Optional value = overAllState.value(ITERATION_NUM);
        if (value.isPresent()) {
            Integer num = (Integer) value.get();
            logger.info("Current iteration: {} | Max iterations: {}", num, Integer.valueOf(this.maxIterations));
            if (num.intValue() >= this.maxIterations) {
                logger.info("Iteration limit reached, stopping reflection cycle");
                overAllState.updateState(Map.of(ITERATION_NUM, 0));
                printMessage(overAllState);
                return StateGraph.END;
            }
            int intValue = num.intValue() + 1;
            logger.debug("Incrementing iteration counter from {} to {}", num, Integer.valueOf(intValue));
            overAllState.updateState(Map.of(ITERATION_NUM, Integer.valueOf(intValue)));
        } else {
            logger.debug("Initializing iteration counter to 1");
            overAllState.updateState(Map.of(ITERATION_NUM, 1));
        }
        logger.debug("Updated iteration count: {}", (Integer) overAllState.value(ITERATION_NUM).orElseThrow());
        return "reflection";
    }

    public String apply(OverAllState overAllState) throws Exception {
        List list = (List) overAllState.value(MESSAGES).get();
        int size = list.size();
        logger.debug("Processing messages, found {} messages in state", Integer.valueOf(size));
        if (size == 0) {
            logger.info("No messages to process, ending reflection cycle");
            return StateGraph.END;
        }
        if (((Message) list.get(list.size() - 1)).getMessageType().equals(MessageType.ASSISTANT)) {
            logger.info("Last message is from assistant: {}", ((Message) list.get(list.size() - 1)).getText());
            return StateGraph.END;
        }
        logger.debug("Last message is from user, continuing to graph node");
        return "graph";
    }

    public static Builder builder() {
        return new Builder();
    }
}
