package ai.libs.mlplan.core;

import ai.libs.jaicore.basic.FileUtil;
import ai.libs.jaicore.basic.SystemRequirementsNotMetException;
import ai.libs.jaicore.basic.sets.SetUtil;
import ai.libs.jaicore.ml.evaluation.evaluators.weka.factory.MonteCarloCrossValidationEvaluatorFactory;
import ai.libs.jaicore.ml.evaluation.evaluators.weka.splitevaluation.SimpleSLCSplitBasedClassifierEvaluator;
import ai.libs.jaicore.ml.weka.dataset.splitter.IDatasetSplitter;
import ai.libs.jaicore.ml.weka.dataset.splitter.MulticlassClassStratifiedSplitter;
import ai.libs.mlplan.multiclass.wekamlplan.IClassifierFactory;
import ai.libs.mlplan.multiclass.wekamlplan.sklearn.SKLearnClassifierFactory;
import java.io.BufferedReader;
import java.io.File;
import java.io.IOException;
import java.io.InputStreamReader;
import java.util.Arrays;
import java.util.LinkedList;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:ai/libs/mlplan/core/MLPlanSKLearnBuilder.class */
public class MLPlanSKLearnBuilder extends AbstractMLPlanSingleLabelBuilder {
    private Logger logger;
    private static final String PYTHON_MINIMUM_REQUIRED_VERSION = "Python 3.5.0";
    private static final int PYTHON_MINIMUM_REQUIRED_VERSION_REL = 3;
    private static final int PYTHON_MINIMUM_REQUIRED_VERSION_MIN = 0;
    private static final String PYTHON_MODULE_NOT_FOUND_ERROR_MSG = "ModuleNotFoundError";
    private static final String RES_SKLEARN_UL_SEARCHSPACE_CONFIG = "automl/searchmodels/sklearn/ml-plan-ul.json";
    private static final String DEF_REQUESTED_HASCO_INTERFACE = "AbstractClassifier";
    private static final String[] PYTHON_REQUIRED_MODULES = {"numpy", "json", "pickle", "os", "sys", "warnings", "scipy.io.arff", "sklearn"};
    private static final String COMMAND_PYTHON = "python";
    private static final String[] COMMAND_PYTHON_VERSION = {COMMAND_PYTHON, "--version"};
    private static final String[] COMMAND_PYTHON_EXEC = {COMMAND_PYTHON, "-c"};
    private static final IDatasetSplitter DEF_SELECTION_HOLDOUT_SPLITTER = new MulticlassClassStratifiedSplitter();
    private static final IClassifierFactory DEF_CLASSIFIER_FACTORY = new SKLearnClassifierFactory();
    private static final String RES_SKLEARN_SEARCHSPACE_CONFIG = "automl/searchmodels/sklearn/sklearn-mlplan.json";
    private static final String FS_SEARCH_SPACE_CONFIG = "conf/mlplan-sklearn.json";
    private static final File DEF_SEARCH_SPACE_CONFIG = FileUtil.getExistingFileWithHighestPriority(RES_SKLEARN_SEARCHSPACE_CONFIG, new String[]{FS_SEARCH_SPACE_CONFIG});
    private static final String RES_SKLEARN_PREFERRED_COMPONENTS = "mlplan/sklearn-preferenceList.txt";
    private static final String FS_SKLEARN_PREFERRED_COMPONENTS = "conf/sklearn-preferenceList.txt";
    private static final File DEF_PREFERRED_COMPONENTS = FileUtil.getExistingFileWithHighestPriority(RES_SKLEARN_PREFERRED_COMPONENTS, new String[]{FS_SKLEARN_PREFERRED_COMPONENTS});
    private static final int PYTHON_MINIMUM_REQUIRED_VERSION_MAJ = 5;
    private static final MonteCarloCrossValidationEvaluatorFactory DEF_SEARCH_PHASE_EVALUATOR = new MonteCarloCrossValidationEvaluatorFactory().withNumMCIterations(PYTHON_MINIMUM_REQUIRED_VERSION_MAJ).withTrainFoldSize(0.7d).withSplitBasedEvaluator(new SimpleSLCSplitBasedClassifierEvaluator(LOSS_FUNCTION)).withDatasetSplitter(new MulticlassClassStratifiedSplitter());
    private static final MonteCarloCrossValidationEvaluatorFactory DEF_SELECTION_PHASE_EVALUATOR = new MonteCarloCrossValidationEvaluatorFactory().withNumMCIterations(PYTHON_MINIMUM_REQUIRED_VERSION_MAJ).withTrainFoldSize(0.7d).withSplitBasedEvaluator(new SimpleSLCSplitBasedClassifierEvaluator(LOSS_FUNCTION)).withDatasetSplitter(new MulticlassClassStratifiedSplitter());

    public MLPlanSKLearnBuilder() throws IOException {
        this(false);
    }

    public MLPlanSKLearnBuilder(boolean z) throws IOException {
        this.logger = LoggerFactory.getLogger(MLPlanSKLearnBuilder.class);
        if (!z) {
            checkPythonSetup();
        }
        withSearchSpaceConfigFile(DEF_SEARCH_SPACE_CONFIG);
        withPreferredComponentsFile(DEF_PREFERRED_COMPONENTS);
        withRequestedInterface(DEF_REQUESTED_HASCO_INTERFACE);
        withClassifierFactory(DEF_CLASSIFIER_FACTORY);
        withDatasetSplitterForSearchSelectionSplit(DEF_SELECTION_HOLDOUT_SPLITTER);
        withSearchPhaseEvaluatorFactory(DEF_SEARCH_PHASE_EVALUATOR);
        withSelectionPhaseEvaluatorFactory(DEF_SELECTION_PHASE_EVALUATOR);
        setPerformanceMeasureName(LOSS_FUNCTION.getClass().getSimpleName());
    }

    public MLPlanSKLearnBuilder withUnlimitedLengthPipelineSearchSpace() throws IOException {
        return (MLPlanSKLearnBuilder) withSearchSpaceConfigFile(FileUtil.getExistingFileWithHighestPriority(RES_SKLEARN_UL_SEARCHSPACE_CONFIG, new String[]{FS_SEARCH_SPACE_CONFIG}));
    }

    private void checkPythonSetup() {
        try {
            Process start = new ProcessBuilder(new String[PYTHON_MINIMUM_REQUIRED_VERSION_MIN]).command(COMMAND_PYTHON_VERSION).start();
            StringBuilder sb = new StringBuilder();
            BufferedReader bufferedReader = new BufferedReader(new InputStreamReader(start.getInputStream()));
            Throwable th = PYTHON_MINIMUM_REQUIRED_VERSION_MIN;
            while (true) {
                try {
                    try {
                        String readLine = bufferedReader.readLine();
                        if (readLine == null) {
                            break;
                        } else {
                            sb.append(readLine);
                        }
                    } finally {
                    }
                } finally {
                }
            }
            if (bufferedReader != null) {
                if (th != null) {
                    try {
                        bufferedReader.close();
                    } catch (Throwable th2) {
                        th.addSuppressed(th2);
                    }
                } else {
                    bufferedReader.close();
                }
            }
            String sb2 = sb.toString();
            if (!sb2.startsWith("Python ")) {
                throw new SystemRequirementsNotMetException("Could not detect valid python version.");
            }
            String[] split = sb2.substring(7).split("\\.");
            if (split.length != PYTHON_MINIMUM_REQUIRED_VERSION_REL) {
                throw new SystemRequirementsNotMetException("Could not parse python version to be of the shape X.X.X");
            }
            if (!isValidVersion(Integer.parseInt(split[PYTHON_MINIMUM_REQUIRED_VERSION_MIN]), Integer.parseInt(split[1]), Integer.parseInt(split[2]))) {
                throw new SystemRequirementsNotMetException("Python version does not conform the minimum required python version of Python 3.5.0");
            }
            LinkedList linkedList = new LinkedList(Arrays.asList(COMMAND_PYTHON_EXEC));
            StringBuilder sb3 = new StringBuilder();
            String[] strArr = PYTHON_REQUIRED_MODULES;
            int length = strArr.length;
            for (int i = PYTHON_MINIMUM_REQUIRED_VERSION_MIN; i < length; i++) {
                String str = strArr[i];
                if (!sb3.toString().isEmpty()) {
                    sb3.append(";");
                }
                sb3.append("import " + str);
            }
            linkedList.add(sb3.toString());
            StringBuilder sb4 = new StringBuilder();
            bufferedReader = new BufferedReader(new InputStreamReader(new ProcessBuilder(new String[PYTHON_MINIMUM_REQUIRED_VERSION_MIN]).command((String[]) linkedList.toArray(new String[PYTHON_MINIMUM_REQUIRED_VERSION_MIN])).start().getErrorStream()));
            Throwable th3 = PYTHON_MINIMUM_REQUIRED_VERSION_MIN;
            while (true) {
                try {
                    try {
                        String readLine2 = bufferedReader.readLine();
                        if (readLine2 == null) {
                            break;
                        } else {
                            sb4.append(readLine2);
                        }
                    } finally {
                    }
                } finally {
                }
            }
            if (bufferedReader != null) {
                if (th3 != null) {
                    try {
                        bufferedReader.close();
                    } catch (Throwable th4) {
                        th3.addSuppressed(th4);
                    }
                } else {
                    bufferedReader.close();
                }
            }
            if (!sb4.toString().isEmpty()) {
                LinkedList linkedList2 = new LinkedList();
                String[] strArr2 = PYTHON_REQUIRED_MODULES;
                int length2 = strArr2.length;
                for (int i2 = PYTHON_MINIMUM_REQUIRED_VERSION_MIN; i2 < length2; i2++) {
                    String str2 = strArr2[i2];
                    Process start2 = new ProcessBuilder(new String[PYTHON_MINIMUM_REQUIRED_VERSION_MIN]).command(COMMAND_PYTHON_EXEC[PYTHON_MINIMUM_REQUIRED_VERSION_MIN], COMMAND_PYTHON_EXEC[1], "import " + str2).start();
                    StringBuilder sb5 = new StringBuilder();
                    BufferedReader bufferedReader2 = new BufferedReader(new InputStreamReader(start2.getErrorStream()));
                    Throwable th5 = PYTHON_MINIMUM_REQUIRED_VERSION_MIN;
                    while (true) {
                        try {
                            try {
                                String readLine3 = bufferedReader2.readLine();
                                if (readLine3 == null) {
                                    break;
                                } else {
                                    sb5.append(readLine3);
                                }
                            } finally {
                            }
                        } finally {
                            if (bufferedReader2 != null) {
                                if (th5 != null) {
                                    try {
                                        bufferedReader2.close();
                                    } catch (Throwable th6) {
                                        th5.addSuppressed(th6);
                                    }
                                } else {
                                    bufferedReader2.close();
                                }
                            }
                        }
                    }
                    if (bufferedReader2 != null) {
                        if (th5 != null) {
                            try {
                                bufferedReader2.close();
                            } catch (Throwable th7) {
                                th5.addSuppressed(th7);
                            }
                        } else {
                            bufferedReader2.close();
                        }
                    }
                    if (!sb5.toString().isEmpty() && sb5.toString().contains(PYTHON_MODULE_NOT_FOUND_ERROR_MSG)) {
                        this.logger.debug("Could not load python module {}: {}", str2, sb5);
                        linkedList2.add(str2);
                    }
                }
                if (!linkedList2.isEmpty()) {
                    throw new SystemRequirementsNotMetException("Could not find required python modules: " + SetUtil.implode(linkedList2, ", "));
                }
            }
        } catch (IOException e) {
            throw new SystemRequirementsNotMetException("Could not check whether python is installed in the required version. Is python available as a command on your command line?");
        }
    }

    private boolean isValidVersion(int i, int i2, int i3) {
        if (i > PYTHON_MINIMUM_REQUIRED_VERSION_REL) {
            return true;
        }
        if (i != PYTHON_MINIMUM_REQUIRED_VERSION_REL || i2 <= PYTHON_MINIMUM_REQUIRED_VERSION_MAJ) {
            return i == PYTHON_MINIMUM_REQUIRED_VERSION_REL && i2 == PYTHON_MINIMUM_REQUIRED_VERSION_MAJ && i3 >= 0;
        }
        return true;
    }

    @Override // ai.libs.mlplan.core.AbstractMLPlanSingleLabelBuilder
    protected IDatasetSplitter getDefaultDatasetSplitter() {
        return new MulticlassClassStratifiedSplitter();
    }
}
