package org.apache.dolphinscheduler.plugin.task.pytorch;

import java.util.ArrayList;
import java.util.stream.Collectors;
import lombok.Generated;
import org.apache.dolphinscheduler.common.utils.JSONUtils;
import org.apache.dolphinscheduler.plugin.task.api.AbstractTask;
import org.apache.dolphinscheduler.plugin.task.api.ShellCommandExecutor;
import org.apache.dolphinscheduler.plugin.task.api.TaskCallBack;
import org.apache.dolphinscheduler.plugin.task.api.TaskException;
import org.apache.dolphinscheduler.plugin.task.api.TaskExecutionContext;
import org.apache.dolphinscheduler.plugin.task.api.model.TaskResponse;
import org.apache.dolphinscheduler.plugin.task.api.parameters.AbstractParameters;
import org.apache.dolphinscheduler.plugin.task.api.shell.ShellInterceptorBuilderFactory;
import org.apache.dolphinscheduler.plugin.task.api.utils.ParameterUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/apache/dolphinscheduler/plugin/task/pytorch/PytorchTask.class */
public class PytorchTask extends AbstractTask {

    @Generated
    private static final Logger log = LoggerFactory.getLogger(PytorchTask.class);
    private final ShellCommandExecutor shellCommandExecutor;
    protected PytorchParameters pytorchParameters;
    protected TaskExecutionContext taskExecutionContext;
    private PythonEnvManager pythonEnvManager;

    public PytorchTask(TaskExecutionContext taskExecutionContext) {
        super(taskExecutionContext);
        this.taskExecutionContext = taskExecutionContext;
        this.shellCommandExecutor = new ShellCommandExecutor(this::logHandle, taskExecutionContext);
    }

    public void init() {
        this.pytorchParameters = (PytorchParameters) JSONUtils.parseObject(this.taskExecutionContext.getTaskParams(), PytorchParameters.class);
        log.info("Initialize pytorch task params {}", JSONUtils.toPrettyJsonString(this.taskExecutionContext));
        if (this.pytorchParameters == null || !this.pytorchParameters.checkParameters()) {
            throw new TaskException("python task params is not valid");
        }
        this.pythonEnvManager = new PythonEnvManager();
        this.pythonEnvManager.setPythonEnvTool(this.pytorchParameters.getPythonEnvTool());
        this.pythonEnvManager.setCondaPythonVersion(this.pytorchParameters.getCondaPythonVersion());
    }

    public void handle(TaskCallBack taskCallBack) throws TaskException {
        try {
            TaskResponse run = this.shellCommandExecutor.run(ShellInterceptorBuilderFactory.newBuilder().properties(ParameterUtils.convert(this.taskExecutionContext.getPrepareParamsMap())).appendScript(buildPythonExecuteCommand()), taskCallBack);
            setExitStatusCode(run.getExitStatusCode());
            setProcessId(run.getProcessId());
            setTaskOutputParams(this.shellCommandExecutor.getTaskOutputParams());
        } catch (InterruptedException e) {
            Thread.currentThread().interrupt();
            log.error("The current Pytorch task has been interrupted", e);
            setExitStatusCode(-1);
            throw new TaskException("The current Pytorch task has been interrupted", e);
        } catch (Exception e2) {
            setExitStatusCode(-1);
            throw new TaskException("Pytorch task execute failed", e2);
        }
    }

    public void cancel() throws TaskException {
    }

    public String buildPythonExecuteCommand() throws Exception {
        ArrayList arrayList = new ArrayList();
        String pythonPath = this.pytorchParameters.getPythonPath();
        if (GitProjectManager.isGitPath(pythonPath)) {
            GitProjectManager gitProjectManager = new GitProjectManager();
            gitProjectManager.setPath(pythonPath);
            gitProjectManager.setBaseDir(this.taskExecutionContext.getExecutePath());
            gitProjectManager.prepareProject();
            this.pytorchParameters.setPythonPath(gitProjectManager.getGitLocalPath());
        }
        arrayList.add(String.format("export PYTHONPATH=%s", this.pytorchParameters.getPythonPath()));
        if (this.pytorchParameters.getIsCreateEnvironment().booleanValue()) {
            arrayList.add(this.pythonEnvManager.getBuildEnvCommand(this.pytorchParameters.getRequirementPath()));
        }
        String scriptParams = this.pytorchParameters.getScriptParams();
        if (scriptParams == null || scriptParams.isEmpty()) {
            arrayList.add(String.format("%s %s", getPythonCommand(), this.pytorchParameters.getScriptPath()));
        } else {
            arrayList.add(String.format("%s %s %s", getPythonCommand(), this.pytorchParameters.getScriptPath(), this.pytorchParameters.getScriptParams()));
        }
        return (String) arrayList.stream().collect(Collectors.joining("\n"));
    }

    private String getPythonCommand() {
        return this.pytorchParameters.getIsCreateEnvironment().booleanValue() ? this.pythonEnvManager.getPythonCommand() : this.pytorchParameters.getPythonLauncher();
    }

    public AbstractParameters getParameters() {
        return this.pytorchParameters;
    }
}
