package com.alibaba.cloud.ai.advisor;

import com.alibaba.cloud.ai.model.RerankModel;
import com.alibaba.cloud.ai.model.RerankRequest;
import com.alibaba.cloud.ai.model.RerankResponse;
import java.util.Comparator;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.chat.client.ChatClientRequest;
import org.springframework.ai.chat.client.ChatClientResponse;
import org.springframework.ai.chat.client.advisor.api.AdvisorChain;
import org.springframework.ai.chat.client.advisor.api.BaseAdvisor;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.prompt.PromptTemplate;
import org.springframework.ai.document.Document;
import org.springframework.ai.vectorstore.SearchRequest;
import org.springframework.ai.vectorstore.VectorStore;
import org.springframework.ai.vectorstore.filter.Filter;
import org.springframework.ai.vectorstore.filter.FilterExpressionTextParser;
import org.springframework.util.Assert;
import org.springframework.util.CollectionUtils;
import org.springframework.util.StringUtils;

/* loaded from: input_file:com/alibaba/cloud/ai/advisor/RetrievalRerankAdvisor.class */
public class RetrievalRerankAdvisor implements BaseAdvisor {
    private static final Logger logger = LoggerFactory.getLogger(RetrievalRerankAdvisor.class);
    private static final PromptTemplate DEFAULT_PROMPT_TEMPLATE = new PromptTemplate("{query}\n\nContext information is below, surrounded by ---------------------\n---------------------\n{question_answer_context}\n---------------------\nGiven the context and provided history information and not prior knowledge,\nreply to the user comment. If the answer is not in the context, inform\nthe user that you can't answer the question.\n");
    private static final Double DEFAULT_MIN_SCORE = Double.valueOf(0.1d);
    private static final int DEFAULT_ORDER = 0;
    private final VectorStore vectorStore;
    private final RerankModel rerankModel;
    private final PromptTemplate promptTemplate;
    private final SearchRequest searchRequest;
    private final Double minScore;
    private final int order;
    public static final String RETRIEVED_DOCUMENTS = "qa_retrieved_documents";
    public static final String FILTER_EXPRESSION = "qa_filter_expression";

    public RetrievalRerankAdvisor(VectorStore vectorStore, RerankModel rerankModel) {
        this(vectorStore, rerankModel, SearchRequest.builder().build(), DEFAULT_PROMPT_TEMPLATE, DEFAULT_MIN_SCORE);
    }

    public RetrievalRerankAdvisor(VectorStore vectorStore, RerankModel rerankModel, Double d) {
        this(vectorStore, rerankModel, SearchRequest.builder().build(), DEFAULT_PROMPT_TEMPLATE, d);
    }

    public RetrievalRerankAdvisor(VectorStore vectorStore, RerankModel rerankModel, SearchRequest searchRequest) {
        this(vectorStore, rerankModel, searchRequest, DEFAULT_PROMPT_TEMPLATE, DEFAULT_MIN_SCORE);
    }

    public RetrievalRerankAdvisor(VectorStore vectorStore, RerankModel rerankModel, SearchRequest searchRequest, PromptTemplate promptTemplate, Double d) {
        this(vectorStore, rerankModel, searchRequest, promptTemplate, d, DEFAULT_ORDER);
    }

    public RetrievalRerankAdvisor(VectorStore vectorStore, RerankModel rerankModel, SearchRequest searchRequest, PromptTemplate promptTemplate, Double d, int i) {
        Assert.notNull(vectorStore, "The vectorStore must not be null!");
        Assert.notNull(rerankModel, "The rerankModel must not be null!");
        Assert.notNull(searchRequest, "The searchRequest must not be null!");
        Assert.notNull(promptTemplate, "The userTextAdvise must not be null!");
        this.vectorStore = vectorStore;
        this.rerankModel = rerankModel;
        this.promptTemplate = promptTemplate;
        this.searchRequest = searchRequest;
        this.minScore = d;
        this.order = i;
    }

    public int getOrder() {
        return this.order;
    }

    protected Filter.Expression doGetFilterExpression(Map<String, Object> map) {
        return (map.containsKey(FILTER_EXPRESSION) && StringUtils.hasText(map.get(FILTER_EXPRESSION).toString())) ? new FilterExpressionTextParser().parse(map.get(FILTER_EXPRESSION).toString()) : this.searchRequest.getFilterExpression();
    }

    protected List<Document> doRerank(ChatClientRequest chatClientRequest, List<Document> list) {
        if (CollectionUtils.isEmpty(list)) {
            return list;
        }
        RerankResponse call = this.rerankModel.call(new RerankRequest(chatClientRequest.prompt().getUserMessage().getText(), list));
        logger.debug("reranked documents: {}", call);
        return (call == null || call.getResults() == null) ? list : (List) call.getResults().stream().filter(documentWithScore -> {
            return documentWithScore != null && documentWithScore.getScore().doubleValue() >= this.minScore.doubleValue();
        }).sorted(Comparator.comparingDouble((v0) -> {
            return v0.getScore();
        }).reversed()).map((v0) -> {
            return v0.m51getOutput();
        }).collect(Collectors.toList());
    }

    public ChatClientRequest before(ChatClientRequest chatClientRequest, AdvisorChain advisorChain) {
        Map<String, Object> context = chatClientRequest.context();
        UserMessage userMessage = chatClientRequest.prompt().getUserMessage();
        List<Document> similaritySearch = this.vectorStore.similaritySearch(SearchRequest.from(this.searchRequest).query(userMessage.getText()).filterExpression(doGetFilterExpression(context)).build());
        context.put(RETRIEVED_DOCUMENTS, similaritySearch);
        return chatClientRequest.mutate().prompt(chatClientRequest.prompt().augmentUserMessage(this.promptTemplate.render(Map.of("query", userMessage.getText(), "question_answer_context", (String) doRerank(chatClientRequest, similaritySearch).stream().map((v0) -> {
            return v0.getText();
        }).collect(Collectors.joining(System.lineSeparator())))))).context(context).build();
    }

    public ChatClientResponse after(ChatClientResponse chatClientResponse, AdvisorChain advisorChain) {
        ChatResponse.Builder builder = chatClientResponse.chatResponse() == null ? ChatResponse.builder() : ChatResponse.builder().from(chatClientResponse.chatResponse());
        builder.metadata(RETRIEVED_DOCUMENTS, chatClientResponse.context().get(RETRIEVED_DOCUMENTS));
        return ChatClientResponse.builder().chatResponse(builder.build()).context(chatClientResponse.context()).build();
    }
}
