package ai.libs.jaicore.ml.core.evaluation.evaluator;

import ai.libs.jaicore.ml.core.evaluation.evaluator.events.TrainTestSplitEvaluationCompletedEvent;
import ai.libs.jaicore.ml.core.evaluation.evaluator.events.TrainTestSplitEvaluationFailedEvent;
import com.google.common.eventbus.EventBus;
import java.util.ArrayList;
import java.util.List;
import java.util.stream.Collectors;
import org.api4.java.ai.ml.classification.IClassifierEvaluator;
import org.api4.java.ai.ml.core.dataset.splitter.SplitFailedException;
import org.api4.java.ai.ml.core.dataset.supervised.ILabeledDataset;
import org.api4.java.ai.ml.core.dataset.supervised.ILabeledInstance;
import org.api4.java.ai.ml.core.evaluation.execution.IAggregatedPredictionPerformanceMeasure;
import org.api4.java.ai.ml.core.evaluation.execution.IDatasetSplitSet;
import org.api4.java.ai.ml.core.evaluation.execution.IFixedDatasetSplitSetGenerator;
import org.api4.java.ai.ml.core.evaluation.execution.ILearnerRunReport;
import org.api4.java.ai.ml.core.evaluation.execution.LearnerExecutionFailedException;
import org.api4.java.ai.ml.core.evaluation.execution.LearnerExecutionInterruptedException;
import org.api4.java.ai.ml.core.learner.ISupervisedLearner;
import org.api4.java.common.attributedobjects.ObjectEvaluationFailedException;
import org.api4.java.common.control.ILoggingCustomizable;
import org.api4.java.common.event.IEventEmitter;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:ai/libs/jaicore/ml/core/evaluation/evaluator/TrainPredictionBasedClassifierEvaluator.class */
public class TrainPredictionBasedClassifierEvaluator implements IClassifierEvaluator, ILoggingCustomizable, IEventEmitter<Object> {
    private final IFixedDatasetSplitSetGenerator<ILabeledDataset<? extends ILabeledInstance>> splitGenerator;
    private final IAggregatedPredictionPerformanceMeasure metric;
    private boolean hasListeners;
    private Logger logger = LoggerFactory.getLogger(TrainPredictionBasedClassifierEvaluator.class);
    private final SupervisedLearnerExecutor executor = new SupervisedLearnerExecutor();
    private final EventBus eventBus = new EventBus();

    /* JADX WARN: Multi-variable type inference failed */
    public TrainPredictionBasedClassifierEvaluator(IFixedDatasetSplitSetGenerator<ILabeledDataset<?>> iFixedDatasetSplitSetGenerator, IAggregatedPredictionPerformanceMeasure<?, ?> iAggregatedPredictionPerformanceMeasure) {
        this.splitGenerator = iFixedDatasetSplitSetGenerator;
        this.metric = iAggregatedPredictionPerformanceMeasure;
    }

    public Double evaluate(ISupervisedLearner<ILabeledInstance, ILabeledDataset<? extends ILabeledInstance>> iSupervisedLearner) throws InterruptedException, ObjectEvaluationFailedException {
        try {
            long currentTimeMillis = System.currentTimeMillis();
            this.logger.info("Using {} to split the given data into two folds.", this.splitGenerator.getClass().getName());
            IDatasetSplitSet nextSplitSet = this.splitGenerator.nextSplitSet();
            if (nextSplitSet.getNumberOfFoldsPerSplit() != 2) {
                throw new IllegalStateException("Number of folds for each split should be 2 but is " + nextSplitSet.getNumberOfFoldsPerSplit() + "! Split generator: " + this.splitGenerator);
            }
            int numberOfSplits = nextSplitSet.getNumberOfSplits();
            ArrayList arrayList = new ArrayList(numberOfSplits);
            for (int i = 0; i < numberOfSplits; i++) {
                List folds = nextSplitSet.getFolds(i);
                this.logger.debug("Executing learner{} on folds of sizes {} (train) and {} (test) using {}.", new Object[]{iSupervisedLearner, Integer.valueOf(((ILabeledDataset) folds.get(0)).size()), Integer.valueOf(((ILabeledDataset) folds.get(1)).size()), this.executor.getClass().getName()});
                try {
                    try {
                        ILearnerRunReport execute = this.executor.execute(iSupervisedLearner, (ILabeledDataset) folds.get(0), (ILabeledDataset) folds.get(1));
                        this.logger.trace("Obtained report. Training times was {}ms, testing time {}ms. Ground truth vector: {}, prediction vector: {}", new Object[]{Long.valueOf(execute.getTrainEndTime() - execute.getTrainStartTime()), Long.valueOf(execute.getTestEndTime() - execute.getTestStartTime()), execute.getPredictionDiffList().getGroundTruthAsList(), execute.getPredictionDiffList().getPredictionsAsList()});
                        if (this.hasListeners) {
                            this.eventBus.post(new TrainTestSplitEvaluationCompletedEvent(iSupervisedLearner, execute));
                        }
                        arrayList.add(execute);
                        if (this.logger.isDebugEnabled()) {
                            List groundTruthAsList = execute.getPredictionDiffList().getGroundTruthAsList();
                            List predictionsAsList = execute.getPredictionDiffList().getPredictionsAsList();
                            int size = groundTruthAsList.size();
                            int i2 = 0;
                            for (int i3 = 0; i3 < size; i3++) {
                                if (!predictionsAsList.get(i3).equals(groundTruthAsList.get(i3))) {
                                    i2++;
                                }
                            }
                            if (size - i2 == 0) {
                                this.logger.warn("0 correct predictions seems suspicious. Here are the vectors: \n\tGround truth: {}\n\tPredictions: {}", execute.getPredictionDiffList().getGroundTruthAsList(), execute.getPredictionDiffList().getPredictionsAsList());
                            }
                            this.logger.debug("Execution completed. Classifier predicted {}/{} test instances correctly.", Integer.valueOf(size - i2), Integer.valueOf(size));
                        }
                    } catch (LearnerExecutionInterruptedException e) {
                        this.logger.info("Received interrupt of training in iteration #{} after a total evaluation time of {}ms. Sending an event over the bus and forwarding the exception.", Integer.valueOf(i + 1), Long.valueOf(System.currentTimeMillis() - currentTimeMillis));
                        this.eventBus.post(new TrainTestSplitEvaluationFailedEvent(iSupervisedLearner, new LearnerRunReport((ILabeledDataset<?>) folds.get(0), (ILabeledDataset<?>) folds.get(1), e.getTrainTimeStart(), e.getTrainTimeEnd(), e.getTestTimeStart(), e.getTestTimeEnd(), (Throwable) e)));
                        throw e;
                    }
                } catch (LearnerExecutionFailedException e2) {
                    this.logger.info("Catching {} in iteration #{} after a total evaluation time of {}ms. Sending an event over the bus and forwarding the exception.", new Object[]{e2.getClass().getName(), Integer.valueOf(i + 1), Long.valueOf(System.currentTimeMillis() - currentTimeMillis)});
                    this.eventBus.post(new TrainTestSplitEvaluationFailedEvent(iSupervisedLearner, new LearnerRunReport((ILabeledDataset<?>) folds.get(0), (ILabeledDataset<?>) folds.get(1), e2.getTrainTimeStart(), e2.getTrainTimeEnd(), e2.getTestTimeStart(), e2.getTestTimeEnd(), (Throwable) e2)));
                    throw e2;
                }
            }
            this.logger.debug("Compute metric ({}) for the diff of predictions and ground truth.", this.metric.getClass().getName());
            double loss = this.metric.loss((List) arrayList.stream().map((v0) -> {
                return v0.getPredictionDiffList();
            }).collect(Collectors.toList()));
            this.logger.info("Computed value for metric {} of {} executions. Metric value is: {}", new Object[]{this.metric, Integer.valueOf(numberOfSplits), Double.valueOf(loss)});
            return Double.valueOf(loss);
        } catch (LearnerExecutionFailedException | SplitFailedException e3) {
            this.logger.debug("Failed to evaluate the classifier {}.", e3);
            throw new ObjectEvaluationFailedException(e3);
        }
    }

    protected IFixedDatasetSplitSetGenerator<ILabeledDataset<? extends ILabeledInstance>> getSplitGenerator() {
        return this.splitGenerator;
    }

    public String getLoggerName() {
        return this.logger.getName();
    }

    public void setLoggerName(String str) {
        this.logger = LoggerFactory.getLogger(str);
        if (this.splitGenerator instanceof ILoggingCustomizable) {
            this.splitGenerator.setLoggerName(str + ".splitgen");
            this.logger.info("Setting logger of split generator {} to {}.splitgen", this.splitGenerator.getClass().getName(), str);
        } else {
            this.logger.info("Split generator {} is not configurable for logging, so not configuring it.", this.splitGenerator.getClass().getName());
        }
        this.executor.setLoggerName(str + ".executor");
        this.logger.info("Setting logger of learner executor {} to {}.executor", this.executor.getClass().getName(), str);
    }

    public void registerListener(Object obj) {
        this.eventBus.register(obj);
        this.hasListeners = true;
    }
}
