package ai.libs.jaicore.ml.scikitwrapper;

import ai.libs.jaicore.ml.classification.singlelabel.SingleLabelClassification;
import ai.libs.jaicore.ml.classification.singlelabel.SingleLabelClassificationPredictionBatch;
import ai.libs.jaicore.ml.core.EScikitLearnProblemType;
import java.io.File;
import java.io.IOException;
import java.util.Collection;
import java.util.List;
import java.util.stream.Collectors;
import org.api4.java.ai.ml.classification.singlelabel.evaluation.ISingleLabelClassification;
import org.api4.java.ai.ml.classification.singlelabel.evaluation.ISingleLabelClassificationPredictionBatch;
import org.api4.java.ai.ml.core.dataset.schema.attribute.ICategoricalAttribute;
import org.api4.java.ai.ml.core.dataset.supervised.ILabeledDataset;
import org.api4.java.ai.ml.core.dataset.supervised.ILabeledInstance;
import org.api4.java.ai.ml.core.exception.PredictionException;
import org.api4.java.ai.ml.core.exception.TrainingException;

/* loaded from: input_file:ai/libs/jaicore/ml/scikitwrapper/ScikitLearnClassificationWrapper.class */
public class ScikitLearnClassificationWrapper extends AScikitLearnWrapper<ISingleLabelClassification, ISingleLabelClassificationPredictionBatch> {
    public ScikitLearnClassificationWrapper(String str, String str2) throws IOException, InterruptedException {
        super(EScikitLearnProblemType.CLASSIFICATION, str, str2);
    }

    @Override // ai.libs.jaicore.ml.scikitwrapper.AScikitLearnWrapper
    protected boolean doLabelsFitToProblemType(ILabeledDataset<? extends ILabeledInstance> iLabeledDataset) {
        return iLabeledDataset.getLabelAttribute() instanceof ICategoricalAttribute;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // ai.libs.jaicore.ml.scikitwrapper.AScikitLearnWrapper
    public ISingleLabelClassificationPredictionBatch handleOutput(File file) throws PredictionException, TrainingException {
        List<List<Double>> rawPredictionResults = getRawPredictionResults(file);
        if (rawPredictionResults.isEmpty()) {
            throw new PredictionException("Reading the output file lead to empty predictions.");
        }
        if (rawPredictionResults.get(0).size() != 1) {
            return new SingleLabelClassificationPredictionBatch((Collection) rawPredictionResults.stream().map(list -> {
                return list.stream().mapToDouble(d -> {
                    return d.doubleValue();
                }).toArray();
            }).map(SingleLabelClassification::new).collect(Collectors.toList()));
        }
        int size = this.data.getLabelAttribute().getLabels().size();
        return new SingleLabelClassificationPredictionBatch((Collection) rawPredictionResults.stream().flatMap((v0) -> {
            return v0.stream();
        }).map(d -> {
            return new SingleLabelClassification(size, d.intValue());
        }).collect(Collectors.toList()));
    }
}
