package ai.libs.jaicore.ml.scikitwrapper.simple;

import ai.libs.jaicore.basic.FileUtil;
import ai.libs.jaicore.ml.regression.singlelabel.SingleTargetRegressionPrediction;
import ai.libs.jaicore.ml.regression.singlelabel.SingleTargetRegressionPredictionBatch;
import ai.libs.jaicore.ml.scikitwrapper.ScikitLearnWrapperExecutionFailedException;
import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.node.ArrayNode;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Iterator;
import org.api4.java.ai.ml.core.dataset.supervised.ILabeledDataset;
import org.api4.java.ai.ml.core.dataset.supervised.ILabeledInstance;
import org.api4.java.ai.ml.core.exception.PredictionException;
import org.api4.java.ai.ml.regression.evaluation.IRegressionPrediction;
import org.api4.java.ai.ml.regression.evaluation.IRegressionResultBatch;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:ai/libs/jaicore/ml/scikitwrapper/simple/SimpleScikitLearnRegressor.class */
public class SimpleScikitLearnRegressor extends ASimpleScikitLearnWrapper<IRegressionPrediction, IRegressionResultBatch> {
    private Logger logger;

    public SimpleScikitLearnRegressor(String str, String str2) throws IOException, InterruptedException {
        super(str, str2, "regression");
        this.logger = LoggerFactory.getLogger(SimpleScikitLearnRegressor.class);
    }

    @Override // ai.libs.jaicore.ml.core.learner.ASupervisedLearner
    public IRegressionResultBatch predict(ILabeledDataset<? extends ILabeledInstance> iLabeledDataset) throws PredictionException, InterruptedException {
        try {
            ArrayNode readTree = new ObjectMapper().readTree(FileUtil.readFileAsString(executePipeline(iLabeledDataset)));
            if (!(readTree instanceof ArrayNode)) {
                throw new PredictionException("Json file for predictions does not contain an array as root element");
            }
            ArrayList arrayList = new ArrayList();
            Iterator it = readTree.iterator();
            while (it.hasNext()) {
                arrayList.add(new SingleTargetRegressionPrediction(Double.valueOf(((JsonNode) it.next()).asDouble())));
            }
            return new SingleTargetRegressionPredictionBatch(arrayList);
        } catch (ScikitLearnWrapperExecutionFailedException e) {
            throw new PredictionException("Could not execute scikit learn wrapper", e);
        } catch (IOException e2) {
            throw new PredictionException("Could not write executable python file.", e2);
        } catch (InterruptedException e3) {
            this.logger.info("SimpleScikitLearnRegressor for pipeline {} got interrupted.", this.constructorCall);
            throw e3;
        }
    }
}
