package com.alibaba.cloud.ai.graph;

import com.alibaba.cloud.ai.graph.GraphRepresentation;
import com.alibaba.cloud.ai.graph.OverAllState;
import com.alibaba.cloud.ai.graph.StateGraph;
import com.alibaba.cloud.ai.graph.action.AsyncNodeAction;
import com.alibaba.cloud.ai.graph.action.AsyncNodeActionWithConfig;
import com.alibaba.cloud.ai.graph.checkpoint.BaseCheckpointSaver;
import com.alibaba.cloud.ai.graph.checkpoint.Checkpoint;
import com.alibaba.cloud.ai.graph.exception.GraphInitKeyErrorException;
import com.alibaba.cloud.ai.graph.exception.GraphInterruptException;
import com.alibaba.cloud.ai.graph.exception.GraphStateException;
import com.alibaba.cloud.ai.graph.internal.edge.Edge;
import com.alibaba.cloud.ai.graph.internal.edge.EdgeValue;
import com.alibaba.cloud.ai.graph.internal.node.Node;
import com.alibaba.cloud.ai.graph.internal.node.ParallelNode;
import com.alibaba.cloud.ai.graph.state.StateSnapshot;
import java.io.IOException;
import java.util.Arrays;
import java.util.Collection;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionException;
import java.util.concurrent.ExecutionException;
import java.util.function.Supplier;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.bsc.async.AsyncGenerator;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.util.CollectionUtils;

/* loaded from: input_file:com/alibaba/cloud/ai/graph/CompiledGraph.class */
public class CompiledGraph {
    private static final Logger log = LoggerFactory.getLogger(CompiledGraph.class);
    public final StateGraph stateGraph;
    private final OverAllState overAllState;
    private final ProcessedNodesEdgesAndConfig processedData;
    public final CompileConfig compileConfig;
    final Map<String, AsyncNodeActionWithConfig> nodes = new LinkedHashMap();
    final Map<String, EdgeValue> edges = new LinkedHashMap();
    private int maxIterations = 25;

    /* loaded from: input_file:com/alibaba/cloud/ai/graph/CompiledGraph$AsyncNodeGenerator.class */
    public class AsyncNodeGenerator<Output extends NodeOutput> implements AsyncGenerator<Output> {
        Map<String, Object> currentState;
        String currentNodeId;
        String nextNodeId;
        OverAllState overAllState;
        RunnableConfig config;
        int iteration = 0;
        boolean resumedFromEmbed = false;

        protected AsyncNodeGenerator(OverAllState overAllState, RunnableConfig runnableConfig) {
            if (overAllState.isResume()) {
                CompiledGraph.log.trace("RESUME REQUEST");
                Checkpoint orElseThrow = CompiledGraph.this.compileConfig.checkpointSaver().orElseThrow(() -> {
                    return new IllegalStateException("inputs cannot be null (ie. resume request) if no checkpoint saver is configured");
                }).get(runnableConfig).orElseThrow(() -> {
                    return new IllegalStateException("Resume request without a saved checkpoint!");
                });
                this.currentState = orElseThrow.getState();
                this.config = runnableConfig.withCheckPointId(null);
                this.overAllState = overAllState.input(this.currentState);
                this.nextNodeId = orElseThrow.getNextNodeId();
                this.currentNodeId = null;
                CompiledGraph.log.trace("RESUME FROM {}", orElseThrow.getNodeId());
                return;
            }
            CompiledGraph.log.trace("START");
            Map<String, Object> data = overAllState.data();
            boolean keyVerify = overAllState.keyVerify();
            if (!CollectionUtils.isEmpty(data) && !keyVerify) {
                throw new GraphInitKeyErrorException(Arrays.toString(data.keySet().toArray()) + " isn't included in the keyStrategies");
            }
            this.currentState = CompiledGraph.this.getInitialState(data, runnableConfig);
            this.overAllState = overAllState;
            this.nextNodeId = null;
            this.currentNodeId = StateGraph.START;
            this.config = runnableConfig;
        }

        protected Output buildNodeOutput(String str) {
            return (Output) NodeOutput.of(str, this.overAllState);
        }

        protected Output buildStateSnapshot(Checkpoint checkpoint) throws Exception {
            return StateSnapshot.of(CompiledGraph.this.overAllState(), checkpoint, this.config, CompiledGraph.this.stateGraph.getStateFactory());
        }

        private Optional<AsyncGenerator.Data<Output>> getEmbedGenerator(Map<String, Object> map) {
            return (Optional<AsyncGenerator.Data<Output>>) map.entrySet().stream().filter(entry -> {
                return entry.getValue() instanceof AsyncGenerator;
            }).findFirst().map(entry2 -> {
                return AsyncGenerator.Data.composeWith(((AsyncGenerator) entry2.getValue()).map(nodeOutput -> {
                    nodeOutput.setSubGraph(true);
                    return nodeOutput;
                }), obj -> {
                    if (obj != null) {
                        if (obj instanceof Map) {
                            updateState(map, entry2, (Map) obj);
                        } else {
                            if (!(obj instanceof NodeOutput)) {
                                throw new IllegalArgumentException("Embedded generator must return a Map or NodeOutput");
                            }
                            updateState(map, entry2, ((NodeOutput) obj).state().data());
                        }
                    }
                    CompiledGraph.this.overAllState().updateState(map);
                    this.nextNodeId = CompiledGraph.this.nextNodeId(this.currentNodeId, this.currentState);
                    this.resumedFromEmbed = true;
                });
            });
        }

        private void updateState(Map<String, Object> map, Map.Entry<String, Object> entry, Map<String, Object> map2) {
            Map<String, Object> updateState = OverAllState.updateState(this.currentState, (Map) map.entrySet().stream().filter(entry2 -> {
                return !Objects.equals(entry2.getKey(), entry.getKey());
            }).collect(Collectors.toMap((v0) -> {
                return v0.getKey();
            }, (v0) -> {
                return v0.getValue();
            })), CompiledGraph.this.overAllState().keyStrategies());
            CompiledGraph.this.overAllState().updateState(updateState);
            this.currentState = OverAllState.updateState(updateState, map2, CompiledGraph.this.overAllState().keyStrategies());
        }

        private CompletableFuture<AsyncGenerator.Data<Output>> evaluateAction(AsyncNodeActionWithConfig asyncNodeActionWithConfig, OverAllState overAllState) {
            return (CompletableFuture<AsyncGenerator.Data<Output>>) asyncNodeActionWithConfig.apply(overAllState, this.config).thenApply(map -> {
                try {
                    Optional<AsyncGenerator.Data<Output>> embedGenerator = getEmbedGenerator(map);
                    if (embedGenerator.isPresent()) {
                        return embedGenerator.get();
                    }
                    this.currentState = CompiledGraph.this.overAllState().updateState(map);
                    this.nextNodeId = CompiledGraph.this.nextNodeId(this.currentNodeId, this.currentState, CompiledGraph.this.overAllState());
                    return AsyncGenerator.Data.of(getNodeOutput());
                } catch (Exception e) {
                    throw new CompletionException(e);
                }
            });
        }

        private CompletableFuture<Output> evaluateActionWithoutNested(AsyncNodeAction asyncNodeAction, OverAllState overAllState) {
            return (CompletableFuture<Output>) asyncNodeAction.apply(overAllState).thenApply(map -> {
                try {
                    this.currentState = OverAllState.updateState(this.currentState, map, CompiledGraph.this.overAllState().keyStrategies());
                    this.nextNodeId = CompiledGraph.this.nextNodeId(this.currentNodeId, this.currentState);
                    Optional<Checkpoint> addCheckpoint = CompiledGraph.this.addCheckpoint(this.config, this.currentNodeId, this.currentState, this.nextNodeId);
                    return (addCheckpoint.isPresent() && this.config.streamMode() == StreamMode.SNAPSHOTS) ? buildStateSnapshot(addCheckpoint.get()) : buildNodeOutput(this.currentNodeId);
                } catch (Exception e) {
                    throw new CompletionException(e);
                }
            });
        }

        private CompletableFuture<Output> getNodeOutput() throws Exception {
            Optional<Checkpoint> addCheckpoint = CompiledGraph.this.addCheckpoint(this.config, this.currentNodeId, this.currentState, this.nextNodeId);
            return CompletableFuture.completedFuture((addCheckpoint.isPresent() && this.config.streamMode() == StreamMode.SNAPSHOTS) ? buildStateSnapshot(addCheckpoint.get()) : buildNodeOutput(this.currentNodeId));
        }

        public AsyncGenerator.Data<Output> next() {
            int i = this.iteration + 1;
            this.iteration = i;
            if (i > CompiledGraph.this.maxIterations) {
                CompiledGraph.log.warn("Maximum number of iterations ({}) reached!", Integer.valueOf(CompiledGraph.this.maxIterations));
                return AsyncGenerator.Data.done(this.currentState);
            }
            if (this.nextNodeId == null && this.currentNodeId == null) {
                return AsyncGenerator.Data.done(this.currentState);
            }
            try {
                if (this.resumedFromEmbed) {
                    CompletableFuture<Output> nodeOutput = getNodeOutput();
                    this.resumedFromEmbed = false;
                    return AsyncGenerator.Data.of(nodeOutput);
                }
                if (StateGraph.START.equals(this.currentNodeId)) {
                    this.nextNodeId = CompiledGraph.this.getEntryPoint(this.currentState);
                    this.currentNodeId = this.nextNodeId;
                    CompiledGraph.this.addCheckpoint(this.config, StateGraph.START, this.currentState, this.nextNodeId);
                    return AsyncGenerator.Data.of(buildNodeOutput(StateGraph.START));
                }
                if (StateGraph.END.equals(this.nextNodeId)) {
                    this.nextNodeId = null;
                    this.currentNodeId = null;
                    return AsyncGenerator.Data.of(buildNodeOutput(StateGraph.END));
                }
                if (!CompiledGraph.this.shouldInterruptAfter(this.currentNodeId, this.nextNodeId) && !CompiledGraph.this.shouldInterruptBefore(this.nextNodeId, this.currentNodeId)) {
                    this.currentNodeId = this.nextNodeId;
                    AsyncNodeActionWithConfig asyncNodeActionWithConfig = CompiledGraph.this.nodes.get(this.currentNodeId);
                    if (asyncNodeActionWithConfig == null) {
                        throw StateGraph.RunnableErrors.missingNode.exception(this.currentNodeId);
                    }
                    return evaluateAction(asyncNodeActionWithConfig, CompiledGraph.this.overAllState()).get();
                }
                return AsyncGenerator.Data.done();
            } catch (Exception e) {
                if (e instanceof ExecutionException) {
                    Throwable cause = ((ExecutionException) e).getCause();
                    if (cause instanceof GraphInterruptException) {
                        this.overAllState.setInterruptMessage(((GraphInterruptException) cause).getMessage());
                        return AsyncGenerator.Data.done(buildNodeOutput(this.currentNodeId));
                    }
                }
                CompiledGraph.log.error(e.getMessage(), e);
                return AsyncGenerator.Data.error(e);
            }
        }
    }

    /* loaded from: input_file:com/alibaba/cloud/ai/graph/CompiledGraph$StreamMode.class */
    public enum StreamMode {
        VALUES,
        SNAPSHOTS
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public CompiledGraph(StateGraph stateGraph, CompileConfig compileConfig) throws GraphStateException {
        this.stateGraph = stateGraph;
        this.overAllState = Objects.nonNull(stateGraph.getOverAllStateFactory()) ? stateGraph.getOverAllStateFactory().create() : stateGraph.getOverAllState();
        this.processedData = ProcessedNodesEdgesAndConfig.process(stateGraph, compileConfig);
        for (String str : this.processedData.interruptsBefore()) {
            if (!this.processedData.nodes().anyMatchById(str)) {
                throw StateGraph.Errors.interruptionNodeNotExist.exception(str);
            }
        }
        for (String str2 : this.processedData.interruptsAfter()) {
            if (!this.processedData.nodes().anyMatchById(str2)) {
                throw StateGraph.Errors.interruptionNodeNotExist.exception(str2);
            }
        }
        this.compileConfig = CompileConfig.builder(compileConfig).interruptsBefore(this.processedData.interruptsBefore()).interruptsAfter(this.processedData.interruptsAfter()).build();
        for (Node node : this.processedData.nodes().elements) {
            Node.ActionFactory actionFactory = node.actionFactory();
            Objects.requireNonNull(actionFactory, String.format("action factory for node id '%s' is null!", node.id()));
            this.nodes.put(node.id(), actionFactory.apply(compileConfig));
        }
        for (Edge edge : this.processedData.edges().elements) {
            List<EdgeValue> targets = edge.targets();
            if (targets.size() == 1) {
                this.edges.put(edge.sourceId(), targets.get(0));
            } else {
                Supplier supplier = () -> {
                    return targets.stream().filter(edgeValue -> {
                        return this.nodes.containsKey(edgeValue.id());
                    });
                };
                List list = ((Stream) supplier.get()).map(edgeValue -> {
                    return new Edge(edgeValue.id());
                }).filter(edge2 -> {
                    return this.processedData.edges().elements.contains(edge2);
                }).map(edge3 -> {
                    return Integer.valueOf(this.processedData.edges().elements.indexOf(edge3));
                }).map(num -> {
                    return this.processedData.edges().elements.get(num.intValue());
                }).toList();
                Set set = (Set) list.stream().map(edge4 -> {
                    return edge4.target().id();
                }).collect(Collectors.toSet());
                if (set.size() > 1) {
                    List list2 = list.stream().filter(edge5 -> {
                        return edge5.target().value() != null;
                    }).toList();
                    if (!list2.isEmpty()) {
                        throw StateGraph.Errors.unsupportedConditionalEdgeOnParallelNode.exception(edge.sourceId(), list2.stream().map((v0) -> {
                            return v0.sourceId();
                        }).toList());
                    }
                    throw StateGraph.Errors.illegalMultipleTargetsOnParallelNode.exception(edge.sourceId(), set);
                }
                ParallelNode parallelNode = new ParallelNode(edge.sourceId(), ((Stream) supplier.get()).map(edgeValue2 -> {
                    return this.nodes.get(edgeValue2.id());
                }).toList(), overAllState().keyStrategies());
                this.nodes.put(parallelNode.id(), parallelNode.actionFactory().apply(compileConfig));
                this.edges.put(edge.sourceId(), new EdgeValue(parallelNode.id()));
                this.edges.put(parallelNode.id(), new EdgeValue((String) set.iterator().next()));
            }
        }
    }

    public OverAllState overAllState() {
        return this.overAllState;
    }

    public Collection<StateSnapshot> getStateHistory(RunnableConfig runnableConfig) {
        return (Collection) this.compileConfig.checkpointSaver().orElseThrow(() -> {
            return new IllegalStateException("Missing CheckpointSaver!");
        }).list(runnableConfig).stream().map(checkpoint -> {
            return StateSnapshot.of(overAllState(), checkpoint, runnableConfig, this.stateGraph.getStateFactory());
        }).collect(Collectors.toList());
    }

    public StateSnapshot getState(RunnableConfig runnableConfig) {
        return stateOf(runnableConfig).orElseThrow(() -> {
            return new IllegalStateException("Missing Checkpoint!");
        });
    }

    public Optional<StateSnapshot> stateOf(RunnableConfig runnableConfig) {
        return this.compileConfig.checkpointSaver().orElseThrow(() -> {
            return new IllegalStateException("Missing CheckpointSaver!");
        }).get(runnableConfig).map(checkpoint -> {
            return StateSnapshot.of(overAllState(), checkpoint, runnableConfig, this.stateGraph.getStateFactory());
        });
    }

    public RunnableConfig updateState(RunnableConfig runnableConfig, Map<String, Object> map, String str) throws Exception {
        BaseCheckpointSaver orElseThrow = this.compileConfig.checkpointSaver().orElseThrow(() -> {
            return new IllegalStateException("Missing CheckpointSaver!");
        });
        Checkpoint checkpoint = (Checkpoint) orElseThrow.get(runnableConfig).map(Checkpoint::new).map(checkpoint2 -> {
            return checkpoint2.updateState(map, overAllState().keyStrategies());
        }).orElseThrow(() -> {
            return new IllegalStateException("Missing Checkpoint!");
        });
        String str2 = null;
        if (str != null) {
            str2 = nextNodeId(str, checkpoint.getState());
        }
        return RunnableConfig.builder(orElseThrow.put(runnableConfig, checkpoint)).checkPointId(checkpoint.getId()).nextNode(str2).build();
    }

    public RunnableConfig updateState(RunnableConfig runnableConfig, Map<String, Object> map) throws Exception {
        return updateState(runnableConfig, map, null);
    }

    public void setMaxIterations(int i) {
        if (i <= 0) {
            throw new IllegalArgumentException("maxIterations must be > 0!");
        }
        this.maxIterations = i;
    }

    private String nextNodeId(EdgeValue edgeValue, Map<String, Object> map, String str) throws Exception {
        return nextNodeId(edgeValue, map, str, overAllState());
    }

    private String nextNodeId(EdgeValue edgeValue, Map<String, Object> map, String str, OverAllState overAllState) throws Exception {
        if (edgeValue == null) {
            throw StateGraph.RunnableErrors.missingEdge.exception(str);
        }
        if (edgeValue.id() != null) {
            return edgeValue.id();
        }
        if (edgeValue.value() == null) {
            throw StateGraph.RunnableErrors.executionError.exception(String.format("invalid edge value for nodeId: [%s] !", str));
        }
        if (overAllState == null) {
            overAllState = this.stateGraph.getStateFactory().apply(map);
        }
        String str2 = edgeValue.value().action().apply(overAllState).get();
        String str3 = edgeValue.value().mappings().get(str2);
        if (str3 == null) {
            throw StateGraph.RunnableErrors.missingNodeInEdgeMapping.exception(str, str2);
        }
        return str3;
    }

    private String nextNodeId(String str, Map<String, Object> map) throws Exception {
        return nextNodeId(this.edges.get(str), map, str);
    }

    private String nextNodeId(String str, Map<String, Object> map, OverAllState overAllState) throws Exception {
        return nextNodeId(this.edges.get(str), map, str, overAllState);
    }

    private String getEntryPoint(Map<String, Object> map) throws Exception {
        return nextNodeId(this.edges.get(StateGraph.START), map, "entryPoint");
    }

    private boolean shouldInterruptBefore(String str, String str2) {
        if (str2 == null) {
            return false;
        }
        return this.compileConfig.interruptsBefore().contains(str);
    }

    private boolean shouldInterruptAfter(String str, String str2) {
        if (str == null) {
            return false;
        }
        return this.compileConfig.interruptsAfter().contains(str);
    }

    private Optional<Checkpoint> addCheckpoint(RunnableConfig runnableConfig, String str, Map<String, Object> map, String str2) throws Exception {
        if (!this.compileConfig.checkpointSaver().isPresent()) {
            return Optional.empty();
        }
        Checkpoint build = Checkpoint.builder().nodeId(str).state(cloneState(map)).nextNodeId(str2).build();
        this.compileConfig.checkpointSaver().get().put(runnableConfig, build);
        return Optional.of(build);
    }

    Map<String, Object> getInitialState(Map<String, Object> map, RunnableConfig runnableConfig) {
        return (Map) this.compileConfig.checkpointSaver().flatMap(baseCheckpointSaver -> {
            return baseCheckpointSaver.get(runnableConfig);
        }).map(checkpoint -> {
            return OverAllState.updateState(checkpoint.getState(), map, overAllState().keyStrategies());
        }).orElseGet(() -> {
            return OverAllState.updateState(new HashMap(), map, overAllState().keyStrategies());
        });
    }

    OverAllState cloneState(Map<String, Object> map) throws IOException, ClassNotFoundException, InstantiationException, IllegalAccessException {
        return new OverAllState(map);
    }

    public AsyncGenerator<NodeOutput> stream(Map<String, Object> map, RunnableConfig runnableConfig) {
        Objects.requireNonNull(runnableConfig, "config cannot be null");
        return new AsyncGenerator.WithEmbed(new AsyncNodeGenerator(overAllState().input(map), runnableConfig));
    }

    public AsyncGenerator<NodeOutput> stream(OverAllState overAllState, RunnableConfig runnableConfig) {
        Objects.requireNonNull(runnableConfig, "config cannot be null");
        return new AsyncGenerator.WithEmbed(new AsyncNodeGenerator(overAllState, runnableConfig));
    }

    public AsyncGenerator<NodeOutput> stream(Map<String, Object> map) {
        return stream(overAllState().input(map), RunnableConfig.builder().build());
    }

    public AsyncGenerator<NodeOutput> stream() {
        return stream(overAllState(), RunnableConfig.builder().build());
    }

    public Optional<OverAllState> invoke(Map<String, Object> map, RunnableConfig runnableConfig) {
        return stream(map, runnableConfig).stream().reduce((nodeOutput, nodeOutput2) -> {
            return nodeOutput2;
        }).map((v0) -> {
            return v0.state();
        });
    }

    public Optional<OverAllState> invoke(OverAllState overAllState, RunnableConfig runnableConfig) {
        return stream(overAllState, runnableConfig).stream().reduce((nodeOutput, nodeOutput2) -> {
            return nodeOutput2;
        }).map((v0) -> {
            return v0.state();
        });
    }

    public Optional<OverAllState> invoke(Map<String, Object> map) {
        return invoke(this.overAllState.input(map), RunnableConfig.builder().build());
    }

    public Optional<OverAllState> resume(OverAllState.HumanFeedback humanFeedback, RunnableConfig runnableConfig) {
        OverAllState apply = this.stateGraph.getStateFactory().apply(getState(runnableConfig).state().data());
        apply.withResume();
        apply.withHumanFeedback(humanFeedback);
        return invoke(apply, runnableConfig);
    }

    public AsyncGenerator<NodeOutput> streamSnapshots(Map<String, Object> map, RunnableConfig runnableConfig) {
        Objects.requireNonNull(runnableConfig, "config cannot be null");
        return new AsyncGenerator.WithEmbed(new AsyncNodeGenerator(overAllState().input(map), runnableConfig.withStreamMode(StreamMode.SNAPSHOTS)));
    }

    public GraphRepresentation getGraph(GraphRepresentation.Type type, String str, boolean z) {
        return new GraphRepresentation(type, type.generator.generate(this.processedData.nodes(), this.processedData.edges(), str, z));
    }

    public GraphRepresentation getGraph(GraphRepresentation.Type type, String str) {
        return new GraphRepresentation(type, type.generator.generate(this.processedData.nodes(), this.processedData.edges(), str, true));
    }

    public GraphRepresentation getGraph(GraphRepresentation.Type type) {
        return getGraph(type, "Graph Diagram", true);
    }
}
