package org.apache.seatunnel.transform.nlpmodel.embedding.remote.qianfan;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import org.apache.http.client.config.RequestConfig;
import org.apache.http.client.methods.CloseableHttpResponse;
import org.apache.http.client.methods.HttpGet;
import org.apache.http.client.methods.HttpPost;
import org.apache.http.client.methods.HttpUriRequest;
import org.apache.http.entity.StringEntity;
import org.apache.http.impl.client.CloseableHttpClient;
import org.apache.http.impl.client.HttpClients;
import org.apache.http.util.EntityUtils;
import org.apache.seatunnel.shade.com.fasterxml.jackson.core.type.TypeReference;
import org.apache.seatunnel.shade.com.fasterxml.jackson.databind.JsonNode;
import org.apache.seatunnel.shade.com.fasterxml.jackson.databind.node.ObjectNode;
import org.apache.seatunnel.shade.com.google.common.annotations.VisibleForTesting;
import org.apache.seatunnel.transform.nlpmodel.CustomConfigPlaceholder;
import org.apache.seatunnel.transform.nlpmodel.embedding.remote.AbstractModel;

/* loaded from: input_file:org/apache/seatunnel/transform/nlpmodel/embedding/remote/qianfan/QianfanModel.class */
public class QianfanModel extends AbstractModel {
    private final CloseableHttpClient client;
    private final String apiKey;
    private final String secretKey;
    private final String model;
    private final String apiPath;
    private final String oauthPath;
    private final String oauthSuffixPath = "?grant_type=client_credentials&client_id=%s&client_secret=%s";
    private String accessToken;

    public QianfanModel(String str, String str2, String str3, String str4, String str5, Integer num) throws IOException {
        super(num);
        this.oauthSuffixPath = "?grant_type=client_credentials&client_id=%s&client_secret=%s";
        this.apiKey = str;
        this.secretKey = str2;
        this.model = str3;
        this.apiPath = str4;
        this.oauthPath = str5;
        this.client = HttpClients.createDefault();
        this.accessToken = getAccessToken();
    }

    public QianfanModel(String str, String str2, String str3, String str4, Integer num, String str5, String str6) throws IOException {
        super(num);
        this.oauthSuffixPath = "?grant_type=client_credentials&client_id=%s&client_secret=%s";
        this.apiKey = str;
        this.secretKey = str2;
        this.model = str3;
        this.apiPath = str4;
        this.oauthPath = str5;
        this.client = HttpClients.createDefault();
        this.accessToken = str6;
    }

    private String getAccessToken() throws IOException {
        CloseableHttpResponse execute = this.client.execute((HttpUriRequest) new HttpGet(String.format(this.oauthPath + "?grant_type=client_credentials&client_id=%s&client_secret=%s", this.apiKey, this.secretKey)));
        String entityUtils = EntityUtils.toString(execute.getEntity());
        if (execute.getStatusLine().getStatusCode() != 200) {
            throw new IOException("Failed to Oauth for qianfan, response: " + entityUtils);
        }
        return OBJECT_MAPPER.readTree(entityUtils).get("access_token").asText();
    }

    @Override // org.apache.seatunnel.transform.nlpmodel.embedding.remote.AbstractModel
    public List<List<Float>> vector(Object[] objArr) throws IOException {
        return vectorGeneration(objArr);
    }

    @Override // org.apache.seatunnel.transform.nlpmodel.embedding.remote.Model
    public Integer dimension() throws IOException {
        return Integer.valueOf(vectorGeneration(new Object[]{"dimension example"}).get(0).size());
    }

    private List<List<Float>> vectorGeneration(Object[] objArr) throws IOException {
        HttpPost httpPost = new HttpPost(String.format((this.apiPath.endsWith("/") ? this.apiPath : this.apiPath + "/") + "%s?access_token=%s", this.model, this.accessToken));
        httpPost.setHeader("Content-Type", "application/json");
        httpPost.setConfig(RequestConfig.custom().setConnectTimeout(20000).setSocketTimeout(20000).build());
        httpPost.setEntity(new StringEntity(OBJECT_MAPPER.writeValueAsString(createJsonNodeFromData(objArr)), "UTF-8"));
        CloseableHttpResponse execute = this.client.execute((HttpUriRequest) httpPost);
        String entityUtils = EntityUtils.toString(execute.getEntity());
        if (execute.getStatusLine().getStatusCode() != 200) {
            throw new IOException("Failed to get vector from qianfan, response: " + entityUtils);
        }
        JsonNode readTree = OBJECT_MAPPER.readTree(entityUtils);
        JsonNode jsonNode = readTree.get("error_code");
        if (jsonNode != null) {
            if (jsonNode.asInt() == 110) {
                this.accessToken = getAccessToken();
            }
            throw new IOException("Failed to get vector from qianfan, response: " + readTree.get("error_msg"));
        }
        ArrayList arrayList = new ArrayList();
        JsonNode jsonNode2 = readTree.get("data");
        if (jsonNode2.isArray()) {
            Iterator it = jsonNode2.iterator();
            while (it.hasNext()) {
                arrayList.add((List) OBJECT_MAPPER.readValue(((JsonNode) it.next()).get("embedding").traverse(), new TypeReference<List<Float>>() { // from class: org.apache.seatunnel.transform.nlpmodel.embedding.remote.qianfan.QianfanModel.1
                }));
            }
        }
        return arrayList;
    }

    @VisibleForTesting
    public ObjectNode createJsonNodeFromData(Object[] objArr) {
        return OBJECT_MAPPER.createObjectNode().set(CustomConfigPlaceholder.REPLACE_PLACEHOLDER_INPUT, OBJECT_MAPPER.valueToTree(Arrays.asList(objArr)));
    }

    @Override // java.io.Closeable, java.lang.AutoCloseable
    public void close() throws IOException {
        if (this.client != null) {
            this.client.close();
        }
    }
}
