package com.alibaba.cloud.ai.advisor;

import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import org.springframework.ai.chat.client.ChatClientRequest;
import org.springframework.ai.chat.client.ChatClientResponse;
import org.springframework.ai.chat.client.advisor.api.AdvisorChain;
import org.springframework.ai.chat.client.advisor.api.BaseAdvisor;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.prompt.PromptTemplate;
import org.springframework.ai.rag.Query;
import org.springframework.ai.rag.retrieval.search.DocumentRetriever;
import org.springframework.util.Assert;

/* loaded from: input_file:com/alibaba/cloud/ai/advisor/DocumentRetrievalAdvisor.class */
public class DocumentRetrievalAdvisor implements BaseAdvisor {
    private static final int DEFAULT_ORDER = 0;
    private final DocumentRetriever retriever;
    private final PromptTemplate promptTemplate;
    private final int order;
    private static final PromptTemplate DEFAULT_PROMPT_TEMPLATE = new PromptTemplate("{query}\n\nContext information is below, surrounded by ---------------------\n---------------------\n{question_answer_context}\n---------------------\nGiven the context and provided history information and not prior knowledge,\nreply to the user comment. If the answer is not in the context, inform\nthe user that you can't answer the question.\n");
    public static String RETRIEVED_DOCUMENTS = "question_answer_context";

    public DocumentRetrievalAdvisor(DocumentRetriever documentRetriever) {
        this(documentRetriever, DEFAULT_PROMPT_TEMPLATE);
    }

    public DocumentRetrievalAdvisor(DocumentRetriever documentRetriever, PromptTemplate promptTemplate) {
        this(documentRetriever, promptTemplate, DEFAULT_ORDER);
    }

    public DocumentRetrievalAdvisor(DocumentRetriever documentRetriever, PromptTemplate promptTemplate, int i) {
        Assert.notNull(documentRetriever, "The retriever must not be null!");
        Assert.notNull(promptTemplate, "The promptTemplate must not be null!");
        this.retriever = documentRetriever;
        this.promptTemplate = promptTemplate;
        this.order = i;
    }

    public int getOrder() {
        return this.order;
    }

    public ChatClientRequest before(ChatClientRequest chatClientRequest, AdvisorChain advisorChain) {
        Map context = chatClientRequest.context();
        UserMessage userMessage = chatClientRequest.prompt().getUserMessage();
        List retrieve = this.retriever.retrieve(new Query(userMessage.getText(), chatClientRequest.prompt().getInstructions(), context));
        context.put(RETRIEVED_DOCUMENTS, retrieve);
        return chatClientRequest.mutate().prompt(chatClientRequest.prompt().augmentUserMessage(this.promptTemplate.render(Map.of("query", userMessage.getText(), "question_answer_context", (String) retrieve.stream().map((v0) -> {
            return v0.getText();
        }).collect(Collectors.joining(System.lineSeparator())))))).context(context).build();
    }

    public ChatClientResponse after(ChatClientResponse chatClientResponse, AdvisorChain advisorChain) {
        ChatResponse.Builder builder = chatClientResponse.chatResponse() == null ? ChatResponse.builder() : ChatResponse.builder().from(chatClientResponse.chatResponse());
        builder.metadata(RETRIEVED_DOCUMENTS, chatClientResponse.context().get(RETRIEVED_DOCUMENTS));
        return ChatClientResponse.builder().chatResponse(builder.build()).context(chatClientResponse.context()).build();
    }
}
