package com.alibaba.cloud.ai.graph;

import com.alibaba.cloud.ai.graph.GraphRepresentation;
import com.alibaba.cloud.ai.graph.action.AsyncEdgeAction;
import com.alibaba.cloud.ai.graph.action.AsyncNodeAction;
import com.alibaba.cloud.ai.graph.action.AsyncNodeActionWithConfig;
import com.alibaba.cloud.ai.graph.checkpoint.config.SaverConfig;
import com.alibaba.cloud.ai.graph.checkpoint.constant.SaverConstant;
import com.alibaba.cloud.ai.graph.checkpoint.savers.MemorySaver;
import com.alibaba.cloud.ai.graph.exception.GraphRunnerException;
import com.alibaba.cloud.ai.graph.exception.GraphStateException;
import com.alibaba.cloud.ai.graph.internal.edge.Edge;
import com.alibaba.cloud.ai.graph.internal.edge.EdgeCondition;
import com.alibaba.cloud.ai.graph.internal.edge.EdgeValue;
import com.alibaba.cloud.ai.graph.internal.node.Node;
import com.alibaba.cloud.ai.graph.internal.node.SubCompiledGraphNode;
import com.alibaba.cloud.ai.graph.internal.node.SubStateGraphNode;
import com.alibaba.cloud.ai.graph.serializer.StateSerializer;
import com.alibaba.cloud.ai.graph.serializer.plain_text.PlainTextStateSerializer;
import com.alibaba.cloud.ai.graph.serializer.plain_text.gson.GsonStateSerializer;
import com.alibaba.cloud.ai.graph.serializer.plain_text.jackson.JacksonStateSerializer;
import com.alibaba.cloud.ai.graph.state.AgentStateFactory;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.google.gson.Gson;
import com.google.gson.GsonBuilder;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Iterator;
import java.util.LinkedHashSet;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;

/* loaded from: input_file:com/alibaba/cloud/ai/graph/StateGraph.class */
public class StateGraph {
    public static String END = "__END__";
    public static String START = "__START__";
    final Nodes nodes;
    final Edges edges;
    private OverAllState overAllState;
    private OverAllStateFactory overAllStateFactory;
    private String name;
    private final PlainTextStateSerializer stateSerializer;

    /* loaded from: input_file:com/alibaba/cloud/ai/graph/StateGraph$Edges.class */
    public static class Edges {
        public final List<Edge> elements;

        public Edges(Collection<Edge> collection) {
            this.elements = new LinkedList(collection);
        }

        public Edges() {
            this.elements = new LinkedList();
        }

        public Optional<Edge> edgeBySourceId(String str) {
            return this.elements.stream().filter(edge -> {
                return Objects.equals(edge.sourceId(), str);
            }).findFirst();
        }

        public List<Edge> edgesByTargetId(String str) {
            return this.elements.stream().filter(edge -> {
                return edge.anyMatchByTargetId(str);
            }).toList();
        }
    }

    /* loaded from: input_file:com/alibaba/cloud/ai/graph/StateGraph$Errors.class */
    public enum Errors {
        invalidNodeIdentifier("END is not a valid node id!"),
        invalidEdgeIdentifier("END is not a valid edge sourceId!"),
        duplicateNodeError("node with id: %s already exist!"),
        duplicateEdgeError("edge with id: %s already exist!"),
        duplicateConditionalEdgeError("conditional edge from '%s' already exist!"),
        edgeMappingIsEmpty("edge mapping is empty!"),
        missingEntryPoint("missing Entry Point"),
        entryPointNotExist("entryPoint: %s doesn't exist!"),
        finishPointNotExist("finishPoint: %s doesn't exist!"),
        missingNodeReferencedByEdge("edge sourceId '%s' refers to undefined node!"),
        missingNodeInEdgeMapping("edge mapping for sourceId: %s contains a not existent nodeId %s!"),
        invalidEdgeTarget("edge sourceId: %s has an initialized target value!"),
        duplicateEdgeTargetError("edge [%s] has duplicate targets %s!"),
        unsupportedConditionalEdgeOnParallelNode("parallel node doesn't support conditional branch, but on [%s] a conditional branch on %s have been found!"),
        illegalMultipleTargetsOnParallelNode("parallel node [%s] must have only one target, but %s have been found!"),
        interruptionNodeNotExist("node '%s' configured as interruption doesn't exist!");

        private final String errorMessage;

        Errors(String str) {
            this.errorMessage = str;
        }

        public GraphStateException exception(Object... objArr) {
            return new GraphStateException(String.format(this.errorMessage, objArr));
        }
    }

    /* loaded from: input_file:com/alibaba/cloud/ai/graph/StateGraph$GsonSerializer.class */
    static class GsonSerializer extends GsonStateSerializer {
        public GsonSerializer() {
            super(OverAllState::new, new GsonBuilder().serializeNulls().create());
        }

        Gson getGson() {
            return this.gson;
        }
    }

    /* loaded from: input_file:com/alibaba/cloud/ai/graph/StateGraph$GsonSerializer2.class */
    static class GsonSerializer2 extends GsonStateSerializer {
        public GsonSerializer2(AgentStateFactory<OverAllState> agentStateFactory) {
            super(agentStateFactory, new GsonBuilder().serializeNulls().create());
        }

        Gson getGson() {
            return this.gson;
        }
    }

    /* loaded from: input_file:com/alibaba/cloud/ai/graph/StateGraph$JacksonSerializer.class */
    static class JacksonSerializer extends JacksonStateSerializer {
        public JacksonSerializer() {
            super(OverAllState::new);
        }

        ObjectMapper getObjectMapper() {
            return this.objectMapper;
        }
    }

    /* loaded from: input_file:com/alibaba/cloud/ai/graph/StateGraph$Nodes.class */
    public static class Nodes {
        public final Set<Node> elements;

        public Nodes(Collection<Node> collection) {
            this.elements = new LinkedHashSet(collection);
        }

        public Nodes() {
            this.elements = new LinkedHashSet();
        }

        public boolean anyMatchById(String str) {
            return this.elements.stream().anyMatch(node -> {
                return Objects.equals(node.id(), str);
            });
        }

        public List<SubStateGraphNode> onlySubStateGraphNodes() {
            return this.elements.stream().filter(node -> {
                return node instanceof SubStateGraphNode;
            }).map(node2 -> {
                return (SubStateGraphNode) node2;
            }).toList();
        }

        public List<Node> exceptSubStateGraphNodes() {
            return this.elements.stream().filter(node -> {
                return !(node instanceof SubStateGraphNode);
            }).toList();
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:com/alibaba/cloud/ai/graph/StateGraph$RunnableErrors.class */
    public enum RunnableErrors {
        missingNodeInEdgeMapping("cannot find edge mapping for id: '%s' in conditional edge with sourceId: '%s' "),
        missingNode("node with id: '%s' doesn't exist!"),
        missingEdge("edge with sourceId: '%s' doesn't exist!"),
        executionError("%s");

        private final String errorMessage;

        RunnableErrors(String str) {
            this.errorMessage = str;
        }

        /* JADX INFO: Access modifiers changed from: package-private */
        public GraphRunnerException exception(String... strArr) {
            return new GraphRunnerException(String.format(this.errorMessage, strArr));
        }
    }

    public OverAllState getOverAllState() {
        return this.overAllState;
    }

    public StateGraph setOverAllState(OverAllState overAllState) {
        this.overAllState = overAllState;
        return this;
    }

    @Deprecated
    public StateGraph(OverAllState overAllState, PlainTextStateSerializer plainTextStateSerializer) {
        this.nodes = new Nodes();
        this.edges = new Edges();
        this.overAllState = overAllState;
        this.stateSerializer = plainTextStateSerializer;
    }

    public StateGraph(String str, OverAllStateFactory overAllStateFactory, PlainTextStateSerializer plainTextStateSerializer) {
        this.nodes = new Nodes();
        this.edges = new Edges();
        this.name = str;
        this.overAllStateFactory = overAllStateFactory;
        this.stateSerializer = plainTextStateSerializer;
    }

    public StateGraph(String str, OverAllStateFactory overAllStateFactory) {
        this.nodes = new Nodes();
        this.edges = new Edges();
        this.name = str;
        this.overAllStateFactory = overAllStateFactory;
        this.stateSerializer = new GsonSerializer();
    }

    public StateGraph(OverAllStateFactory overAllStateFactory) {
        this.nodes = new Nodes();
        this.edges = new Edges();
        this.overAllStateFactory = overAllStateFactory;
        this.stateSerializer = new GsonSerializer();
    }

    public StateGraph(OverAllStateFactory overAllStateFactory, PlainTextStateSerializer plainTextStateSerializer) {
        this.nodes = new Nodes();
        this.edges = new Edges();
        this.overAllStateFactory = overAllStateFactory;
        this.stateSerializer = plainTextStateSerializer;
    }

    @Deprecated
    public StateGraph(String str, OverAllState overAllState) {
        this.nodes = new Nodes();
        this.edges = new Edges();
        this.name = str;
        this.overAllState = overAllState;
        this.stateSerializer = new GsonSerializer();
    }

    @Deprecated
    public StateGraph(OverAllState overAllState) {
        this.nodes = new Nodes();
        this.edges = new Edges();
        this.overAllState = overAllState;
        this.stateSerializer = new GsonSerializer();
    }

    public StateGraph() {
        this.nodes = new Nodes();
        this.edges = new Edges();
        this.stateSerializer = new GsonSerializer();
    }

    public String getName() {
        return this.name;
    }

    public StateSerializer getStateSerializer() {
        return this.stateSerializer;
    }

    public final AgentStateFactory<OverAllState> getStateFactory() {
        return this.stateSerializer.stateFactory();
    }

    public final OverAllStateFactory getOverAllStateFactory() {
        return this.overAllStateFactory;
    }

    public StateGraph addNode(String str, AsyncNodeAction asyncNodeAction) throws GraphStateException {
        return addNode(str, AsyncNodeActionWithConfig.of(asyncNodeAction));
    }

    public StateGraph addNode(String str, AsyncNodeActionWithConfig asyncNodeActionWithConfig) throws GraphStateException {
        return addNode(str, new Node(str, compileConfig -> {
            return asyncNodeActionWithConfig;
        }));
    }

    public StateGraph addNode(String str, Node node) throws GraphStateException {
        if (Objects.equals(node.id(), END)) {
            throw Errors.invalidNodeIdentifier.exception(END);
        }
        if (!Objects.equals(node.id(), str)) {
            throw Errors.invalidNodeIdentifier.exception(node.id(), str);
        }
        if (this.nodes.elements.contains(node)) {
            throw Errors.duplicateNodeError.exception(str);
        }
        this.nodes.elements.add(node);
        return this;
    }

    public StateGraph addNode(String str, CompiledGraph compiledGraph) throws GraphStateException {
        if (Objects.equals(str, END)) {
            throw Errors.invalidNodeIdentifier.exception(END);
        }
        SubCompiledGraphNode subCompiledGraphNode = new SubCompiledGraphNode(str, compiledGraph);
        if (this.nodes.elements.contains(subCompiledGraphNode)) {
            throw Errors.duplicateNodeError.exception(str);
        }
        this.nodes.elements.add(subCompiledGraphNode);
        return this;
    }

    public StateGraph addNode(String str, StateGraph stateGraph) throws GraphStateException {
        if (Objects.equals(str, END)) {
            throw Errors.invalidNodeIdentifier.exception(END);
        }
        stateGraph.validateGraph();
        OverAllState overAllState = stateGraph.getOverAllState();
        OverAllState overAllState2 = getOverAllState();
        if (overAllState != null) {
            for (Map.Entry<String, KeyStrategy> entry : overAllState.keyStrategies().entrySet()) {
                if (!overAllState2.containStrategy(entry.getKey())) {
                    overAllState2.registerKeyAndStrategy(entry.getKey(), entry.getValue());
                }
            }
        }
        stateGraph.setOverAllState(getOverAllState());
        SubStateGraphNode subStateGraphNode = new SubStateGraphNode(str, stateGraph);
        if (this.nodes.elements.contains(subStateGraphNode)) {
            throw Errors.duplicateNodeError.exception(str);
        }
        this.nodes.elements.add(subStateGraphNode);
        return this;
    }

    public StateGraph addEdge(String str, String str2) throws GraphStateException {
        if (Objects.equals(str, END)) {
            throw Errors.invalidEdgeIdentifier.exception(END);
        }
        Edge edge = new Edge(str, new EdgeValue(str2));
        int indexOf = this.edges.elements.indexOf(edge);
        if (indexOf >= 0) {
            ArrayList arrayList = new ArrayList(this.edges.elements.get(indexOf).targets());
            arrayList.add(edge.target());
            this.edges.elements.set(indexOf, new Edge(str, arrayList));
        } else {
            this.edges.elements.add(edge);
        }
        return this;
    }

    public StateGraph addConditionalEdges(String str, AsyncEdgeAction asyncEdgeAction, Map<String, String> map) throws GraphStateException {
        if (Objects.equals(str, END)) {
            throw Errors.invalidEdgeIdentifier.exception(END);
        }
        if (map == null || map.isEmpty()) {
            throw Errors.edgeMappingIsEmpty.exception(str);
        }
        Edge edge = new Edge(str, new EdgeValue(new EdgeCondition(asyncEdgeAction, map)));
        if (this.edges.elements.contains(edge)) {
            throw Errors.duplicateConditionalEdgeError.exception(str);
        }
        this.edges.elements.add(edge);
        return this;
    }

    void validateGraph() throws GraphStateException {
        Optional<Edge> edgeBySourceId = this.edges.edgeBySourceId(START);
        Errors errors = Errors.missingEntryPoint;
        Objects.requireNonNull(errors);
        edgeBySourceId.orElseThrow(() -> {
            return errors.exception(new Object[0]);
        }).validate(this.nodes);
        Iterator<Edge> it = this.edges.elements.iterator();
        while (it.hasNext()) {
            it.next().validate(this.nodes);
        }
    }

    public CompiledGraph compile(CompileConfig compileConfig) throws GraphStateException {
        Objects.requireNonNull(compileConfig, "config cannot be null");
        validateGraph();
        return new CompiledGraph(this, compileConfig);
    }

    public CompiledGraph compile() throws GraphStateException {
        return compile(CompileConfig.builder().plainTextStateSerializer(new JacksonSerializer()).saverConfig(SaverConfig.builder().register(SaverConstant.MEMORY, new MemorySaver()).build()).build());
    }

    public GraphRepresentation getGraph(GraphRepresentation.Type type, String str, boolean z) {
        return new GraphRepresentation(type, type.generator.generate(this.nodes, this.edges, str, z));
    }

    public GraphRepresentation getGraph(GraphRepresentation.Type type, String str) {
        return new GraphRepresentation(type, type.generator.generate(this.nodes, this.edges, str, true));
    }

    public GraphRepresentation getGraph(GraphRepresentation.Type type) {
        return new GraphRepresentation(type, type.generator.generate(this.nodes, this.edges, this.name, true));
    }
}
