package ai.libs.reduction.single;

import ai.libs.jaicore.experiments.exceptions.ExperimentEvaluationFailedException;
import ai.libs.jaicore.ml.classification.loss.dataset.EClassificationPerformanceMeasure;
import ai.libs.jaicore.ml.core.evaluation.evaluator.FixedSplitClassifierEvaluator;
import ai.libs.jaicore.ml.core.evaluation.evaluator.MonteCarloCrossValidationEvaluator;
import ai.libs.jaicore.ml.weka.WekaUtil;
import ai.libs.jaicore.ml.weka.classification.learner.WekaClassifier;
import ai.libs.jaicore.ml.weka.classification.learner.reduction.MCTreeNodeReD;
import ai.libs.jaicore.ml.weka.classification.learner.reduction.splitter.ISplitter;
import ai.libs.jaicore.ml.weka.classification.learner.reduction.splitter.ISplitterFactory;
import ai.libs.jaicore.ml.weka.dataset.IWekaInstances;
import ai.libs.jaicore.ml.weka.dataset.WekaInstances;
import java.io.BufferedReader;
import java.io.FileReader;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Random;
import org.api4.java.ai.ml.core.dataset.supervised.ILabeledDataset;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import weka.classifiers.AbstractClassifier;
import weka.classifiers.Classifier;
import weka.core.Instances;

/* loaded from: input_file:ai/libs/reduction/single/ExperimentRunner.class */
public class ExperimentRunner<T extends ISplitter> {
    private final int k;
    private final int mccvRepeats;
    private final ISplitterFactory<T> splitterFactory;
    private final Logger logger = LoggerFactory.getLogger(ExperimentRunner.class);

    public ExperimentRunner(int i, int i2, ISplitterFactory<T> iSplitterFactory) {
        this.k = i;
        this.mccvRepeats = i2;
        this.splitterFactory = iSplitterFactory;
    }

    public Map<String, Object> conductSingleOneStepReductionExperiment(ReductionExperiment reductionExperiment) throws Exception {
        Instances instances = new Instances(new BufferedReader(new FileReader(reductionExperiment.getDataset())));
        instances.setClassIndex(instances.numAttributes() - 1);
        WekaInstances wekaInstances = new WekaInstances(instances);
        int seed = reductionExperiment.getSeed();
        Classifier forName = AbstractClassifier.forName(reductionExperiment.getNameOfLeftClassifier(), (String[]) null);
        Classifier forName2 = AbstractClassifier.forName(reductionExperiment.getNameOfInnerClassifier(), (String[]) null);
        Classifier forName3 = AbstractClassifier.forName(reductionExperiment.getNameOfRightClassifier(), (String[]) null);
        List stratifiedSplit = WekaUtil.getStratifiedSplit(wekaInstances, reductionExperiment.getSeed(), 0.7d);
        MonteCarloCrossValidationEvaluator monteCarloCrossValidationEvaluator = new MonteCarloCrossValidationEvaluator(new WekaInstances(wekaInstances), this.mccvRepeats, 0.7d, new Random(seed));
        ISplitter splitter = this.splitterFactory.getSplitter(seed);
        MCTreeNodeReD mCTreeNodeReD = null;
        double d = Double.MAX_VALUE;
        for (int i = 0; i < this.k; i++) {
            try {
                ArrayList arrayList = new ArrayList(splitter.split(((IWekaInstances) stratifiedSplit.get(0)).getInstances()));
                MCTreeNodeReD mCTreeNodeReD2 = new MCTreeNodeReD(forName2, (Collection) arrayList.get(0), forName, (Collection) arrayList.get(1), forName3);
                double doubleValue = ((Double) monteCarloCrossValidationEvaluator.evaluate(new WekaClassifier(mCTreeNodeReD2))).doubleValue();
                this.logger.info("\t\t\tComputed loss {}", Double.valueOf(doubleValue));
                if (doubleValue < d) {
                    d = doubleValue;
                    mCTreeNodeReD = mCTreeNodeReD2;
                }
            } catch (InterruptedException e) {
                throw e;
            } catch (Exception e2) {
                throw new ExperimentEvaluationFailedException("Could not create a split.", e2);
            }
        }
        double doubleValue2 = new FixedSplitClassifierEvaluator((ILabeledDataset) stratifiedSplit.get(0), (ILabeledDataset) stratifiedSplit.get(1), EClassificationPerformanceMeasure.ERRORRATE).evaluate(new WekaClassifier(mCTreeNodeReD)).doubleValue();
        HashMap hashMap = new HashMap();
        this.logger.info("\t\t\tBest previously observed loss was {}. The retrained classifier achieves {} on the full data.", Double.valueOf(d), Double.valueOf(doubleValue2));
        hashMap.put("errorRate", Double.valueOf(doubleValue2));
        return hashMap;
    }
}
