package ai.libs.jaicore.ml.scikitwrapper.simple;

import ai.libs.jaicore.basic.ResourceUtil;
import ai.libs.jaicore.ml.core.dataset.serialization.ArffDatasetAdapter;
import ai.libs.jaicore.ml.core.learner.ASupervisedLearner;
import ai.libs.jaicore.ml.scikitwrapper.IScikitLearnWrapper;
import ai.libs.jaicore.ml.scikitwrapper.IScikitLearnWrapperConfig;
import ai.libs.jaicore.ml.scikitwrapper.ScikitLearnWrapperExecutionFailedException;
import ai.libs.python.IPythonConfig;
import ai.libs.python.PythonRequirementDefinition;
import ai.libs.python.PythonUtil;
import java.io.BufferedWriter;
import java.io.File;
import java.io.FileWriter;
import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.attribute.FileAttribute;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Map;
import java.util.Objects;
import java.util.stream.Stream;
import org.aeonbits.owner.ConfigFactory;
import org.api4.java.ai.ml.core.dataset.supervised.ILabeledDataset;
import org.api4.java.ai.ml.core.dataset.supervised.ILabeledInstance;
import org.api4.java.ai.ml.core.evaluation.IPrediction;
import org.api4.java.ai.ml.core.evaluation.IPredictionBatch;
import org.api4.java.ai.ml.core.exception.DatasetCreationException;
import org.api4.java.ai.ml.core.exception.PredictionException;
import org.api4.java.ai.ml.core.exception.TrainingException;
import org.api4.java.algorithm.Timeout;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:ai/libs/jaicore/ml/scikitwrapper/simple/ASimpleScikitLearnWrapper.class */
public abstract class ASimpleScikitLearnWrapper<P extends IPrediction, B extends IPredictionBatch> extends ASupervisedLearner<ILabeledInstance, ILabeledDataset<? extends ILabeledInstance>, P, B> implements IScikitLearnWrapper {
    private Logger logger;
    private static final String LOG_SERIALIZATION_NOT_IMPLEMENTED = "The simple scikit-learn classifier wrapper does not support model serialization.";
    public static final int PYTHON_MINIMUM_REQUIRED_VERSION_REL = 3;
    public static final int PYTHON_MINIMUM_REQUIRED_VERSION_MAJ = 5;
    public static final int PYTHON_MINIMUM_REQUIRED_VERSION_MIN = 0;
    protected static final String[] PYTHON_REQUIRED_MODULES = {"arff", "numpy", "json", "pickle", "os", "sys", "warnings", "scipy", "sklearn", "pandas"};
    protected static final String[] PYTHON_OPTIONAL_MODULES = new String[0];
    private static Boolean pythonRequirementsFulfilled = null;
    private static File tempDir = null;
    private String pathExecutableTemplate;
    protected IScikitLearnWrapperConfig sklearnClassifierConfig;
    protected IPythonConfig pythonC;
    private PythonUtil putil;
    protected final String problem;
    protected final String constructorCall;
    protected final String imports;
    private File executable;
    private File outputFile;
    protected ILabeledDataset<? extends ILabeledInstance> trainingData;

    /* JADX INFO: Access modifiers changed from: protected */
    public ASimpleScikitLearnWrapper(String str, String str2, String str3) throws IOException, InterruptedException {
        this(str, str2, str3, ConfigFactory.create(IPythonConfig.class, new Map[0]));
    }

    protected ASimpleScikitLearnWrapper(String str, String str2, String str3, IPythonConfig iPythonConfig) throws IOException, InterruptedException {
        this.logger = LoggerFactory.getLogger(ASimpleScikitLearnWrapper.class);
        this.pathExecutableTemplate = "sklearn/sklearn_template_windows.twig.py";
        this.sklearnClassifierConfig = ConfigFactory.create(IScikitLearnWrapperConfig.class, new Map[0]);
        this.executable = null;
        this.outputFile = null;
        this.constructorCall = str;
        this.imports = str2;
        this.problem = str3;
        setPythonConfig(iPythonConfig);
    }

    private synchronized void ensurePythonRequirementsAreSatisfied() throws InterruptedException {
        if (pythonRequirementsFulfilled == null) {
            new PythonRequirementDefinition(3, 5, 0, PYTHON_REQUIRED_MODULES, PYTHON_OPTIONAL_MODULES).check(this.pythonC);
        }
        pythonRequirementsFulfilled = true;
    }

    public void fit(ILabeledDataset<? extends ILabeledInstance> iLabeledDataset) throws TrainingException, InterruptedException {
        this.trainingData = iLabeledDataset;
    }

    public String getLoggerName() {
        return this.logger.getName();
    }

    public void setLoggerName(String str) {
        this.logger = LoggerFactory.getLogger(str);
    }

    private synchronized File getOrWriteDataFile(ILabeledDataset<? extends ILabeledInstance> iLabeledDataset, String str) throws ScikitLearnWrapperExecutionFailedException, IOException {
        File datasetFile = getDatasetFile(str);
        if (datasetFile.exists()) {
            this.logger.debug("Reusing dataset: {}", str);
            return datasetFile;
        }
        try {
            if (this.sklearnClassifierConfig.getDeleteFileOnExit()) {
                datasetFile.deleteOnExit();
            }
            this.logger.debug("Serializing {}x{} dataset to {}", new Object[]{Integer.valueOf(iLabeledDataset.size()), Integer.valueOf(iLabeledDataset.getNumAttributes()), str});
            new ArffDatasetAdapter().serializeDataset(datasetFile, iLabeledDataset);
            this.logger.debug("Serialization completed.");
            return datasetFile;
        } catch (IOException e) {
            throw new ScikitLearnWrapperExecutionFailedException("Could not dump data file for prediction", e);
        }
    }

    private synchronized File getDatasetFile(String str) throws IOException {
        File file = new File(getTempDir(), str + ".arff");
        if (this.sklearnClassifierConfig.getDeleteFileOnExit()) {
            file.deleteOnExit();
        }
        return file;
    }

    private static synchronized File getTempDir() throws IOException {
        if (tempDir == null) {
            tempDir = Files.createTempDirectory("ailibs-dumps", new FileAttribute[0]).toFile();
            tempDir.deleteOnExit();
        }
        return tempDir;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public File executePipeline(ILabeledDataset<? extends ILabeledInstance> iLabeledDataset) throws IOException, InterruptedException, ScikitLearnWrapperExecutionFailedException {
        this.executable = Files.createTempFile("sklearn-classifier-", ".py", new FileAttribute[0]).toFile();
        this.executable.deleteOnExit();
        String replace = ResourceUtil.readResourceFileToString(this.pathExecutableTemplate).replace("{{pipeline}}", this.constructorCall).replace("{{import}}", this.imports);
        BufferedWriter bufferedWriter = new BufferedWriter(new FileWriter(this.executable));
        try {
            bufferedWriter.write(replace);
            bufferedWriter.close();
            this.outputFile = Files.createTempFile("sklearn-predictions", ".json", new FileAttribute[0]).toFile();
            this.outputFile.deleteOnExit();
            File orWriteDataFile = getOrWriteDataFile(this.trainingData, getDataName(this.trainingData));
            File orWriteDataFile2 = getOrWriteDataFile(iLabeledDataset, getDataName(iLabeledDataset));
            ArrayList arrayList = new ArrayList();
            arrayList.add(this.executable.getCanonicalPath());
            arrayList.add("--fit");
            arrayList.add(orWriteDataFile.getCanonicalPath());
            arrayList.add("--predict");
            arrayList.add(orWriteDataFile2.getCanonicalPath());
            arrayList.add("--problem");
            arrayList.add(this.problem);
            arrayList.add("--predictOutput");
            arrayList.add(this.outputFile.getCanonicalPath());
            if (this.putil.executeScriptFile(arrayList) != 0) {
                throw new ScikitLearnWrapperExecutionFailedException("Spawned python process has not terminated cleanly.");
            }
            return this.outputFile;
        } catch (Throwable th) {
            try {
                bufferedWriter.close();
            } catch (Throwable th2) {
                th.addSuppressed(th2);
            }
            throw th;
        }
    }

    @Override // ai.libs.jaicore.ml.scikitwrapper.IScikitLearnWrapper
    public void setModelPath(String str) throws IOException {
        this.logger.debug(LOG_SERIALIZATION_NOT_IMPLEMENTED);
    }

    @Override // ai.libs.jaicore.ml.scikitwrapper.IScikitLearnWrapper
    public File getModelPath() {
        this.logger.debug(LOG_SERIALIZATION_NOT_IMPLEMENTED);
        return null;
    }

    @Override // ai.libs.jaicore.ml.scikitwrapper.IScikitLearnWrapper
    public File getModelFile() {
        this.logger.debug(LOG_SERIALIZATION_NOT_IMPLEMENTED);
        return null;
    }

    @Override // ai.libs.jaicore.ml.scikitwrapper.IScikitLearnWrapper
    public void setTargetIndices(int... iArr) {
        this.logger.debug("The simple scikit-learn classifier wrapper does not support multiple targets.");
    }

    public String toString() {
        return this.constructorCall;
    }

    @Override // ai.libs.jaicore.ml.scikitwrapper.IScikitLearnWrapper
    public void setSeed(long j) {
        this.logger.debug("The simple scikit-learn classifier wrapper does not support setting a seed.");
    }

    @Override // ai.libs.jaicore.ml.scikitwrapper.IScikitLearnWrapper
    public void setTimeout(Timeout timeout) {
        this.logger.debug("The simple scikit-learn classifier wrapper does not support setting a timeout.");
    }

    @Override // ai.libs.jaicore.ml.scikitwrapper.IScikitLearnWrapper
    public void fit(String str) throws TrainingException, InterruptedException {
        this.logger.debug("The simple scikit-learn classifier wrapper does not support fitting providing a path only.");
    }

    @Override // ai.libs.jaicore.ml.scikitwrapper.IScikitLearnWrapper
    public File getOutputFile(String str) {
        this.logger.debug("The simple scikit-learn classifier wrapper does not support retrieving the output file.");
        return this.outputFile;
    }

    @Override // ai.libs.jaicore.ml.scikitwrapper.IScikitLearnWrapper
    public void setPythonTemplate(String str) throws IOException {
        this.pathExecutableTemplate = str;
    }

    @Override // ai.libs.jaicore.ml.scikitwrapper.IScikitLearnWrapper
    public void setPythonConfig(IPythonConfig iPythonConfig) throws IOException, InterruptedException {
        this.pythonC = iPythonConfig;
        this.putil = new PythonUtil(iPythonConfig);
        ensurePythonRequirementsAreSatisfied();
    }

    @Override // ai.libs.jaicore.ml.scikitwrapper.IScikitLearnWrapper
    public void setScikitLearnWrapperConfig(IScikitLearnWrapperConfig iScikitLearnWrapperConfig) {
        this.sklearnClassifierConfig = iScikitLearnWrapperConfig;
    }

    @Override // ai.libs.jaicore.ml.scikitwrapper.IScikitLearnWrapper
    public File getSKLearnScriptFile() {
        return this.executable;
    }

    @Override // ai.libs.jaicore.ml.scikitwrapper.IScikitLearnWrapper
    public String getDataName(ILabeledDataset<? extends ILabeledInstance> iLabeledDataset) {
        String str = "" + iLabeledDataset.hashCode();
        return str.startsWith("-") ? str.replace("-", "1") : "0" + str;
    }

    @Override // ai.libs.jaicore.ml.core.learner.ASupervisedLearner
    public B predict(ILabeledInstance[] iLabeledInstanceArr) throws PredictionException, InterruptedException {
        try {
            ILabeledDataset createEmptyCopy = this.trainingData.createEmptyCopy();
            Stream stream = Arrays.stream(iLabeledInstanceArr);
            Objects.requireNonNull(createEmptyCopy);
            stream.forEach((v1) -> {
                r1.add(v1);
            });
            return predict((ASimpleScikitLearnWrapper<P, B>) createEmptyCopy);
        } catch (InterruptedException e) {
            throw e;
        } catch (DatasetCreationException e2) {
            throw new PredictionException("Could not create empty test dataset copy.", e2);
        }
    }

    @Override // ai.libs.jaicore.ml.core.learner.ASupervisedLearner
    public P predict(ILabeledInstance iLabeledInstance) throws PredictionException, InterruptedException {
        try {
            ILabeledDataset createEmptyCopy = this.trainingData.createEmptyCopy();
            createEmptyCopy.add(iLabeledInstance);
            return (P) predict((ASimpleScikitLearnWrapper<P, B>) createEmptyCopy).get(0);
        } catch (DatasetCreationException e) {
            throw new PredictionException("Could not predict due to a DatasetCreationException", e);
        } catch (InterruptedException e2) {
            throw e2;
        }
    }
}
