package ai.libs.reduction;

import ai.libs.jaicore.ml.WekaUtil;
import ai.libs.jaicore.ml.classification.multiclass.reduction.MCTreeNodeReD;
import ai.libs.jaicore.ml.classification.multiclass.reduction.splitters.RPNDSplitter;
import ai.libs.reduction.ensemble.simple.EnsembleOfSimpleOneStepReductionsExperiment;
import ai.libs.reduction.single.ReductionExperiment;
import java.io.BufferedReader;
import java.io.FileReader;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Random;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import weka.classifiers.AbstractClassifier;
import weka.classifiers.Classifier;
import weka.classifiers.Evaluation;
import weka.classifiers.meta.Vote;
import weka.core.Instances;

/* loaded from: input_file:ai/libs/reduction/Util.class */
public class Util {
    private static final Logger logger = LoggerFactory.getLogger(Util.class);
    private static final String LABEL_TRAIN_TIME = "trainTime";

    private Util() {
    }

    public static List<Map<String, Object>> conductSingleOneStepReductionExperiment(ReductionExperiment reductionExperiment) throws Exception {
        Instances instances = new Instances(new BufferedReader(new FileReader(reductionExperiment.getDataset())));
        instances.setClassIndex(instances.numAttributes() - 1);
        int seed = reductionExperiment.getSeed();
        Classifier forName = AbstractClassifier.forName(reductionExperiment.getNameOfInnerClassifier(), (String[]) null);
        Classifier forName2 = AbstractClassifier.forName(reductionExperiment.getNameOfLeftClassifier(), (String[]) null);
        Classifier forName3 = AbstractClassifier.forName(reductionExperiment.getNameOfInnerClassifier(), (String[]) null);
        Classifier forName4 = AbstractClassifier.forName(reductionExperiment.getNameOfRightClassifier(), (String[]) null);
        RPNDSplitter rPNDSplitter = new RPNDSplitter(new Random(seed), forName);
        ArrayList arrayList = new ArrayList();
        for (int i = 0; i < 10; i++) {
            try {
                ArrayList arrayList2 = new ArrayList(rPNDSplitter.split(instances));
                MCTreeNodeReD mCTreeNodeReD = new MCTreeNodeReD(forName3, (Collection) arrayList2.get(0), forName2, (Collection) arrayList2.get(1), forName4);
                long currentTimeMillis = System.currentTimeMillis();
                HashMap hashMap = new HashMap();
                List stratifiedSplit = WekaUtil.getStratifiedSplit(instances, seed + i, new double[]{0.7d});
                mCTreeNodeReD.buildClassifier((Instances) stratifiedSplit.get(0));
                long currentTimeMillis2 = System.currentTimeMillis() - currentTimeMillis;
                Evaluation evaluation = new Evaluation((Instances) stratifiedSplit.get(0));
                evaluation.evaluateModel(mCTreeNodeReD, (Instances) stratifiedSplit.get(1), new Object[0]);
                double pctCorrect = (100.0d - evaluation.pctCorrect()) / 100.0d;
                logger.info("Conducted experiment {} with split {}/{}. Loss: {}. Time: {}ms.", new Object[]{Integer.valueOf(i), arrayList2.get(0), arrayList2.get(1), Double.valueOf(pctCorrect), Long.valueOf(currentTimeMillis2)});
                hashMap.put("errorRate", Double.valueOf(pctCorrect));
                hashMap.put(LABEL_TRAIN_TIME, Long.valueOf(currentTimeMillis2));
                arrayList.add(hashMap);
            } catch (Exception e) {
                throw new RuntimeException("Could not create RPND split.", e);
            }
        }
        return arrayList;
    }

    public static List<Map<String, Object>> conductEnsembleOfOneStepReductionsExperiment(EnsembleOfSimpleOneStepReductionsExperiment ensembleOfSimpleOneStepReductionsExperiment) throws Exception {
        Instances instances = new Instances(new BufferedReader(new FileReader(ensembleOfSimpleOneStepReductionsExperiment.getDataset())));
        instances.setClassIndex(instances.numAttributes() - 1);
        int seed = ensembleOfSimpleOneStepReductionsExperiment.getSeed();
        String nameOfClassifier = ensembleOfSimpleOneStepReductionsExperiment.getNameOfClassifier();
        RPNDSplitter rPNDSplitter = new RPNDSplitter(new Random(seed), AbstractClassifier.forName(nameOfClassifier, (String[]) null));
        ArrayList arrayList = new ArrayList();
        for (int i = 0; i < 10; i++) {
            Vote vote = new Vote();
            vote.setOptions(new String[]{"-R", "MAJ"});
            long currentTimeMillis = System.currentTimeMillis();
            List stratifiedSplit = WekaUtil.getStratifiedSplit(instances, seed + i, new double[]{0.7d});
            for (int i2 = 0; i2 < ensembleOfSimpleOneStepReductionsExperiment.getNumberOfStumps(); i2++) {
                ArrayList arrayList2 = new ArrayList(rPNDSplitter.split(instances));
                MCTreeNodeReD mCTreeNodeReD = new MCTreeNodeReD(nameOfClassifier, (Collection) arrayList2.get(0), nameOfClassifier, (Collection) arrayList2.get(1), nameOfClassifier);
                mCTreeNodeReD.buildClassifier((Instances) stratifiedSplit.get(0));
                vote.addPreBuiltClassifier(mCTreeNodeReD);
            }
            HashMap hashMap = new HashMap();
            hashMap.put(LABEL_TRAIN_TIME, Long.valueOf(System.currentTimeMillis() - currentTimeMillis));
            vote.buildClassifier(instances);
            Evaluation evaluation = new Evaluation((Instances) stratifiedSplit.get(0));
            evaluation.evaluateModel(vote, (Instances) stratifiedSplit.get(1), new Object[0]);
            double pctCorrect = (100.0d - evaluation.pctCorrect()) / 100.0d;
            logger.info("Conducted experiment {}. Loss: {}. Time: {}ms.", new Object[]{Integer.valueOf(i), Double.valueOf(pctCorrect), hashMap.get(LABEL_TRAIN_TIME)});
            hashMap.put("errorRate", Double.valueOf(pctCorrect));
            arrayList.add(hashMap);
        }
        return arrayList;
    }
}
