package io.github.lnyocly.ai4j.utils;

import com.alibaba.fastjson2.JSON;
import io.github.lnyocly.ai4j.annotation.FunctionCall;
import io.github.lnyocly.ai4j.annotation.FunctionParameter;
import io.github.lnyocly.ai4j.annotation.FunctionRequest;
import io.github.lnyocly.ai4j.platform.openai.tool.Tool;
import java.lang.reflect.Field;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import org.reflections.Reflections;
import org.reflections.scanners.Scanner;
import org.reflections.scanners.Scanners;
import org.reflections.util.ClasspathHelper;
import org.reflections.util.ConfigurationBuilder;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:io/github/lnyocly/ai4j/utils/ToolUtil.class */
public class ToolUtil {
    private static final Logger log = LoggerFactory.getLogger(ToolUtil.class);
    static Reflections reflections = new Reflections(new ConfigurationBuilder().setUrls(ClasspathHelper.forPackage("", new ClassLoader[0])).setScanners(new Scanner[]{Scanners.TypesAnnotated}));
    public static Map<String, Tool> toolEntityMap = new ConcurrentHashMap();
    public static Map<String, Class<?>> toolClassMap = new ConcurrentHashMap();
    public static Map<String, Class<?>> toolRequestMap = new ConcurrentHashMap();

    public static String invoke(String str, String str2) {
        long currentTimeMillis = System.currentTimeMillis();
        Class<?> cls = toolClassMap.get(str);
        Class<?> cls2 = toolRequestMap.get(str);
        log.info("tool call function {}, argument {}", str, str2);
        try {
            String jSONString = JSON.toJSONString(cls.getMethod("apply", cls2).invoke(cls.newInstance(), JSON.parseObject(str2, cls2)));
            log.info("response {}, cost {} ms", jSONString, Long.valueOf(System.currentTimeMillis() - currentTimeMillis));
            return jSONString;
        } catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    public static List<Tool> getAllFunctionTools(List<String> list) {
        ArrayList arrayList = new ArrayList();
        for (String str : list) {
            Tool tool = toolEntityMap.get(str);
            if (tool == null) {
                tool = getToolEntity(str);
            }
            if (tool != null) {
                toolEntityMap.put(str, tool);
                arrayList.add(tool);
            }
        }
        if (arrayList.isEmpty()) {
            return null;
        }
        return arrayList;
    }

    public static Tool getToolEntity(String str) {
        Tool.Function functionEntity = getFunctionEntity(str);
        if (functionEntity == null) {
            return null;
        }
        Tool tool = new Tool();
        tool.setType("function");
        tool.setFunction(functionEntity);
        return tool;
    }

    public static Tool.Function getFunctionEntity(String str) {
        for (Class<?> cls : reflections.getTypesAnnotatedWith(FunctionCall.class)) {
            FunctionCall functionCall = (FunctionCall) cls.getAnnotation(FunctionCall.class);
            String name = functionCall.name();
            if (name.equals(str)) {
                Tool.Function function = new Tool.Function();
                function.setName(name);
                function.setDescription(functionCall.description());
                setFunctionParameters(function, cls);
                toolClassMap.put(str, cls);
                return function;
            }
        }
        return null;
    }

    private static void setFunctionParameters(Tool.Function function, Class<?> cls) {
        Class<?>[] declaredClasses = cls.getDeclaredClasses();
        HashMap hashMap = new HashMap();
        ArrayList arrayList = new ArrayList();
        for (Class<?> cls2 : declaredClasses) {
            if (((FunctionRequest) cls2.getAnnotation(FunctionRequest.class)) != null) {
                toolRequestMap.put(function.getName(), cls2);
                for (Field field : cls2.getDeclaredFields()) {
                    FunctionParameter functionParameter = (FunctionParameter) field.getAnnotation(FunctionParameter.class);
                    if (functionParameter != null) {
                        Class<?> type = field.getType();
                        String mapJavaTypeToJsonSchemaType = mapJavaTypeToJsonSchemaType(type);
                        Tool.Function.Property property = new Tool.Function.Property();
                        property.setType(mapJavaTypeToJsonSchemaType);
                        property.setDescription(functionParameter.description());
                        if (type.isEnum()) {
                            property.setEnumValues(getEnumValues(type));
                        }
                        hashMap.put(field.getName(), property);
                        if (functionParameter.required()) {
                            arrayList.add(field.getName());
                        }
                    }
                }
            }
        }
        function.setParameters(new Tool.Function.Parameter("object", hashMap, arrayList));
    }

    private static String mapJavaTypeToJsonSchemaType(Class<?> cls) {
        return (cls.isEnum() || cls.equals(String.class)) ? "string" : (cls.equals(Integer.TYPE) || cls.equals(Integer.class) || cls.equals(Long.TYPE) || cls.equals(Long.class) || cls.equals(Short.TYPE) || cls.equals(Short.class) || cls.equals(Float.TYPE) || cls.equals(Float.class) || cls.equals(Double.TYPE) || cls.equals(Double.class)) ? "number" : (cls.equals(Boolean.TYPE) || cls.equals(Boolean.class)) ? "boolean" : (cls.isArray() || Collection.class.isAssignableFrom(cls)) ? "array" : Map.class.isAssignableFrom(cls) ? "object" : "object";
    }

    private static List<String> getEnumValues(Class<?> cls) {
        ArrayList arrayList = new ArrayList();
        for (Object obj : cls.getEnumConstants()) {
            arrayList.add(obj.toString());
        }
        return arrayList;
    }
}
