package ai.libs.jaicore.ml.core.evaluation.evaluator;

import java.util.List;
import org.api4.java.ai.ml.classification.multilabel.evaluation.IMultiLabelClassification;
import org.api4.java.ai.ml.core.dataset.supervised.ILabeledDataset;
import org.api4.java.ai.ml.core.dataset.supervised.ILabeledInstance;
import org.api4.java.ai.ml.core.evaluation.IPrediction;
import org.api4.java.ai.ml.core.evaluation.execution.ILearnerRunReport;
import org.api4.java.ai.ml.core.evaluation.execution.ISupervisedLearnerExecutor;
import org.api4.java.ai.ml.core.evaluation.execution.LearnerExecutionFailedException;
import org.api4.java.ai.ml.core.evaluation.execution.LearnerExecutionInterruptedException;
import org.api4.java.ai.ml.core.exception.PredictionException;
import org.api4.java.ai.ml.core.exception.TrainingException;
import org.api4.java.ai.ml.core.learner.ISupervisedLearner;
import org.api4.java.common.control.ILoggingCustomizable;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:ai/libs/jaicore/ml/core/evaluation/evaluator/SupervisedLearnerExecutor.class */
public class SupervisedLearnerExecutor implements ISupervisedLearnerExecutor, ILoggingCustomizable {
    private Logger logger = LoggerFactory.getLogger(SupervisedLearnerExecutor.class);

    public <I extends ILabeledInstance, D extends ILabeledDataset<? extends I>> ILearnerRunReport execute(ISupervisedLearner<I, D> iSupervisedLearner, D d, D d2) throws LearnerExecutionFailedException, LearnerExecutionInterruptedException {
        long currentTimeMillis = System.currentTimeMillis();
        try {
            this.logger.info("Fitting the learner (class: {}) {} with {} instances, each of which with {} attributes", new Object[]{iSupervisedLearner.getClass().getName(), iSupervisedLearner, Integer.valueOf(d.size()), Integer.valueOf(d.getNumAttributes())});
            iSupervisedLearner.fit(d);
            long currentTimeMillis2 = System.currentTimeMillis();
            this.logger.debug("Training finished successfully after {}ms. Now acquiring predictions.", Long.valueOf(currentTimeMillis2 - currentTimeMillis));
            try {
                return getReportForTrainedLearner(iSupervisedLearner, d, d2, currentTimeMillis, currentTimeMillis2);
            } catch (PredictionException e) {
                this.logger.info("Prediction failed with exception {}.", e.getClass().getName());
                throw new LearnerExecutionFailedException(currentTimeMillis, currentTimeMillis2, currentTimeMillis2, System.currentTimeMillis(), e);
            } catch (InterruptedException e2) {
                long currentTimeMillis3 = System.currentTimeMillis();
                this.logger.info("Learner was interrupted during prediction after a runtime of {}ms for training and {}ms for testing ({}ms total walltime).", new Object[]{Long.valueOf(currentTimeMillis2 - currentTimeMillis), Long.valueOf(currentTimeMillis3 - currentTimeMillis2), Long.valueOf(currentTimeMillis3 - currentTimeMillis)});
                throw new LearnerExecutionInterruptedException(currentTimeMillis, currentTimeMillis2, currentTimeMillis2, System.currentTimeMillis());
            }
        } catch (InterruptedException e3) {
            long currentTimeMillis4 = System.currentTimeMillis();
            this.logger.info("Training was interrupted after {}ms, sending respective LearnerExecutionInterruptedException.", Long.valueOf(currentTimeMillis4 - currentTimeMillis));
            throw new LearnerExecutionInterruptedException(currentTimeMillis, currentTimeMillis4);
        } catch (TrainingException e4) {
            long currentTimeMillis5 = System.currentTimeMillis();
            this.logger.info("Training failed due to an {} after {}ms.", e4.getClass().getName(), Long.valueOf(currentTimeMillis5 - currentTimeMillis));
            throw new LearnerExecutionFailedException(currentTimeMillis, currentTimeMillis5, e4);
        }
    }

    public <I extends ILabeledInstance, D extends ILabeledDataset<? extends I>> ILearnerRunReport execute(ISupervisedLearner<I, D> iSupervisedLearner, D d) throws LearnerExecutionFailedException {
        long currentTimeMillis = System.currentTimeMillis();
        try {
            return getReportForTrainedLearner(iSupervisedLearner, null, d, -1L, -1L);
        } catch (PredictionException e) {
            throw new LearnerExecutionFailedException(-1L, -1L, currentTimeMillis, System.currentTimeMillis(), e);
        } catch (InterruptedException e2) {
            Thread.currentThread().interrupt();
            throw new LearnerExecutionFailedException(-1L, -1L, currentTimeMillis, System.currentTimeMillis(), e2);
        }
    }

    private <I extends ILabeledInstance, D extends ILabeledDataset<? extends I>> ILearnerRunReport getReportForTrainedLearner(ISupervisedLearner<I, D> iSupervisedLearner, D d, D d2, long j, long j2) throws PredictionException, InterruptedException {
        long currentTimeMillis = System.currentTimeMillis();
        List predictions = iSupervisedLearner.predict(d2).getPredictions();
        long currentTimeMillis2 = System.currentTimeMillis();
        int size = d2.size();
        TypelessPredictionDiff typelessPredictionDiff = new TypelessPredictionDiff();
        for (int i = 0; i < size; i++) {
            typelessPredictionDiff.addPair(((ILabeledInstance) d2.get(i)).getLabel(), predictions.get(i) instanceof IMultiLabelClassification ? predictions.get(i) : ((IPrediction) predictions.get(i)).getPrediction());
        }
        return new LearnerRunReport(d, d2, j, j2, currentTimeMillis, currentTimeMillis2, typelessPredictionDiff);
    }

    public String getLoggerName() {
        return this.logger.getName();
    }

    public void setLoggerName(String str) {
        this.logger = LoggerFactory.getLogger(str);
    }
}
