package org.apache.dolphinscheduler.plugin.task.mlflow;

import java.util.ArrayList;
import java.util.Map;
import java.util.regex.Pattern;
import java.util.stream.Collectors;
import lombok.Generated;
import org.apache.commons.lang3.StringUtils;
import org.apache.dolphinscheduler.common.thread.ThreadUtils;
import org.apache.dolphinscheduler.common.utils.JSONUtils;
import org.apache.dolphinscheduler.common.utils.OSUtils;
import org.apache.dolphinscheduler.common.utils.PropertyUtils;
import org.apache.dolphinscheduler.plugin.task.api.AbstractTask;
import org.apache.dolphinscheduler.plugin.task.api.ShellCommandExecutor;
import org.apache.dolphinscheduler.plugin.task.api.TaskCallBack;
import org.apache.dolphinscheduler.plugin.task.api.TaskException;
import org.apache.dolphinscheduler.plugin.task.api.TaskExecutionContext;
import org.apache.dolphinscheduler.plugin.task.api.model.Property;
import org.apache.dolphinscheduler.plugin.task.api.model.TaskResponse;
import org.apache.dolphinscheduler.plugin.task.api.shell.ShellInterceptorBuilderFactory;
import org.apache.dolphinscheduler.plugin.task.api.utils.ParameterUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/apache/dolphinscheduler/plugin/task/mlflow/MlflowTask.class */
public class MlflowTask extends AbstractTask {

    @Generated
    private static final Logger log = LoggerFactory.getLogger(MlflowTask.class);
    private static final Pattern GIT_CHECK_PATTERN = Pattern.compile("^(git@|https?://)");
    private final ShellCommandExecutor shellCommandExecutor;
    private final TaskExecutionContext taskExecutionContext;
    private MlflowParameters mlflowParameters;

    public MlflowTask(TaskExecutionContext taskExecutionContext) {
        super(taskExecutionContext);
        this.taskExecutionContext = taskExecutionContext;
        this.shellCommandExecutor = new ShellCommandExecutor(this::logHandle, taskExecutionContext);
    }

    public static String getPresetRepository() {
        String string = PropertyUtils.getString(MlflowConstants.PRESET_REPOSITORY_KEY);
        if (StringUtils.isEmpty(string)) {
            string = MlflowConstants.PRESET_REPOSITORY;
        }
        return string;
    }

    public static String getPresetRepositoryVersion() {
        String string = PropertyUtils.getString(MlflowConstants.PRESET_REPOSITORY_VERSION_KEY);
        if (StringUtils.isEmpty(string)) {
            string = MlflowConstants.PRESET_REPOSITORY_VERSION;
        }
        return string;
    }

    public static String getVersionString(String str, String str2) {
        return StringUtils.isEmpty(str) ? "" : GIT_CHECK_PATTERN.matcher(str2).find() ? String.format("--version=%s", str) : "";
    }

    public void init() {
        this.mlflowParameters = (MlflowParameters) JSONUtils.parseObject(this.taskExecutionContext.getTaskParams(), MlflowParameters.class);
        log.info("Initialize MLFlow task params {}", JSONUtils.toPrettyJsonString(this.mlflowParameters));
        if (this.mlflowParameters == null || !this.mlflowParameters.checkParameters()) {
            throw new RuntimeException("MLFlow task params is not valid");
        }
    }

    public void handle(TaskCallBack taskCallBack) throws TaskException {
        try {
            TaskResponse run = this.shellCommandExecutor.run(ShellInterceptorBuilderFactory.newBuilder().properties(ParameterUtils.convert(getParamsMap())).appendScript(buildCommand()), taskCallBack);
            setExitStatusCode(this.mlflowParameters.getIsDeployDocker() ? checkDockerHealth() : run.getExitStatusCode());
            setProcessId(run.getProcessId());
            this.mlflowParameters.dealOutParam(this.shellCommandExecutor.getTaskOutputParams());
        } catch (InterruptedException e) {
            Thread.currentThread().interrupt();
            log.error("The current Mlflow task has been interrupted", e);
            setExitStatusCode(-1);
            throw new TaskException("The current Mlflow task has been interrupted", e);
        } catch (Exception e2) {
            log.error("Mlflow task error", e2);
            setExitStatusCode(-1);
            throw new TaskException("Execute Mlflow task failed", e2);
        }
    }

    public void cancel() throws TaskException {
        try {
            this.shellCommandExecutor.cancelApplication();
        } catch (Exception e) {
            throw new TaskException("cancel application error", e);
        }
    }

    public String buildCommand() {
        String str = "";
        if (this.mlflowParameters.getMlflowTaskType().equals(MlflowConstants.MLFLOW_TASK_TYPE_PROJECTS)) {
            str = buildCommandForMlflowProjects();
        } else if (this.mlflowParameters.getMlflowTaskType().equals(MlflowConstants.MLFLOW_TASK_TYPE_MODELS)) {
            str = buildCommandForMlflowModels();
        }
        log.info("mlflow task command: \n{}", str);
        return str;
    }

    private String buildCommandForMlflowProjects() {
        String format;
        ArrayList arrayList = new ArrayList();
        arrayList.add(String.format(MlflowConstants.EXPORT_MLFLOW_TRACKING_URI_ENV, this.mlflowParameters.getMlflowTrackingUri()));
        String versionString = this.mlflowParameters.isCustomProject().booleanValue() ? getVersionString(this.mlflowParameters.getMlflowProjectVersion(), this.mlflowParameters.getMlflowProjectRepository()) : getVersionString(getPresetRepositoryVersion(), getPresetRepository());
        String mlflowJobType = this.mlflowParameters.getMlflowJobType();
        boolean z = -1;
        switch (mlflowJobType.hashCode()) {
            case 515350529:
                if (mlflowJobType.equals(MlflowConstants.JOB_TYPE_BASIC_ALGORITHM)) {
                    z = false;
                    break;
                }
                break;
            case 1507063368:
                if (mlflowJobType.equals(MlflowConstants.JOB_TYPE_CUSTOM_PROJECT)) {
                    z = 2;
                    break;
                }
                break;
            case 1972511662:
                if (mlflowJobType.equals(MlflowConstants.JOB_TYPE_AUTOML)) {
                    z = true;
                    break;
                }
                break;
        }
        switch (z) {
            case false:
                arrayList.add(String.format(MlflowConstants.SET_DATA_PATH, this.mlflowParameters.getDataPath()));
                arrayList.add(String.format(MlflowConstants.SET_REPOSITORY, getPresetRepository() + MlflowConstants.PRESET_BASIC_ALGORITHM_PROJECT));
                format = String.format(MlflowConstants.MLFLOW_RUN_BASIC_ALGORITHM, this.mlflowParameters.getAlgorithm(), this.mlflowParameters.getParams(), this.mlflowParameters.getSearchParams(), this.mlflowParameters.getModelName(), this.mlflowParameters.getExperimentName());
                break;
            case true:
                arrayList.add(String.format(MlflowConstants.SET_DATA_PATH, this.mlflowParameters.getDataPath()));
                arrayList.add(String.format(MlflowConstants.SET_REPOSITORY, getPresetRepository() + MlflowConstants.PRESET_AUTOML_PROJECT));
                format = String.format(MlflowConstants.MLFLOW_RUN_AUTOML_PROJECT, this.mlflowParameters.getAutomlTool(), this.mlflowParameters.getParams(), this.mlflowParameters.getModelName(), this.mlflowParameters.getExperimentName());
                break;
            case true:
                arrayList.add(String.format(MlflowConstants.SET_REPOSITORY, this.mlflowParameters.getMlflowProjectRepository()));
                format = String.format(MlflowConstants.MLFLOW_RUN_CUSTOM_PROJECT, this.mlflowParameters.getParams(), this.mlflowParameters.getExperimentName());
                break;
            default:
                throw new TaskException("Unsupported mlflow job type: " + this.mlflowParameters.getMlflowJobType());
        }
        if (StringUtils.isNotEmpty(versionString)) {
            format = format + " " + versionString;
        }
        arrayList.add(format);
        return (String) arrayList.stream().collect(Collectors.joining("\n"));
    }

    protected String buildCommandForMlflowModels() {
        ArrayList arrayList = new ArrayList();
        arrayList.add(String.format(MlflowConstants.EXPORT_MLFLOW_TRACKING_URI_ENV, this.mlflowParameters.getMlflowTrackingUri()));
        String deployModelKey = this.mlflowParameters.getDeployModelKey();
        if (this.mlflowParameters.getDeployType().equals(MlflowConstants.MLFLOW_MODELS_DEPLOY_TYPE_MLFLOW)) {
            arrayList.add(String.format(MlflowConstants.MLFLOW_MODELS_SERVE, deployModelKey, this.mlflowParameters.getDeployPort()));
        } else if (this.mlflowParameters.getDeployType().equals(MlflowConstants.MLFLOW_MODELS_DEPLOY_TYPE_DOCKER)) {
            String str = "mlflow/" + this.mlflowParameters.getModelKeyName(":");
            String containerName = this.mlflowParameters.getContainerName();
            arrayList.add(String.format(MlflowConstants.MLFLOW_BUILD_DOCKER, deployModelKey, str));
            arrayList.add(String.format(MlflowConstants.DOCKER_RREMOVE_CONTAINER, containerName));
            arrayList.add(String.format(MlflowConstants.DOCKER_RUN, containerName, this.mlflowParameters.getDeployPort(), str));
        }
        return (String) arrayList.stream().collect(Collectors.joining("\n"));
    }

    private Map<String, Property> getParamsMap() {
        return this.taskExecutionContext.getPrepareParamsMap();
    }

    public int checkDockerHealth() {
        String format;
        log.info("checking container healthy ... ");
        String[] strArr = {"sh", "-c", String.format(MlflowConstants.DOCKER_HEALTH_CHECK, this.mlflowParameters.getContainerName())};
        int i = 0;
        while (true) {
            int i2 = i;
            if (i2 >= 20) {
                log.info("health check fail");
                return -1;
            }
            try {
                format = OSUtils.exeShell(strArr).replace("\n", "").replace("\"", "");
            } catch (Exception e) {
                format = String.format("error --- %s", e.getMessage());
            }
            log.info("container healthy status: {}", format);
            if (format.equals("healthy")) {
                log.info("container is healthy");
                return 0;
            }
            log.info("The health check has been running for {} seconds", Integer.valueOf((i2 * MlflowConstants.DOCKER_HEALTH_CHECK_INTERVAL) / 1000));
            ThreadUtils.sleep(5000L);
            i = i2 + 1;
        }
    }

    /* renamed from: getParameters, reason: merged with bridge method [inline-methods] */
    public MlflowParameters m1getParameters() {
        return this.mlflowParameters;
    }
}
