package org.apache.seatunnel.transform.nlpmodel.embedding;

import groovyjarjarpicocli.CommandLine;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import lombok.NonNull;
import org.apache.seatunnel.api.configuration.ReadonlyConfig;
import org.apache.seatunnel.api.table.catalog.CatalogTable;
import org.apache.seatunnel.api.table.catalog.Column;
import org.apache.seatunnel.api.table.catalog.PhysicalColumn;
import org.apache.seatunnel.api.table.type.SeaTunnelRowAccessor;
import org.apache.seatunnel.api.table.type.SeaTunnelRowType;
import org.apache.seatunnel.api.table.type.VectorType;
import org.apache.seatunnel.transform.common.MultipleFieldOutputTransform;
import org.apache.seatunnel.transform.exception.TransformCommonError;
import org.apache.seatunnel.transform.nlpmodel.ModelProvider;
import org.apache.seatunnel.transform.nlpmodel.ModelTransformConfig;
import org.apache.seatunnel.transform.nlpmodel.embedding.remote.Model;
import org.apache.seatunnel.transform.nlpmodel.embedding.remote.custom.CustomModel;
import org.apache.seatunnel.transform.nlpmodel.embedding.remote.doubao.DoubaoModel;
import org.apache.seatunnel.transform.nlpmodel.embedding.remote.openai.OpenAIModel;
import org.apache.seatunnel.transform.nlpmodel.embedding.remote.qianfan.QianfanModel;

/* loaded from: input_file:org/apache/seatunnel/transform/nlpmodel/embedding/EmbeddingTransform.class */
public class EmbeddingTransform extends MultipleFieldOutputTransform {
    private final ReadonlyConfig config;
    private List<String> fieldNames;
    private List<Integer> fieldOriginalIndexes;
    private Model model;
    private Integer dimension;

    public EmbeddingTransform(@NonNull ReadonlyConfig readonlyConfig, @NonNull CatalogTable catalogTable) {
        super(catalogTable);
        if (readonlyConfig == null) {
            throw new NullPointerException("config is marked non-null but is null");
        }
        if (catalogTable == null) {
            throw new NullPointerException("inputCatalogTable is marked non-null but is null");
        }
        this.config = readonlyConfig;
        initOutputFields(catalogTable.getTableSchema().toPhysicalRowDataType(), (Map) readonlyConfig.get(EmbeddingTransformConfig.VECTORIZATION_FIELDS));
    }

    private void tryOpen() {
        if (this.model == null) {
            open();
        }
    }

    public void open() {
        ModelProvider modelProvider = (ModelProvider) this.config.get(ModelTransformConfig.MODEL_PROVIDER);
        try {
            switch (modelProvider) {
                case CUSTOM:
                    ReadonlyConfig readonlyConfig = (ReadonlyConfig) this.config.getOptional(ModelTransformConfig.CustomRequestConfig.CUSTOM_CONFIG).map(ReadonlyConfig::fromMap).orElseThrow(() -> {
                        return new IllegalArgumentException("Custom config can't be null");
                    });
                    this.model = new CustomModel((String) this.config.get(ModelTransformConfig.MODEL), modelProvider.usedEmbeddingPath((String) this.config.get(ModelTransformConfig.API_PATH)), (Map) readonlyConfig.get(ModelTransformConfig.CustomRequestConfig.CUSTOM_REQUEST_HEADERS), (Map) readonlyConfig.get(ModelTransformConfig.CustomRequestConfig.CUSTOM_REQUEST_BODY), (String) readonlyConfig.get(ModelTransformConfig.CustomRequestConfig.CUSTOM_RESPONSE_PARSE), (Integer) this.config.get(EmbeddingTransformConfig.SINGLE_VECTORIZED_INPUT_NUMBER));
                    break;
                case OPENAI:
                    this.model = new OpenAIModel((String) this.config.get(ModelTransformConfig.API_KEY), (String) this.config.get(ModelTransformConfig.MODEL), modelProvider.usedEmbeddingPath((String) this.config.get(ModelTransformConfig.API_PATH)), (Integer) this.config.get(EmbeddingTransformConfig.SINGLE_VECTORIZED_INPUT_NUMBER));
                    break;
                case DOUBAO:
                    this.model = new DoubaoModel((String) this.config.get(ModelTransformConfig.API_KEY), (String) this.config.get(ModelTransformConfig.MODEL), modelProvider.usedEmbeddingPath((String) this.config.get(ModelTransformConfig.API_PATH)), (Integer) this.config.get(EmbeddingTransformConfig.SINGLE_VECTORIZED_INPUT_NUMBER));
                    break;
                case QIANFAN:
                    this.model = new QianfanModel((String) this.config.get(ModelTransformConfig.API_KEY), (String) this.config.get(ModelTransformConfig.SECRET_KEY), (String) this.config.get(ModelTransformConfig.MODEL), modelProvider.usedEmbeddingPath((String) this.config.get(ModelTransformConfig.API_PATH)), (String) this.config.get(ModelTransformConfig.OAUTH_PATH), (Integer) this.config.get(EmbeddingTransformConfig.SINGLE_VECTORIZED_INPUT_NUMBER));
                    break;
                case LOCAL:
                default:
                    throw new IllegalArgumentException("Unsupported model provider: " + modelProvider);
            }
            this.dimension = this.model.dimension();
        } catch (IOException e) {
            throw new RuntimeException("Failed to initialize model", e);
        }
    }

    private void initOutputFields(SeaTunnelRowType seaTunnelRowType, Map<String, String> map) {
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        for (Map.Entry<String, String> entry : map.entrySet()) {
            String value = entry.getValue();
            try {
                int indexOf = seaTunnelRowType.indexOf(value);
                arrayList.add(entry.getKey());
                arrayList2.add(Integer.valueOf(indexOf));
            } catch (IllegalArgumentException e) {
                throw TransformCommonError.cannotFindInputFieldError(getPluginName(), value);
            }
        }
        this.fieldNames = arrayList;
        this.fieldOriginalIndexes = arrayList2;
    }

    @Override // org.apache.seatunnel.transform.common.MultipleFieldOutputTransform
    protected Object[] getOutputFieldValues(SeaTunnelRowAccessor seaTunnelRowAccessor) {
        tryOpen();
        try {
            Object[] objArr = new Object[this.fieldOriginalIndexes.size()];
            for (int i = 0; i < this.fieldOriginalIndexes.size(); i++) {
                objArr[i] = seaTunnelRowAccessor.getField(this.fieldOriginalIndexes.get(i).intValue());
            }
            return this.model.vectorization(objArr).toArray();
        } catch (Exception e) {
            throw new RuntimeException("Failed to data vectorization", e);
        }
    }

    @Override // org.apache.seatunnel.transform.common.MultipleFieldOutputTransform
    protected Column[] getOutputColumns() {
        Column[] columnArr = new Column[this.fieldNames.size()];
        for (int i = 0; i < this.fieldNames.size(); i++) {
            columnArr[i] = PhysicalColumn.of(this.fieldNames.get(i), VectorType.VECTOR_FLOAT_TYPE, (Long) null, this.dimension, true, CommandLine.Model.OptionSpec.DEFAULT_FALLBACK_VALUE, CommandLine.Model.OptionSpec.DEFAULT_FALLBACK_VALUE);
        }
        return columnArr;
    }

    public String getPluginName() {
        return "Embedding";
    }

    public void close() {
        if (this.model != null) {
            this.model.close();
        }
    }
}
