package ai.libs.jaicore.ml.scikitwrapper;

import ai.libs.jaicore.ml.core.dataset.DatasetUtil;
import ai.libs.jaicore.processes.EOperatingSystem;
import ai.libs.jaicore.processes.ProcessUtil;
import ai.libs.python.IPythonConfig;
import ai.libs.python.PythonUtil;
import java.io.File;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import java.util.Objects;
import java.util.StringJoiner;
import org.api4.java.algorithm.Timeout;
import org.slf4j.Logger;

/* loaded from: input_file:ai/libs/jaicore/ml/scikitwrapper/ScikitLearnWrapperCommandBuilder.class */
public class ScikitLearnWrapperCommandBuilder {
    private Logger logger;
    private static final String PROBLEM_FLAG = "--problem";
    private static final String MODE_FLAG = "--mode";
    private static final String MODEL_FLAG = "--model";
    private static final String FIT_DATA_FLAG = "--fit";
    private static final String FIT_OUTPUT_FLAG = "--fitOutput";
    private static final String PREDICT_DATA_FLAG = "--predict";
    private static final String PREDICT_OUTPUT_FLAG = "--predictOutput";
    private static final String TARGETS_FLAG = "--targets";
    private static final String SEED_FLAG = "--seed";
    private IPythonConfig pythonConfiguration;
    private String problemTypeFlag;
    private File scriptFile;
    private EWrapperExecutionMode executionMode;
    protected String modelFile;
    protected String fitDataFile;
    protected String fitOutputFile;
    protected String predictDataFile;
    protected String predictOutputFile;
    private int[] targetIndices;
    private long seed;
    private Timeout timeout;
    protected List<String> additionalParameters;

    /* JADX INFO: Access modifiers changed from: package-private */
    /* renamed from: ai.libs.jaicore.ml.scikitwrapper.ScikitLearnWrapperCommandBuilder$1, reason: invalid class name */
    /* loaded from: input_file:ai/libs/jaicore/ml/scikitwrapper/ScikitLearnWrapperCommandBuilder$1.class */
    public static /* synthetic */ class AnonymousClass1 {
        static final /* synthetic */ int[] $SwitchMap$ai$libs$jaicore$ml$scikitwrapper$ScikitLearnWrapperCommandBuilder$EWrapperExecutionMode = new int[EWrapperExecutionMode.values().length];

        static {
            try {
                $SwitchMap$ai$libs$jaicore$ml$scikitwrapper$ScikitLearnWrapperCommandBuilder$EWrapperExecutionMode[EWrapperExecutionMode.FIT.ordinal()] = 1;
            } catch (NoSuchFieldError e) {
            }
            try {
                $SwitchMap$ai$libs$jaicore$ml$scikitwrapper$ScikitLearnWrapperCommandBuilder$EWrapperExecutionMode[EWrapperExecutionMode.PREDICT.ordinal()] = 2;
            } catch (NoSuchFieldError e2) {
            }
            try {
                $SwitchMap$ai$libs$jaicore$ml$scikitwrapper$ScikitLearnWrapperCommandBuilder$EWrapperExecutionMode[EWrapperExecutionMode.FIT_AND_PREDICT.ordinal()] = 3;
            } catch (NoSuchFieldError e3) {
            }
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:ai/libs/jaicore/ml/scikitwrapper/ScikitLearnWrapperCommandBuilder$EWrapperExecutionMode.class */
    public enum EWrapperExecutionMode {
        FIT("fit"),
        PREDICT("predict"),
        FIT_AND_PREDICT("fitAndPredict");

        private String name;

        EWrapperExecutionMode(String str) {
            this.name = str;
        }

        @Override // java.lang.Enum
        public String toString() {
            return this.name;
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public ScikitLearnWrapperCommandBuilder(String str, File file) {
        this.problemTypeFlag = str;
        this.scriptFile = file;
    }

    public ScikitLearnWrapperCommandBuilder withPythonConfig(IPythonConfig iPythonConfig) {
        this.pythonConfiguration = iPythonConfig;
        return this;
    }

    public ScikitLearnWrapperCommandBuilder withLogger(Logger logger) {
        this.logger = logger;
        return this;
    }

    public ScikitLearnWrapperCommandBuilder withScriptFile(File file) {
        this.scriptFile = file;
        return this;
    }

    private ScikitLearnWrapperCommandBuilder withMode(EWrapperExecutionMode eWrapperExecutionMode) {
        this.executionMode = eWrapperExecutionMode;
        return this;
    }

    public ScikitLearnWrapperCommandBuilder withFitMode() {
        return withMode(EWrapperExecutionMode.FIT);
    }

    public ScikitLearnWrapperCommandBuilder withPredictMode() {
        return withMode(EWrapperExecutionMode.PREDICT);
    }

    public ScikitLearnWrapperCommandBuilder withFitAndPredictMode() {
        return withMode(EWrapperExecutionMode.FIT_AND_PREDICT);
    }

    public ScikitLearnWrapperCommandBuilder withModelFile(File file) {
        this.modelFile = file.getAbsolutePath();
        return this;
    }

    public ScikitLearnWrapperCommandBuilder withFitDataFile(File file) {
        if (!file.getAbsoluteFile().exists()) {
            throw new IllegalArgumentException("Data file does not exist: " + file.getAbsolutePath());
        }
        this.fitDataFile = file.getAbsolutePath();
        return this;
    }

    public ScikitLearnWrapperCommandBuilder withFitOutputFile(File file) {
        this.fitOutputFile = file.getAbsolutePath();
        return this;
    }

    public ScikitLearnWrapperCommandBuilder withPredictDataFile(File file) {
        this.predictDataFile = file.getAbsolutePath();
        return this;
    }

    public ScikitLearnWrapperCommandBuilder withPredictOutputFile(File file) {
        this.predictOutputFile = file.getAbsolutePath();
        return this;
    }

    public ScikitLearnWrapperCommandBuilder withTargetIndices(int... iArr) {
        this.targetIndices = iArr;
        return this;
    }

    public ScikitLearnWrapperCommandBuilder withSeed(long j) {
        this.seed = j;
        return this;
    }

    public ScikitLearnWrapperCommandBuilder withTimeout(Timeout timeout) {
        this.timeout = timeout;
        return this;
    }

    public ScikitLearnWrapperCommandBuilder withAdditionalCommandLineParameters(List<String> list) {
        this.additionalParameters = list;
        return this;
    }

    public void checkRequirements() {
        if (!this.scriptFile.exists()) {
            throw new IllegalArgumentException("The wrapped sklearn script " + this.scriptFile.getAbsolutePath() + " file does not exist");
        }
        Objects.requireNonNull(this.problemTypeFlag);
        Objects.requireNonNull(this.executionMode);
        switch (AnonymousClass1.$SwitchMap$ai$libs$jaicore$ml$scikitwrapper$ScikitLearnWrapperCommandBuilder$EWrapperExecutionMode[this.executionMode.ordinal()]) {
            case DatasetUtil.EXPANSION_SQUARES /* 1 */:
                checkRequirementsTrainMode();
                return;
            case DatasetUtil.EXPANSION_LOGARITHM /* 2 */:
                checkRequirementsTestMode();
                return;
            case 3:
                checkRequirementsTrainTestMode();
                return;
            default:
                return;
        }
    }

    protected void checkRequirementsTrainMode() {
        Objects.requireNonNull(this.fitDataFile);
        Objects.requireNonNull(this.modelFile);
        Objects.requireNonNull(this.targetIndices);
    }

    protected void checkRequirementsTestMode() {
        Objects.requireNonNull(this.modelFile);
        Objects.requireNonNull(this.predictDataFile);
        Objects.requireNonNull(this.predictOutputFile);
        Objects.requireNonNull(this.targetIndices);
    }

    protected void checkRequirementsTrainTestMode() {
        Objects.requireNonNull(this.fitDataFile);
        Objects.requireNonNull(this.predictDataFile);
        Objects.requireNonNull(this.predictOutputFile);
        Objects.requireNonNull(this.targetIndices);
    }

    public String[] toCommandArray() {
        checkRequirements();
        ArrayList arrayList = new ArrayList();
        EOperatingSystem os = ProcessUtil.getOS();
        if (this.timeout != null && os == EOperatingSystem.LINUX) {
            this.logger.info("Executing with timeout {}s", Long.valueOf(this.timeout.seconds()));
            arrayList.add("timeout");
            arrayList.add((this.timeout.seconds() - 2) + "");
        }
        arrayList.add("-u");
        arrayList.add(this.scriptFile.getAbsolutePath());
        arrayList.addAll(Arrays.asList(PROBLEM_FLAG, this.problemTypeFlag));
        arrayList.addAll(Arrays.asList(MODE_FLAG, this.executionMode.toString()));
        if (this.modelFile != null) {
            arrayList.addAll(Arrays.asList(MODEL_FLAG, this.modelFile));
        }
        if (this.fitDataFile != null) {
            arrayList.addAll(Arrays.asList(FIT_DATA_FLAG, this.fitDataFile));
        }
        if (this.fitOutputFile != null) {
            arrayList.addAll(Arrays.asList(FIT_OUTPUT_FLAG, this.fitOutputFile));
        }
        if (this.predictDataFile != null) {
            arrayList.addAll(Arrays.asList(PREDICT_DATA_FLAG, this.predictDataFile));
        }
        if (this.predictOutputFile != null) {
            arrayList.addAll(Arrays.asList(PREDICT_OUTPUT_FLAG, this.predictOutputFile));
        }
        arrayList.addAll(Arrays.asList(SEED_FLAG, String.valueOf(this.seed)));
        if (this.targetIndices != null && this.targetIndices.length > 0) {
            arrayList.addAll(Arrays.asList(TARGETS_FLAG, Arrays.toString(this.targetIndices).replaceAll("\\s+", "")));
        }
        if (this.additionalParameters != null) {
            arrayList.addAll(this.additionalParameters);
        }
        StringJoiner stringJoiner = new StringJoiner(" ");
        Iterator it = arrayList.iterator();
        while (it.hasNext()) {
            stringJoiner.add((String) it.next());
        }
        return new PythonUtil(this.pythonConfiguration).getExecutableCommandArray(false, new String[]{stringJoiner.toString()});
    }
}
