package ai.libs.jaicore.ml.scikitwrapper;

import ai.libs.jaicore.basic.FileUtil;
import ai.libs.jaicore.basic.ResourceUtil;
import ai.libs.jaicore.ml.core.EScikitLearnProblemType;
import ai.libs.jaicore.ml.core.dataset.serialization.ArffDatasetAdapter;
import ai.libs.jaicore.ml.core.learner.ASupervisedLearner;
import ai.libs.jaicore.processes.EOperatingSystem;
import ai.libs.jaicore.processes.ProcessIDNotRetrievableException;
import ai.libs.jaicore.processes.ProcessUtil;
import ai.libs.python.DefaultProcessListener;
import ai.libs.python.IPythonConfig;
import ai.libs.python.PythonRequirementDefinition;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.google.common.hash.Hashing;
import java.io.File;
import java.io.FileOutputStream;
import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.nio.file.Files;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.aeonbits.owner.ConfigCache;
import org.aeonbits.owner.ConfigFactory;
import org.apache.commons.lang3.ArrayUtils;
import org.api4.java.ai.ml.core.dataset.supervised.ILabeledDataset;
import org.api4.java.ai.ml.core.dataset.supervised.ILabeledInstance;
import org.api4.java.ai.ml.core.evaluation.IPrediction;
import org.api4.java.ai.ml.core.evaluation.IPredictionBatch;
import org.api4.java.ai.ml.core.exception.DatasetCreationException;
import org.api4.java.ai.ml.core.exception.PredictionException;
import org.api4.java.ai.ml.core.exception.TrainingException;
import org.api4.java.algorithm.Timeout;
import org.jtwig.JtwigModel;
import org.jtwig.JtwigTemplate;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:ai/libs/jaicore/ml/scikitwrapper/AScikitLearnWrapper.class */
public abstract class AScikitLearnWrapper<P extends IPrediction, B extends IPredictionBatch> extends ASupervisedLearner<ILabeledInstance, ILabeledDataset<? extends ILabeledInstance>, P, B> implements IScikitLearnWrapper {
    public static final int PYTHON_MINIMUM_REQUIRED_VERSION_REL = 3;
    public static final int PYTHON_MINIMUM_REQUIRED_VERSION_MAJ = 5;
    public static final int PYTHON_MINIMUM_REQUIRED_VERSION_MIN = 0;
    protected static final String[] PYTHON_REQUIRED_MODULES = {"arff", "numpy", "json", "pickle", "os", "sys", "warnings", "scipy", "sklearn", "pandas"};
    protected static final String[] PYTHON_OPTIONAL_MODULES = new String[0];
    private static final String SCIKIT_LEARN_TEMPLATE = "sklearn/sklearn_template.twig.py";
    private static final String COULD_NOT_RUN_SCIKIT_LEARN_MODEL = "Could not run scikit-learn model.";
    protected IScikitLearnWrapperConfig scikitLearnWrapperConfig;
    protected final String configurationUID;
    protected EScikitLearnProblemType problemType;
    protected String pipeline;
    private String imports;
    protected File modelFile;
    protected ILabeledDataset<ILabeledInstance> data;
    protected long seed;
    protected Timeout timeout;
    private boolean listenToPidFromProcess;
    protected Logger logger = LoggerFactory.getLogger(AScikitLearnWrapper.class);
    protected IPythonConfig pythonConfig = ConfigFactory.create(IPythonConfig.class, new Map[0]);
    protected int[] targetIndices = new int[0];

    /* JADX INFO: Access modifiers changed from: protected */
    public AScikitLearnWrapper(EScikitLearnProblemType eScikitLearnProblemType, String str, String str2) throws IOException, InterruptedException {
        this.problemType = eScikitLearnProblemType;
        this.pipeline = str;
        this.imports = str2;
        String hashCode = Hashing.sha256().hashString(this.pipeline, StandardCharsets.UTF_8).toString();
        this.configurationUID = hashCode.startsWith("-") ? hashCode.replace("-", "1") : "0" + hashCode;
        this.listenToPidFromProcess = ProcessUtil.getOS() == EOperatingSystem.MAC || ProcessUtil.getOS() == EOperatingSystem.LINUX;
        this.scikitLearnWrapperConfig = ConfigCache.getOrCreate(IScikitLearnWrapperConfig.class, new Map[0]);
        this.scikitLearnWrapperConfig.getTempFolder().mkdirs();
        this.scikitLearnWrapperConfig.getModelDumpsDirectory().mkdirs();
        new PythonRequirementDefinition(3, 5, 0, (String[]) ArrayUtils.addAll(PYTHON_REQUIRED_MODULES, eScikitLearnProblemType.getPythonRequiredModules()), (String[]) ArrayUtils.addAll(PYTHON_OPTIONAL_MODULES, eScikitLearnProblemType.getPythonOptionalModules())).check(this.pythonConfig);
        setPythonTemplate(ResourceUtil.getResourceAsTempFile(SCIKIT_LEARN_TEMPLATE));
    }

    @Override // ai.libs.jaicore.ml.scikitwrapper.IScikitLearnWrapper
    public void setPythonTemplate(String str) throws IOException {
        File file = new File(str);
        File sKLearnScriptFile = getSKLearnScriptFile();
        if (!sKLearnScriptFile.createNewFile()) {
            this.logger.debug("Script file for configuration UID {} already exists in {}", this.configurationUID, sKLearnScriptFile.getAbsolutePath());
        }
        if (this.scikitLearnWrapperConfig.getDeleteFileOnExit()) {
            sKLearnScriptFile.deleteOnExit();
        }
        if (this.pipeline == null || this.pipeline.isEmpty()) {
            throw new AssertionError("Pipeline command for learner must be stated.");
        }
        HashMap hashMap = new HashMap();
        hashMap.put("imports", this.imports != null ? this.imports : "");
        hashMap.put("pipeline", this.pipeline);
        JtwigTemplate.fileTemplate(file).render(JtwigModel.newModel(hashMap), new FileOutputStream(sKLearnScriptFile));
    }

    @Override // ai.libs.jaicore.ml.scikitwrapper.IScikitLearnWrapper
    public void setModelPath(String str) throws IOException {
        this.modelFile = new File(str);
    }

    @Override // ai.libs.jaicore.ml.scikitwrapper.IScikitLearnWrapper
    public File getModelPath() {
        return this.modelFile;
    }

    @Override // ai.libs.jaicore.ml.scikitwrapper.IScikitLearnWrapper
    public void setTargetIndices(int... iArr) {
        this.targetIndices = iArr;
    }

    @Override // ai.libs.jaicore.ml.scikitwrapper.IScikitLearnWrapper
    public void setSeed(long j) {
        this.seed = j;
    }

    @Override // ai.libs.jaicore.ml.scikitwrapper.IScikitLearnWrapper
    public void setTimeout(Timeout timeout) {
        this.timeout = timeout;
    }

    public void fit(ILabeledDataset<? extends ILabeledInstance> iLabeledDataset) throws TrainingException, InterruptedException {
        try {
            String dataName = getDataName(iLabeledDataset);
            this.data = iLabeledDataset.createEmptyCopy();
            if (!doLabelsFitToProblemType(this.data)) {
                throw new TrainingException("The label of the given data " + iLabeledDataset.getRelationName() + " are not suitable for the selected problem type " + this.problemType.getName());
            }
            fit(getOrWriteDataFile(iLabeledDataset, dataName), dataName);
        } catch (DatasetCreationException | ScikitLearnWrapperExecutionFailedException e) {
            throw new TrainingException(COULD_NOT_RUN_SCIKIT_LEARN_MODEL, e);
        }
    }

    @Override // ai.libs.jaicore.ml.scikitwrapper.IScikitLearnWrapper
    public void fit(String str) throws TrainingException, InterruptedException {
        fit(getDatasetFile(str), str);
    }

    private void fit(File file, String str) throws TrainingException, InterruptedException {
        try {
            if (!getOutputFile(str).exists()) {
                this.modelFile = new File(this.scikitLearnWrapperConfig.getModelDumpsDirectory(), getModelFileName(str));
                String[] commandArray = constructCommandLineParametersForFitMode(this.modelFile, file).toCommandArray();
                if (this.logger.isDebugEnabled()) {
                    this.logger.debug("{} run train mode {}", Thread.currentThread().getName(), Arrays.toString(commandArray));
                }
                runProcess(commandArray);
            }
        } catch (ScikitLearnWrapperExecutionFailedException e) {
            throw new TrainingException(COULD_NOT_RUN_SCIKIT_LEARN_MODEL, e);
        }
    }

    @Override // ai.libs.jaicore.ml.core.learner.ASupervisedLearner
    public B predict(ILabeledDataset<? extends ILabeledInstance> iLabeledDataset) throws PredictionException, InterruptedException {
        try {
            String dataName = getDataName(iLabeledDataset);
            File orWriteDataFile = getOrWriteDataFile(iLabeledDataset, dataName);
            this.logger.info("Prediction dataset serialized, now acquiring predictions.");
            return predict(orWriteDataFile, dataName);
        } catch (ScikitLearnWrapperExecutionFailedException e) {
            throw new PredictionException(COULD_NOT_RUN_SCIKIT_LEARN_MODEL, e);
        }
    }

    public B predict(String str) throws PredictionException, InterruptedException {
        return predict(getDatasetFile(str), str);
    }

    private B predict(File file, String str) throws PredictionException, InterruptedException {
        try {
            File outputFile = getOutputFile(str);
            if (!outputFile.exists()) {
                String[] commandArray = constructCommandLineParametersForPredictMode(this.modelFile, file, outputFile).toCommandArray();
                if (this.logger.isDebugEnabled()) {
                    this.logger.debug("Run test mode with {}", Arrays.toString(commandArray));
                }
                runProcess(commandArray);
            }
            return handleOutput(outputFile);
        } catch (ScikitLearnWrapperExecutionFailedException | TrainingException e) {
            throw new PredictionException(COULD_NOT_RUN_SCIKIT_LEARN_MODEL, e);
        }
    }

    @Override // ai.libs.jaicore.ml.core.learner.ASupervisedLearner
    public B predict(ILabeledInstance[] iLabeledInstanceArr) throws PredictionException, InterruptedException {
        Objects.requireNonNull(this.modelFile, "Model has not been trained.");
        Objects.requireNonNull(this.data, "Model has not been trained.");
        this.logger.info("Predicting {} instances.", Integer.valueOf(iLabeledInstanceArr.length));
        try {
            ILabeledDataset<? extends ILabeledInstance> createEmptyCopy = this.data.createEmptyCopy();
            Stream stream = Arrays.stream(iLabeledInstanceArr);
            Objects.requireNonNull(createEmptyCopy);
            stream.forEach((v1) -> {
                r1.add(v1);
            });
            return predict(createEmptyCopy);
        } catch (DatasetCreationException e) {
            throw new PredictionException("Could not replicate labeled dataset instance", e);
        }
    }

    @Override // ai.libs.jaicore.ml.core.learner.ASupervisedLearner
    public P predict(ILabeledInstance iLabeledInstance) throws PredictionException, InterruptedException {
        return (P) predict(new ILabeledInstance[]{iLabeledInstance}).get(0);
    }

    @Override // ai.libs.jaicore.ml.core.learner.ASupervisedLearner
    public B fitAndPredict(ILabeledDataset<? extends ILabeledInstance> iLabeledDataset, ILabeledDataset<? extends ILabeledInstance> iLabeledDataset2) throws TrainingException, PredictionException, InterruptedException {
        try {
            String dataName = getDataName(iLabeledDataset);
            this.data = iLabeledDataset.createEmptyCopy();
            File orWriteDataFile = getOrWriteDataFile(iLabeledDataset, dataName);
            String dataName2 = getDataName(iLabeledDataset2);
            File orWriteDataFile2 = getOrWriteDataFile(iLabeledDataset2, dataName2);
            this.logger.info("Prediction dataset serialized, now acquiring predictions.");
            return fitAndPredict(orWriteDataFile, dataName, orWriteDataFile2, dataName2);
        } catch (DatasetCreationException | ScikitLearnWrapperExecutionFailedException e) {
            throw new TrainingException(COULD_NOT_RUN_SCIKIT_LEARN_MODEL, e);
        }
    }

    public B fitAndPredict(File file, String str, File file2, String str2) throws TrainingException, PredictionException, InterruptedException {
        try {
            File outputFile = getOutputFile(str);
            File outputFile2 = getOutputFile(str2);
            if (!outputFile.exists() && !outputFile2.exists()) {
                String[] commandArray = constructCommandLineParametersForFitAndPredictMode(file, file2, outputFile2).toCommandArray();
                if (this.logger.isDebugEnabled()) {
                    this.logger.debug("{} run fitAndPredict mode {}", Thread.currentThread().getName(), Arrays.toString(commandArray));
                }
                runProcess(commandArray);
            }
            return handleOutput(outputFile2);
        } catch (ScikitLearnWrapperExecutionFailedException e) {
            throw new TrainingException(COULD_NOT_RUN_SCIKIT_LEARN_MODEL, e);
        }
    }

    protected String getModelFileName(String str) {
        return this.configurationUID + "_" + str + this.scikitLearnWrapperConfig.getPickleFileExtension();
    }

    @Override // ai.libs.jaicore.ml.scikitwrapper.IScikitLearnWrapper
    public String getDataName(ILabeledDataset<? extends ILabeledInstance> iLabeledDataset) {
        String str = "" + iLabeledDataset.hashCode();
        return str.startsWith("-") ? str.replace("-", "1") : "0" + str;
    }

    private synchronized File getOrWriteDataFile(ILabeledDataset<? extends ILabeledInstance> iLabeledDataset, String str) throws ScikitLearnWrapperExecutionFailedException {
        this.logger.debug("Serializing {}x{} dataset to {}", new Object[]{Integer.valueOf(iLabeledDataset.size()), Integer.valueOf(iLabeledDataset.getNumAttributes()), str});
        File datasetFile = getDatasetFile(str);
        if (this.scikitLearnWrapperConfig.getDeleteFileOnExit()) {
            datasetFile.deleteOnExit();
        }
        if (datasetFile.exists()) {
            this.logger.debug("Reusing dataset: {}", str);
            return datasetFile;
        }
        try {
            new ArffDatasetAdapter().serializeDataset(datasetFile, iLabeledDataset);
            this.logger.debug("Serializating completed.");
            return datasetFile;
        } catch (IOException e) {
            throw new ScikitLearnWrapperExecutionFailedException("Could not dump data file for prediction", e);
        }
    }

    private synchronized File getDatasetFile(String str) {
        return new File(this.scikitLearnWrapperConfig.getTempFolder(), str + ".arff");
    }

    protected abstract boolean doLabelsFitToProblemType(ILabeledDataset<? extends ILabeledInstance> iLabeledDataset);

    protected ScikitLearnWrapperCommandBuilder getCommandBuilder() {
        return getCommandBuilder(new ScikitLearnWrapperCommandBuilder(this.problemType.getScikitLearnCommandLineFlag(), getSKLearnScriptFile()));
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public ScikitLearnWrapperCommandBuilder getCommandBuilder(ScikitLearnWrapperCommandBuilder scikitLearnWrapperCommandBuilder) {
        scikitLearnWrapperCommandBuilder.withLogger(this.logger);
        scikitLearnWrapperCommandBuilder.withSeed(this.seed);
        scikitLearnWrapperCommandBuilder.withTimeout(this.timeout);
        if (this.pythonConfig != null) {
            scikitLearnWrapperCommandBuilder.withPythonConfig(this.pythonConfig);
        }
        return scikitLearnWrapperCommandBuilder;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public ScikitLearnWrapperCommandBuilder constructCommandLineParametersForFitMode(File file, File file2) {
        ScikitLearnWrapperCommandBuilder commandBuilder = getCommandBuilder();
        commandBuilder.withFitMode();
        commandBuilder.withModelFile(file);
        commandBuilder.withFitDataFile(file2);
        commandBuilder.withTargetIndices(this.targetIndices);
        return commandBuilder;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public ScikitLearnWrapperCommandBuilder constructCommandLineParametersForPredictMode(File file, File file2, File file3) {
        ScikitLearnWrapperCommandBuilder commandBuilder = getCommandBuilder();
        commandBuilder.withPredictMode();
        commandBuilder.withModelFile(file);
        commandBuilder.withPredictDataFile(file2);
        commandBuilder.withTargetIndices(this.targetIndices);
        commandBuilder.withPredictOutputFile(file3);
        return commandBuilder;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public ScikitLearnWrapperCommandBuilder constructCommandLineParametersForFitAndPredictMode(File file, File file2, File file3) {
        ScikitLearnWrapperCommandBuilder commandBuilder = getCommandBuilder();
        commandBuilder.withFitAndPredictMode();
        commandBuilder.withFitDataFile(file);
        commandBuilder.withPredictDataFile(file2);
        commandBuilder.withPredictOutputFile(file3);
        commandBuilder.withTargetIndices(this.targetIndices);
        return commandBuilder;
    }

    private void runProcess(String[] strArr) throws InterruptedException, ScikitLearnWrapperExecutionFailedException {
        DefaultProcessListener defaultProcessListener = new DefaultProcessListener(this.listenToPidFromProcess);
        try {
            defaultProcessListener.setLoggerName(this.logger.getName() + ".python");
            this.logger.debug("Set logger name of listener to {}. Now starting python process.", defaultProcessListener.getLoggerName());
            if (this.logger.isDebugEnabled()) {
                String replace = Arrays.toString(strArr).replace(",", "");
                this.logger.info("Starting process {}", replace.substring(1, replace.length() - 1));
            }
            Process start = new ProcessBuilder(strArr).directory(this.scikitLearnWrapperConfig.getTempFolder()).start();
            this.logger.debug("Started process with PID: {}. Listener is {}", Integer.valueOf(ProcessUtil.getPID(start)), defaultProcessListener);
            this.logger.info("Attaching listener {} to process {}", defaultProcessListener, start);
            defaultProcessListener.listenTo(start);
            this.logger.info("Listener attached.");
            if (!defaultProcessListener.getErrorOutput().isEmpty()) {
                if (!defaultProcessListener.getErrorOutput().toLowerCase().contains("convergence")) {
                    throw new ScikitLearnWrapperExecutionFailedException(COULD_NOT_RUN_SCIKIT_LEARN_MODEL);
                }
                this.logger.warn("Learner {} could not converge. Consider increase number of iterations.", this.pipeline);
            }
        } catch (InterruptedException e) {
            throw e;
        } catch (Exception e2) {
            throw new ScikitLearnWrapperExecutionFailedException(COULD_NOT_RUN_SCIKIT_LEARN_MODEL, e2);
        } catch (ProcessIDNotRetrievableException e3) {
            this.logger.warn("Could not retrieve process ID.");
        }
    }

    @Override // ai.libs.jaicore.ml.scikitwrapper.IScikitLearnWrapper
    public File getOutputFile(String str) {
        return new File(this.scikitLearnWrapperConfig.getModelDumpsDirectory(), this.configurationUID + "_" + str + this.scikitLearnWrapperConfig.getResultFileExtension());
    }

    protected abstract B handleOutput(File file) throws PredictionException, TrainingException;

    /* JADX INFO: Access modifiers changed from: protected */
    public List<List<Double>> getRawPredictionResults(File file) throws PredictionException {
        try {
            String readFileAsString = FileUtil.readFileAsString(file);
            if (this.scikitLearnWrapperConfig.getDeleteFileOnExit()) {
                Files.delete(file.toPath());
            }
            List<List<Double>> list = (List) new ObjectMapper().readValue(readFileAsString, List.class);
            if (this.logger.isInfoEnabled()) {
                this.logger.info("{}", list.stream().flatMap((v0) -> {
                    return v0.stream();
                }).collect(Collectors.toList()));
            }
            return list;
        } catch (IOException e) {
            throw new PredictionException("Could not read result file or parse the json content to a list.", e);
        }
    }

    @Override // ai.libs.jaicore.ml.scikitwrapper.IScikitLearnWrapper
    public void setPythonConfig(IPythonConfig iPythonConfig) {
        this.pythonConfig = iPythonConfig;
    }

    @Override // ai.libs.jaicore.ml.scikitwrapper.IScikitLearnWrapper
    public void setScikitLearnWrapperConfig(IScikitLearnWrapperConfig iScikitLearnWrapperConfig) {
        this.scikitLearnWrapperConfig = iScikitLearnWrapperConfig;
    }

    @Override // ai.libs.jaicore.ml.scikitwrapper.IScikitLearnWrapper
    public File getSKLearnScriptFile() {
        return new File(this.scikitLearnWrapperConfig.getTempFolder(), this.configurationUID + this.scikitLearnWrapperConfig.getPythonFileExtension());
    }

    @Override // ai.libs.jaicore.ml.scikitwrapper.IScikitLearnWrapper
    public File getModelFile() {
        return this.modelFile;
    }

    public String getLoggerName() {
        return this.logger.getName();
    }

    public void setLoggerName(String str) {
        this.logger = LoggerFactory.getLogger(str);
    }

    public String toString() {
        return this.pipeline;
    }
}
