package com.primihub.sdk.task.factory;

import com.alibaba.fastjson.JSONObject;
import com.alibaba.fastjson.serializer.SerializerFeature;
import com.google.protobuf.ByteString;
import com.primihub.sdk.task.cache.CacheService;
import com.primihub.sdk.task.dataenum.ModelTypeEnum;
import com.primihub.sdk.task.param.TaskComponentParam;
import com.primihub.sdk.task.param.TaskParam;
import com.primihub.sdk.util.FreemarkerTemplate;
import io.grpc.Channel;
import java.nio.charset.StandardCharsets;
import java.util.Map;
import java_worker.PushTaskReply;
import java_worker.PushTaskRequest;
import org.apache.commons.lang3.StringUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import primihub.rpc.Common;

/* loaded from: input_file:com/primihub/sdk/task/factory/AbstractComponentGRPCExecute.class */
public class AbstractComponentGRPCExecute extends AbstractGRPCExecuteFactory {
    private static final Logger log = LoggerFactory.getLogger(AbstractComponentGRPCExecute.class);
    private CacheService cacheService;

    @Override // com.primihub.sdk.task.factory.AbstractGRPCExecuteFactory
    public CacheService getCacheService() {
        return this.cacheService;
    }

    @Override // com.primihub.sdk.task.factory.AbstractGRPCExecuteFactory
    public void setCacheService(CacheService cacheService) {
        this.cacheService = cacheService;
    }

    @Override // com.primihub.sdk.task.factory.AbstractGRPCExecuteFactory
    public void execute(Channel channel, TaskParam taskParam) {
        runComponentTask(channel, taskParam);
    }

    private void runComponentTask(Channel channel, TaskParam<TaskComponentParam> taskParam) {
        String templatesContent;
        String str;
        try {
            TaskComponentParam taskContentParam = taskParam.getTaskContentParam();
            if (taskContentParam.getModelType() == ModelTypeEnum.CLASSIFICATION_BINARY) {
                taskContentParam.getFreemarkerMap().put("taskNNType", "classification");
            }
            if (taskContentParam.getModelType() == ModelTypeEnum.REGRESSION_BINARY) {
                taskContentParam.getFreemarkerMap().put("taskNNType", "regression");
            }
            if (StringUtils.isEmpty(taskContentParam.getTemplatesContent())) {
                str = FreemarkerTemplate.getInstance().generateTemplateStr(taskContentParam.getFreemarkerMap(), taskContentParam.isInfer() ? taskContentParam.getModelType().getInferFtlPath() : taskContentParam.getModelType().getModelFtlPath());
            } else {
                if (taskContentParam.isUntreated()) {
                    templatesContent = FreemarkerTemplate.getInstance().generateTemplateStrFreemarkerContent(taskContentParam.isInfer() ? taskContentParam.getModelType().getInferFtlPath() : taskContentParam.getModelType().getModelFtlPath(), taskContentParam.getTemplatesContent(), taskContentParam.getFreemarkerMap());
                } else {
                    templatesContent = taskContentParam.getTemplatesContent();
                }
                str = templatesContent;
            }
            log.info("start taskParam:{} - freemarkerContent:{}", taskParam, str);
            Common.Params m1181build = Common.Params.newBuilder().putParamMap("component_params", Common.ParamValue.newBuilder().setValueString(ByteString.copyFrom(JSONObject.toJSONString(JSONObject.parseObject(str), new SerializerFeature[]{SerializerFeature.WriteMapNullValue}).getBytes(StandardCharsets.UTF_8))).m1133build()).m1181build();
            Map<String, Common.Dataset> assembleModelDatasets = assembleModelDatasets(taskContentParam.getFreemarkerMap());
            Common.TaskContext assembleTaskContext = assembleTaskContext(taskParam);
            PushTaskRequest m694build = PushTaskRequest.newBuilder().setIntendedWorkerId(ByteString.copyFrom("1".getBytes(StandardCharsets.UTF_8))).setTask(Common.Task.newBuilder().setType(Common.TaskType.ACTOR_TASK).setParams(m1181build).setName(taskParam.getTaskContentParam().getModelType() == null ? "taskModel" : taskParam.getTaskContentParam().getModelType().getTypeName()).setLanguage(Common.Language.PYTHON).setTaskInfo(assembleTaskContext).putAllPartyDatasets(assembleModelDatasets).m1283build()).setSequenceNumber(11L).setClientProcessedUpTo(22L).m694build();
            log.info("grpc PushTaskRequest :\n{}", m694build.toString());
            PushTaskReply pushTaskReply = (PushTaskReply) runVMNodeGrpc(vMNodeBlockingStub -> {
                return vMNodeBlockingStub.submitTask(m694build);
            }, channel);
            log.info("grpc结果:{}", pushTaskReply.toString());
            if (pushTaskReply.getRetCode() == 0) {
                taskParam.setPartyCount(Integer.valueOf(pushTaskReply.getPartyCount()));
                if (taskParam.getOpenGetStatus().booleanValue()) {
                    continuouslyObtainTaskStatus(channel, assembleTaskContext, taskParam, pushTaskReply.getPartyCount());
                }
            } else {
                taskParam.setSuccess(false);
                taskParam.setError(pushTaskReply.getMsgInfo().toStringUtf8());
            }
        } catch (Exception e) {
            taskParam.setSuccess(false);
            taskParam.setError(e.getMessage());
            log.info("grpc Exception:{}", e.getMessage());
            e.printStackTrace();
        }
    }
}
