package com.primihub.sdk.task.factory;

import com.primihub.sdk.constant.TaskConstant;
import com.primihub.sdk.task.Functional;
import com.primihub.sdk.task.cache.CacheService;
import com.primihub.sdk.task.param.TaskParam;
import io.grpc.Channel;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.TimeUnit;
import java.util.function.Function;
import java.util.stream.Collectors;
import java_data_service.DataSetServiceGrpc;
import java_worker.TaskStatus;
import java_worker.TaskStatusReply;
import java_worker.VMNodeGrpc;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import primihub.rpc.Common;

/* loaded from: input_file:com/primihub/sdk/task/factory/AbstractGRPCExecuteFactory.class */
public abstract class AbstractGRPCExecuteFactory {
    private static final Logger log = LoggerFactory.getLogger(AbstractGRPCExecuteFactory.class);
    public static final String TASK_STATUS_KEY = "ts:<taskId>:<jobId>";

    public Common.TaskContext assembleTaskContext(TaskParam taskParam) {
        return Common.TaskContext.newBuilder().setJobId(taskParam.getJobId()).setRequestId(taskParam.getRequestId()).setTaskId(taskParam.getTaskId()).m1333build();
    }

    public Map<String, Common.Dataset> assembleModelDatasets(Map<String, Object> map) {
        HashMap hashMap = new HashMap();
        if (map.containsKey(TaskConstant.FTL_KEY_LABEL_DATASET)) {
            hashMap.put(TaskConstant.TASK_MODEL_LABEL_DATASET, Common.Dataset.newBuilder().putData(TaskConstant.TASK_MODEL_PARTY_DATASETS_KEY, map.get(TaskConstant.FTL_KEY_LABEL_DATASET).toString()).m940build());
        }
        if (map.containsKey(TaskConstant.FTL_KEY_GUEST_DATASET)) {
            hashMap.put(TaskConstant.TASK_MODEL_GUEST_DATASET, Common.Dataset.newBuilder().putData(TaskConstant.TASK_MODEL_PARTY_DATASETS_KEY, map.get(TaskConstant.FTL_KEY_GUEST_DATASET).toString()).m940build());
        }
        if (map.containsKey(TaskConstant.FTL_KEY_ARBITER_DATASET)) {
            hashMap.put(TaskConstant.TASK_MODEL_ARBITER_DATASET, Common.Dataset.newBuilder().putData(TaskConstant.TASK_MODEL_PARTY_DATASETS_KEY, map.get(TaskConstant.FTL_KEY_ARBITER_DATASET).toString()).m940build());
        }
        return hashMap;
    }

    public Map<String, Common.Dataset> assembleModelMpcDatasets(List<String> list) {
        HashMap hashMap = new HashMap();
        for (int i = 0; i < list.size(); i++) {
            hashMap.put(TaskConstant.TASK_MODEL_MPC_DATASET + i, Common.Dataset.newBuilder().putData(TaskConstant.TASK_MPC_PARTY_DATASETS_KEY, list.get(i)).m940build());
        }
        return hashMap;
    }

    public <Result> Result runVMNodeGrpc(Functional<VMNodeGrpc.VMNodeBlockingStub, Result> functional, Channel channel) {
        int i = 0;
        do {
            i++;
            try {
                return functional.run(VMNodeGrpc.newBlockingStub(channel));
            } catch (Exception e) {
                e.printStackTrace();
            }
        } while (i <= 3);
        log.info("Over three anomalies thrown");
        throw e;
    }

    public <Result> Result runDataServiceGrpc(Functional<DataSetServiceGrpc.DataSetServiceBlockingStub, Result> functional, Channel channel) {
        int i = 0;
        do {
            i++;
            try {
                return functional.run((DataSetServiceGrpc.DataSetServiceBlockingStub) DataSetServiceGrpc.newBlockingStub(channel).withDeadlineAfter(3L, TimeUnit.SECONDS));
            } catch (Exception e) {
                e.printStackTrace();
            }
        } while (i <= 3);
        log.info("Over three anomalies thrown");
        throw e;
    }

    public void continuouslyObtainTaskStatus(Channel channel, Common.TaskContext taskContext, TaskParam taskParam, int i) {
        boolean z = true;
        String replace = TASK_STATUS_KEY.replace("<taskId>", taskContext.getTaskId()).replace("<jobId>", taskContext.getJobId());
        CacheService cacheService = getCacheService();
        while (z) {
            try {
                TaskStatusReply taskStatusReply = (TaskStatusReply) runVMNodeGrpc(vMNodeBlockingStub -> {
                    return vMNodeBlockingStub.fetchTaskStatus(taskContext);
                }, channel);
                if (taskStatusReply != null && taskStatusReply.getTaskStatusList() != null && !taskStatusReply.getTaskStatusList().isEmpty()) {
                    List list = (List) taskStatusReply.getTaskStatusList().stream().filter(taskStatus -> {
                        return (taskStatus.getParty() == null || "".equals(taskStatus.getParty())) ? false : true;
                    }).map((v0) -> {
                        return v0.getStatus();
                    }).map((v0) -> {
                        return v0.name();
                    }).collect(Collectors.toList());
                    if (!list.isEmpty()) {
                        log.info(taskStatusReply.toString());
                        List<String> list2 = cacheService.get(replace);
                        list2.addAll(list);
                        cacheService.put(replace, list2);
                        if (list.contains(TaskStatus.StatusCode.FAIL.name())) {
                            taskParam.setSuccess(false);
                            List<TaskStatus> list3 = (List) taskStatusReply.getTaskStatusList().stream().filter(taskStatus2 -> {
                                return taskStatus2.getStatus() == TaskStatus.StatusCode.FAIL;
                            }).collect(Collectors.toList());
                            StringBuilder sb = new StringBuilder();
                            for (TaskStatus taskStatus3 : list3) {
                                log.info("taskid:{} - requestId:{} -fail:{}", new Object[]{taskParam.getTaskId(), taskParam.getRequestId(), taskStatus3.toString()});
                                sb.append(taskStatus3.getParty()).append(":").append(taskStatus3.getMessage()).append("\n");
                            }
                            taskParam.setError(sb.toString());
                            z = false;
                        } else {
                            long longValue = getNumberOfSuccessfulTasks(replace, cacheService).longValue();
                            log.info("taskid:{} - requestId:{} - num:{} - success:{}", new Object[]{taskParam.getTaskId(), taskParam.getRequestId(), Integer.valueOf(i), Long.valueOf(longValue)});
                            if (i <= longValue) {
                                z = false;
                            }
                        }
                    }
                }
                Thread.sleep(1000L);
            } catch (Exception e) {
                taskParam.setSuccess(false);
                e.printStackTrace();
            }
        }
        cacheService.invalidate(replace);
        taskParam.setEnd(true);
    }

    public Long getNumberOfSuccessfulTasks(String str, CacheService cacheService) {
        List<String> list = cacheService.get(str);
        if (Integer.valueOf(list.size()).intValue() == 0) {
            return 0L;
        }
        Map map = (Map) list.stream().collect(Collectors.groupingBy(Function.identity(), Collectors.counting()));
        if (map.containsKey(TaskStatus.StatusCode.SUCCESS.name())) {
            return (Long) map.get(TaskStatus.StatusCode.SUCCESS.name());
        }
        return 0L;
    }

    public abstract void execute(Channel channel, TaskParam taskParam);

    public abstract CacheService getCacheService();

    public abstract void setCacheService(CacheService cacheService);
}
