package ai.libs.mlplan.multiclasswithreduction;

import ai.libs.jaicore.basic.sets.SetUtil;
import ai.libs.jaicore.ml.WekaUtil;
import ai.libs.jaicore.ml.classification.multiclass.reduction.splitters.RPNDSplitter;
import ai.libs.mlplan.multiclass.wekamlplan.weka.model.MLPipeline;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashSet;
import java.util.Iterator;
import java.util.Random;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import weka.attributeSelection.InfoGainAttributeEval;
import weka.attributeSelection.Ranker;
import weka.classifiers.AbstractClassifier;
import weka.classifiers.Classifier;
import weka.core.Instance;
import weka.core.Instances;

/* loaded from: input_file:ai/libs/mlplan/multiclasswithreduction/NestedDichotomyUtil.class */
public class NestedDichotomyUtil {
    private static final Logger logger = LoggerFactory.getLogger(NestedDichotomyUtil.class);

    private NestedDichotomyUtil() {
    }

    public static ClassSplit<String> createGeneralRPNDBasedSplit(Collection<String> collection, Random random, String str, Instances instances) throws InterruptedException {
        if (collection.size() < 2) {
            throw new IllegalArgumentException("Cannot compute split for less than two classes!");
        }
        try {
            Iterator it = new RPNDSplitter(random, new MLPipeline(new Ranker(), new InfoGainAttributeEval(), AbstractClassifier.forName(str, (String[]) null))).split(instances).iterator();
            return new ClassSplit<>(collection, (Collection) it.next(), (Collection) it.next());
        } catch (InterruptedException e) {
            throw e;
        } catch (Exception e2) {
            logger.error("Unexpected exception occurred while creating an RPND split", e2);
            return null;
        }
    }

    public static ClassSplit<String> createGeneralRPNDBasedSplit(Collection<String> collection, Collection<String> collection2, Collection<String> collection3, Random random, String str, Instances instances) {
        try {
            Iterator it = new RPNDSplitter(random, AbstractClassifier.forName(str, new String[0])).split(collection, collection2, collection3, instances).iterator();
            return new ClassSplit<>(collection, (Collection) it.next(), (Collection) it.next());
        } catch (Exception e) {
            logger.error("Unexpected exception occurred while creating an RPND split", e);
            return null;
        }
    }

    public static ClassSplit<String> createUnaryRPNDBasedSplit(Collection<String> collection, Random random, String str, Instances instances) {
        if (collection.size() == 1) {
            return new ClassSplit<>(collection, null, null);
        }
        ArrayList arrayList = new ArrayList(collection);
        Collections.shuffle(arrayList, random);
        String str2 = (String) arrayList.get(0);
        String str3 = (String) arrayList.get(1);
        HashSet hashSet = new HashSet();
        hashSet.add(str2);
        HashSet hashSet2 = new HashSet();
        hashSet2.add(str3);
        Instances mergeClassesOfInstances = WekaUtil.mergeClassesOfInstances(instances, hashSet, hashSet2);
        try {
            Classifier forName = AbstractClassifier.forName(str, new String[0]);
            try {
                forName.buildClassifier(mergeClassesOfInstances);
            } catch (Exception e) {
                logger.error("Could not train classifier", e);
            }
            ArrayList arrayList2 = new ArrayList(SetUtil.difference(SetUtil.difference(collection, hashSet), hashSet2));
            int i = 0;
            int i2 = 0;
            for (int i3 = 0; i3 < arrayList2.size(); i3++) {
                Iterator it = WekaUtil.getInstancesOfClass(instances, (String) arrayList2.get(i3)).iterator();
                while (it.hasNext()) {
                    try {
                        if (forName.classifyInstance(WekaUtil.getRefactoredInstance((Instance) it.next())) == 0.0d) {
                            i++;
                        } else {
                            i2++;
                        }
                    } catch (Exception e2) {
                        logger.error("Could not get prediction for some instance to assign it to a meta-class", e2);
                    }
                }
            }
            if (i > i2) {
                hashSet.addAll(arrayList2);
            } else {
                hashSet2.addAll(arrayList2);
            }
            return new ClassSplit<>(collection, hashSet, hashSet2);
        } catch (Exception e3) {
            logger.error("Could not get object of classifier with name {}", str, e3);
            return null;
        }
    }
}
