package com.alibaba.cloud.ai.graph.checkpoint.savers;

import com.alibaba.cloud.ai.graph.RunnableConfig;
import com.alibaba.cloud.ai.graph.checkpoint.BaseCheckpointSaver;
import com.alibaba.cloud.ai.graph.checkpoint.Checkpoint;
import com.fasterxml.jackson.core.type.TypeReference;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.mongodb.BasicDBObject;
import com.mongodb.ClientSessionOptions;
import com.mongodb.TransactionOptions;
import com.mongodb.WriteConcern;
import com.mongodb.client.ClientSession;
import com.mongodb.client.MongoClient;
import com.mongodb.client.MongoCollection;
import com.mongodb.client.MongoDatabase;
import com.mongodb.client.model.Filters;
import com.mongodb.client.model.ReplaceOptions;
import java.util.Collection;
import java.util.Collections;
import java.util.LinkedList;
import java.util.List;
import java.util.NoSuchElementException;
import java.util.Objects;
import java.util.Optional;
import java.util.stream.IntStream;
import org.bson.Document;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:com/alibaba/cloud/ai/graph/checkpoint/savers/MongoSaver.class */
public class MongoSaver implements BaseCheckpointSaver {
    private static final Logger logger = LoggerFactory.getLogger(MongoSaver.class);
    private MongoClient client;
    private MongoDatabase database;
    private TransactionOptions txnOptions = TransactionOptions.builder().writeConcern(WriteConcern.MAJORITY).build();
    private final ObjectMapper objectMapper = new ObjectMapper();
    private static final String DB_NAME = "check_point_db";
    private static final String COLLECTION_NAME = "checkpoint_collection";
    private static final String DOCUMENT_PREFIX = "mongo:checkpoint:document:";
    private static final String DOCUMENT_CONTENT_KEY = "checkpoint_content";

    public MongoSaver(MongoClient mongoClient) {
        this.client = mongoClient;
        this.database = mongoClient.getDatabase(DB_NAME);
        Runtime runtime = Runtime.getRuntime();
        Objects.requireNonNull(mongoClient);
        runtime.addShutdownHook(new Thread(mongoClient::close));
    }

    @Override // com.alibaba.cloud.ai.graph.checkpoint.BaseCheckpointSaver
    public Collection<Checkpoint> list(RunnableConfig runnableConfig) {
        Optional<String> threadId = runnableConfig.threadId();
        if (!threadId.isPresent()) {
            throw new IllegalArgumentException("threadId is not allow null");
        }
        ClientSession startSession = this.client.startSession(ClientSessionOptions.builder().defaultTransactionOptions(this.txnOptions).build());
        startSession.startTransaction();
        try {
            try {
                Document document = (Document) this.database.getCollection(COLLECTION_NAME).find(new BasicDBObject("_id", "mongo:checkpoint:document:" + threadId.get())).first();
                if (document == null) {
                    List emptyList = Collections.emptyList();
                    startSession.close();
                    return emptyList;
                }
                List list = (List) this.objectMapper.readValue(document.getString(DOCUMENT_CONTENT_KEY), new TypeReference<List<Checkpoint>>() { // from class: com.alibaba.cloud.ai.graph.checkpoint.savers.MongoSaver.1
                });
                startSession.commitTransaction();
                startSession.close();
                return list;
            } catch (Exception e) {
                startSession.abortTransaction();
                throw new RuntimeException(e);
            }
        } catch (Throwable th) {
            startSession.close();
            throw th;
        }
    }

    @Override // com.alibaba.cloud.ai.graph.checkpoint.BaseCheckpointSaver
    public Optional<Checkpoint> get(RunnableConfig runnableConfig) {
        Optional<String> threadId = runnableConfig.threadId();
        if (!threadId.isPresent()) {
            throw new IllegalArgumentException("threadId is not allow null");
        }
        ClientSession startSession = this.client.startSession(ClientSessionOptions.builder().defaultTransactionOptions(this.txnOptions).build());
        try {
            try {
                startSession.startTransaction();
                Document document = (Document) this.database.getCollection(COLLECTION_NAME).find(new BasicDBObject("_id", "mongo:checkpoint:document:" + threadId.get())).first();
                if (document == null) {
                    Optional<Checkpoint> empty = Optional.empty();
                    startSession.close();
                    return empty;
                }
                List<Checkpoint> list = (List) this.objectMapper.readValue(document.getString(DOCUMENT_CONTENT_KEY), new TypeReference<List<Checkpoint>>() { // from class: com.alibaba.cloud.ai.graph.checkpoint.savers.MongoSaver.2
                });
                startSession.commitTransaction();
                if (runnableConfig.checkPointId().isPresent()) {
                    Optional flatMap = runnableConfig.checkPointId().flatMap(str -> {
                        return list.stream().filter(checkpoint -> {
                            return checkpoint.getId().equals(str);
                        }).findFirst();
                    });
                    startSession.close();
                    return flatMap;
                }
                Optional<Checkpoint> last = getLast(getLinkedList(list), runnableConfig);
                startSession.close();
                return last;
            } catch (Exception e) {
                startSession.abortTransaction();
                throw new RuntimeException(e);
            }
        } catch (Throwable th) {
            startSession.close();
            throw th;
        }
    }

    @Override // com.alibaba.cloud.ai.graph.checkpoint.BaseCheckpointSaver
    public RunnableConfig put(RunnableConfig runnableConfig, Checkpoint checkpoint) throws Exception {
        Optional<String> threadId = runnableConfig.threadId();
        if (!threadId.isPresent()) {
            throw new IllegalArgumentException("threadId is not allow null");
        }
        ClientSession startSession = this.client.startSession(ClientSessionOptions.builder().defaultTransactionOptions(this.txnOptions).build());
        startSession.startTransaction();
        try {
            try {
                MongoCollection collection = this.database.getCollection(COLLECTION_NAME);
                Document document = (Document) collection.find(new BasicDBObject("_id", "mongo:checkpoint:document:" + threadId.get())).first();
                LinkedList<Checkpoint> linkedList = null;
                if (Objects.nonNull(document)) {
                    List<Checkpoint> list = (List) this.objectMapper.readValue(document.getString(DOCUMENT_CONTENT_KEY), new TypeReference<List<Checkpoint>>() { // from class: com.alibaba.cloud.ai.graph.checkpoint.savers.MongoSaver.3
                    });
                    linkedList = getLinkedList(list);
                    if (runnableConfig.checkPointId().isPresent()) {
                        String str = runnableConfig.checkPointId().get();
                        linkedList.set(IntStream.range(0, list.size()).filter(i -> {
                            return ((Checkpoint) list.get(i)).getId().equals(str);
                        }).findFirst().orElseThrow(() -> {
                            return new NoSuchElementException(String.format("Checkpoint with id %s not found!", str));
                        }), checkpoint);
                        collection.replaceOne(Filters.eq("_id", "mongo:checkpoint:document:" + threadId.get()), new Document().append("_id", "mongo:checkpoint:document:" + threadId.get()).append(DOCUMENT_CONTENT_KEY, this.objectMapper.writeValueAsString(linkedList)));
                        startSession.commitTransaction();
                        startSession.close();
                        startSession.close();
                        return runnableConfig;
                    }
                }
                if (linkedList == null) {
                    LinkedList linkedList2 = new LinkedList();
                    linkedList2.push(checkpoint);
                    collection.insertOne(new Document().append("_id", "mongo:checkpoint:document:" + threadId.get()).append(DOCUMENT_CONTENT_KEY, this.objectMapper.writeValueAsString(linkedList2))).wasAcknowledged();
                } else {
                    linkedList.push(checkpoint);
                    collection.replaceOne(Filters.eq("_id", "mongo:checkpoint:document:" + threadId.get()), new Document().append("_id", "mongo:checkpoint:document:" + threadId.get()).append(DOCUMENT_CONTENT_KEY, this.objectMapper.writeValueAsString(linkedList)), new ReplaceOptions().upsert(true));
                }
                startSession.commitTransaction();
                startSession.close();
                return RunnableConfig.builder(runnableConfig).checkPointId(checkpoint.getId()).build();
            } catch (Exception e) {
                startSession.abortTransaction();
                throw new RuntimeException(e);
            }
        } catch (Throwable th) {
            startSession.close();
            throw th;
        }
    }

    @Override // com.alibaba.cloud.ai.graph.checkpoint.BaseCheckpointSaver
    public boolean clear(RunnableConfig runnableConfig) {
        Optional<String> threadId = runnableConfig.threadId();
        if (!threadId.isPresent()) {
            throw new IllegalArgumentException("threadId is not allow null");
        }
        ClientSession startSession = this.client.startSession(ClientSessionOptions.builder().defaultTransactionOptions(this.txnOptions).build());
        startSession.startTransaction();
        try {
            try {
                this.database.getCollection(COLLECTION_NAME).findOneAndDelete(new BasicDBObject("_id", "mongo:checkpoint:document:" + threadId.get()));
                startSession.commitTransaction();
                startSession.close();
                return true;
            } catch (Exception e) {
                startSession.abortTransaction();
                throw new RuntimeException(e);
            }
        } catch (Throwable th) {
            startSession.close();
            throw th;
        }
    }
}
