package com.alibaba.cloud.ai.graph.node;

import com.alibaba.cloud.ai.graph.OverAllState;
import com.alibaba.cloud.ai.graph.action.NodeAction;
import com.alibaba.cloud.ai.graph.agent.ReflectAgent;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.springframework.ai.chat.client.ChatClient;
import org.springframework.ai.chat.messages.AssistantMessage;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.prompt.SystemPromptTemplate;
import org.springframework.util.StringUtils;

/* loaded from: input_file:com/alibaba/cloud/ai/graph/node/QuestionClassifierNode.class */
public class QuestionClassifierNode implements NodeAction {
    private static final String CLASSIFIER_PROMPT_TEMPLATE = "\t### Job Description',\n\tYou are a text classification engine that analyzes text data and assigns categories based on user input or automatically determined categories.\n\t### Task\n\tYour task is to assign one category ONLY to the input text and only one category can be  returned in the output. Additionally, you need to extract the key words from the text that are related to the classification.\n\t### Format\n\tThe input text is: {inputText}. Categories are specified as a category list: {categories}. Classification instructions may be included to improve the classification accuracy: {classificationInstructions}.\n\t### Constraint\n\tDO NOT include anything other than the JSON array in your response.\n";
    private static final String QUESTION_CLASSIFIER_USER_PROMPT_1 = "\t{ \"input_text\": [\"I recently had a great experience with your company. The service was prompt and the staff was very friendly.\"],\n\t\"categories\": [\"Customer Service\", \"Satisfaction\", \"Sales\", \"Product\"],\n\t\"classification_instructions\": [\"classify the text based on the feedback provided by customer\"]}\n";
    private static final String QUESTION_CLASSIFIER_ASSISTANT_PROMPT_1 = "\t```json\n\t\t{\"keywords\": [\"recently\", \"great experience\", \"company\", \"service\", \"prompt\", \"staff\", \"friendly\"]\n\t\t\"category_name\": \"Customer Service\"}\n\t```\n";
    private static final String QUESTION_CLASSIFIER_USER_PROMPT_2 = "\t{\"input_text\": [\"bad service, slow to bring the food\"],\n\t\"categories\": [\"Food Quality\", \"Experience\", \"Price\"],\n\t\"classification_instructions\": []}\n";
    private static final String QUESTION_CLASSIFIER_ASSISTANT_PROMPT_2 = "\t```json\n\t\t{\"keywords\": [\"bad service\", \"slow\", \"food\", \"tip\", \"terrible\", \"waitresses\"],\n\t\t\"category_name\": \"Experience\"}\n\t```\n";
    private SystemPromptTemplate systemPromptTemplate = new SystemPromptTemplate(CLASSIFIER_PROMPT_TEMPLATE);
    private ChatClient chatClient;
    private String inputText;
    private List<String> categories;
    private List<String> classificationInstructions;
    private String inputTextKey;

    /* loaded from: input_file:com/alibaba/cloud/ai/graph/node/QuestionClassifierNode$Builder.class */
    public static class Builder {
        private String inputTextKey;
        private ChatClient chatClient;
        private List<String> categories;
        private List<String> classificationInstructions;

        public Builder inputTextKey(String str) {
            this.inputTextKey = str;
            return this;
        }

        public Builder chatClient(ChatClient chatClient) {
            this.chatClient = chatClient;
            return this;
        }

        public Builder categories(List<String> list) {
            this.categories = list;
            return this;
        }

        public Builder classificationInstructions(List<String> list) {
            this.classificationInstructions = list;
            return this;
        }

        public QuestionClassifierNode build() {
            return new QuestionClassifierNode(this.chatClient, this.inputTextKey, this.categories, this.classificationInstructions);
        }
    }

    public QuestionClassifierNode(ChatClient chatClient, String str, List<String> list, List<String> list2) {
        this.chatClient = chatClient;
        this.inputTextKey = str;
        this.categories = list;
        this.classificationInstructions = list2;
    }

    @Override // com.alibaba.cloud.ai.graph.action.NodeAction
    public Map<String, Object> apply(OverAllState overAllState) throws Exception {
        if (StringUtils.hasLength(this.inputTextKey)) {
            this.inputText = (String) overAllState.value(this.inputTextKey).orElse(this.inputText);
        }
        ArrayList arrayList = new ArrayList();
        UserMessage userMessage = new UserMessage(QUESTION_CLASSIFIER_USER_PROMPT_1);
        AssistantMessage assistantMessage = new AssistantMessage(QUESTION_CLASSIFIER_ASSISTANT_PROMPT_1);
        UserMessage userMessage2 = new UserMessage(QUESTION_CLASSIFIER_USER_PROMPT_2);
        AssistantMessage assistantMessage2 = new AssistantMessage(QUESTION_CLASSIFIER_ASSISTANT_PROMPT_2);
        arrayList.add(userMessage);
        arrayList.add(assistantMessage);
        arrayList.add(userMessage2);
        arrayList.add(assistantMessage2);
        ChatResponse chatResponse = this.chatClient.prompt().system(this.systemPromptTemplate.render(Map.of("inputText", this.inputText, "categories", this.categories, "classificationInstructions", this.classificationInstructions))).user(this.inputText).messages(arrayList).call().chatResponse();
        HashMap hashMap = new HashMap();
        hashMap.put("classifier_output", chatResponse.getResult().getOutput().getText());
        if (overAllState.value(ReflectAgent.MESSAGES).isPresent()) {
            hashMap.put(ReflectAgent.MESSAGES, chatResponse.getResult().getOutput());
        }
        return hashMap;
    }

    public static Builder builder() {
        return new Builder();
    }
}
