package com.alibaba.cloud.ai.graph.node;

import com.alibaba.cloud.ai.dashscope.rerank.DashScopeRerankOptions;
import com.alibaba.cloud.ai.graph.OverAllState;
import com.alibaba.cloud.ai.graph.action.NodeAction;
import com.alibaba.cloud.ai.model.RerankModel;
import com.alibaba.cloud.ai.model.RerankRequest;
import com.alibaba.cloud.ai.model.RerankResponse;
import jakarta.annotation.Nullable;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.function.Function;
import java.util.stream.Collectors;
import javax.validation.constraints.NotNull;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.document.Document;
import org.springframework.ai.rag.Query;
import org.springframework.ai.rag.postretrieval.document.DocumentPostProcessor;
import org.springframework.ai.rag.retrieval.search.VectorStoreDocumentRetriever;
import org.springframework.ai.vectorstore.VectorStore;
import org.springframework.ai.vectorstore.filter.Filter;
import org.springframework.util.StringUtils;

/* loaded from: input_file:com/alibaba/cloud/ai/graph/node/KnowledgeRetrievalNode.class */
public class KnowledgeRetrievalNode implements NodeAction {
    private String userPromptKey;
    private String userPrompt;
    private String topKKey;
    private Integer topK;
    private String similarityThresholdKey;
    private Double similarityThreshold;
    private String filterExpressionKey;
    private Filter.Expression filterExpression;
    private String enableRankerKey;
    private Boolean enableRanker;
    private String rerankModelKey;
    private RerankModel rerankModel;
    private String rerankOptionsKey;
    private DashScopeRerankOptions rerankOptions;
    private String vectorStoreKey;
    private VectorStore vectorStore;
    List<Document> documents;
    private static final Logger logger = LoggerFactory.getLogger(KnowledgeRetrievalNode.class);

    /* loaded from: input_file:com/alibaba/cloud/ai/graph/node/KnowledgeRetrievalNode$Builder.class */
    public static class Builder {
        private String userPromptKey;
        private String userPrompt;
        private String topKKey;
        private Integer topK;
        private String similarityThresholdKey;
        private Double similarityThreshold;
        private String filterExpressionKey;
        private Filter.Expression filterExpression;
        private String enableRankerKey;
        private Boolean enableRanker;
        private String rerankModelKey;
        private RerankModel rerankModel;
        private String rerankOptionsKey;
        private DashScopeRerankOptions rerankOptions;
        private String vectorStoreKey;
        private VectorStore vectorStore;

        public Builder userPromptKey(String str) {
            this.userPromptKey = str;
            return this;
        }

        public Builder userPrompt(String str) {
            this.userPrompt = str;
            return this;
        }

        public Builder topKKey(String str) {
            this.topKKey = str;
            return this;
        }

        public Builder topK(Integer num) {
            this.topK = num;
            return this;
        }

        public Builder similarityThresholdKey(String str) {
            this.similarityThresholdKey = str;
            return this;
        }

        public Builder similarityThreshold(Double d) {
            this.similarityThreshold = d;
            return this;
        }

        public Builder filterExpressionKey(String str) {
            this.filterExpressionKey = str;
            return this;
        }

        public Builder filterExpression(Filter.Expression expression) {
            this.filterExpression = expression;
            return this;
        }

        public Builder enableRankerKey(String str) {
            this.enableRankerKey = str;
            return this;
        }

        public Builder enableRanker(Boolean bool) {
            this.enableRanker = bool;
            return this;
        }

        public Builder rerankModelKey(String str) {
            this.rerankModelKey = str;
            return this;
        }

        public Builder rerankModel(RerankModel rerankModel) {
            this.rerankModel = rerankModel;
            return this;
        }

        public Builder rerankOptionsKey(String str) {
            this.rerankOptionsKey = str;
            return this;
        }

        public Builder rerankOptions(DashScopeRerankOptions dashScopeRerankOptions) {
            this.rerankOptions = dashScopeRerankOptions;
            return this;
        }

        public Builder vectorStoreKey(String str) {
            this.vectorStoreKey = str;
            return this;
        }

        public Builder vectorStore(VectorStore vectorStore) {
            this.vectorStore = vectorStore;
            return this;
        }

        public KnowledgeRetrievalNode build() {
            KnowledgeRetrievalNode knowledgeRetrievalNode = new KnowledgeRetrievalNode();
            knowledgeRetrievalNode.userPromptKey = this.userPromptKey;
            knowledgeRetrievalNode.userPrompt = this.userPrompt;
            knowledgeRetrievalNode.topKKey = this.topKKey;
            knowledgeRetrievalNode.topK = this.topK;
            knowledgeRetrievalNode.similarityThresholdKey = this.similarityThresholdKey;
            knowledgeRetrievalNode.similarityThreshold = this.similarityThreshold;
            knowledgeRetrievalNode.filterExpressionKey = this.filterExpressionKey;
            knowledgeRetrievalNode.filterExpression = this.filterExpression;
            knowledgeRetrievalNode.enableRankerKey = this.enableRankerKey;
            knowledgeRetrievalNode.enableRanker = this.enableRanker;
            knowledgeRetrievalNode.rerankModelKey = this.rerankModelKey;
            knowledgeRetrievalNode.rerankModel = this.rerankModel;
            knowledgeRetrievalNode.rerankOptionsKey = this.rerankOptionsKey;
            knowledgeRetrievalNode.rerankOptions = this.rerankOptions;
            knowledgeRetrievalNode.vectorStoreKey = this.vectorStoreKey;
            knowledgeRetrievalNode.vectorStore = this.vectorStore;
            return knowledgeRetrievalNode;
        }
    }

    /* loaded from: input_file:com/alibaba/cloud/ai/graph/node/KnowledgeRetrievalNode$KnowledgeRetrievalDocumentRanker.class */
    public static class KnowledgeRetrievalDocumentRanker implements DocumentPostProcessor {
        private RerankModel rerankModel;
        private DashScopeRerankOptions rerankOptions;

        public KnowledgeRetrievalDocumentRanker(RerankModel rerankModel, DashScopeRerankOptions dashScopeRerankOptions) {
            this.rerankModel = rerankModel;
            this.rerankOptions = dashScopeRerankOptions;
        }

        public List<Document> apply(Query query, List<Document> list) {
            return List.of();
        }

        @NotNull
        public List<Document> process(@Nullable Query query, @Nullable List<Document> list) {
            try {
                ArrayList arrayList = new ArrayList();
                if (Objects.nonNull(query) && StringUtils.hasText(query.text())) {
                    RerankResponse call = this.rerankModel.call(new RerankRequest(query.text(), list, this.rerankOptions));
                    Map map = (Map) list.stream().collect(Collectors.toMap((v0) -> {
                        return v0.getId();
                    }, Function.identity()));
                    call.getResults().forEach(documentWithScore -> {
                        Document document = (Document) map.get(documentWithScore.getOutput().getId());
                        if (document != null) {
                            arrayList.add(document);
                        }
                    });
                }
                return arrayList;
            } catch (Exception e) {
                KnowledgeRetrievalNode.logger.error("rank error", e);
                return list;
            }
        }
    }

    public KnowledgeRetrievalNode() {
        this.enableRanker = false;
    }

    public KnowledgeRetrievalNode(String str, Integer num, Double d, Filter.Expression expression, Boolean bool, RerankModel rerankModel, DashScopeRerankOptions dashScopeRerankOptions, VectorStore vectorStore) {
        this.enableRanker = false;
        this.userPrompt = str;
        this.topK = num;
        this.similarityThreshold = d;
        this.filterExpression = expression;
        this.enableRanker = bool;
        this.rerankModel = rerankModel;
        this.rerankOptions = dashScopeRerankOptions;
        this.vectorStore = vectorStore;
    }

    @Override // com.alibaba.cloud.ai.graph.action.NodeAction
    public Map<String, Object> apply(OverAllState overAllState) throws Exception {
        initNodeWithState(overAllState);
        VectorStoreDocumentRetriever build = VectorStoreDocumentRetriever.builder().similarityThreshold(this.similarityThreshold).topK(this.topK).filterExpression(this.filterExpression).vectorStore(this.vectorStore).build();
        Query query = new Query(this.userPrompt);
        this.documents = build.retrieve(query);
        this.documents = this.enableRanker.booleanValue() ? ranking(query, this.documents, new KnowledgeRetrievalDocumentRanker(this.rerankModel, this.rerankOptions)) : this.documents;
        StringBuilder sb = new StringBuilder(this.userPrompt);
        Iterator<Document> it = this.documents.iterator();
        while (it.hasNext()) {
            sb.append("Document: ").append(it.next().getFormattedContent()).append("\n");
        }
        HashMap hashMap = new HashMap();
        if (StringUtils.hasLength(this.userPromptKey)) {
            hashMap.put(this.userPromptKey, sb.toString());
        } else {
            hashMap.put("user_prompt", sb.toString());
        }
        return hashMap;
    }

    private void initNodeWithState(OverAllState overAllState) {
        if (StringUtils.hasLength(this.userPromptKey)) {
            this.userPrompt = (String) overAllState.value(this.userPromptKey).orElse(this.userPrompt);
        }
        if (StringUtils.hasLength(this.topKKey)) {
            this.topK = (Integer) overAllState.value(this.topKKey).orElse(this.topK);
        }
        if (StringUtils.hasLength(this.similarityThresholdKey)) {
            this.similarityThreshold = (Double) overAllState.value(this.similarityThresholdKey).orElse(this.similarityThreshold);
        }
        if (StringUtils.hasLength(this.filterExpressionKey)) {
            this.filterExpression = (Filter.Expression) overAllState.value(this.filterExpressionKey).orElse(this.filterExpression);
        }
        if (StringUtils.hasLength(this.enableRankerKey)) {
            this.enableRanker = (Boolean) overAllState.value(this.enableRankerKey).orElse(this.enableRanker);
        }
        if (StringUtils.hasLength(this.rerankModelKey)) {
            this.rerankModel = (RerankModel) overAllState.value(this.rerankModelKey).orElse(this.rerankModel);
        }
        if (StringUtils.hasLength(this.rerankOptionsKey)) {
            this.rerankOptions = (DashScopeRerankOptions) overAllState.value(this.rerankOptionsKey).orElse(this.rerankOptions);
        }
        if (StringUtils.hasLength(this.vectorStoreKey)) {
            this.vectorStore = (VectorStore) overAllState.value(this.vectorStoreKey).orElse(this.vectorStore);
        }
    }

    private List<Document> ranking(Query query, List<Document> list, DocumentPostProcessor documentPostProcessor) {
        if (list.size() <= 1) {
            return list;
        }
        try {
            return documentPostProcessor.process(query, list);
        } catch (Exception e) {
            logger.error("ranking error", e);
            return list;
        }
    }

    public static Builder builder() {
        return new Builder();
    }
}
