package ai.libs.jaicore.ml.weka.classification.learner.reduction.reducer;

import ai.libs.jaicore.logging.LoggerUtil;
import ai.libs.jaicore.ml.weka.WekaUtil;
import ai.libs.jaicore.ml.weka.classification.learner.reduction.EMCNodeType;
import ai.libs.jaicore.ml.weka.classification.learner.reduction.MCTreeNode;
import ai.libs.jaicore.ml.weka.classification.learner.reduction.MCTreeNodeLeaf;
import ai.libs.jaicore.ml.weka.dataset.IWekaInstances;
import ai.libs.jaicore.ml.weka.dataset.WekaInstances;
import ai.libs.jaicore.search.algorithms.standard.bestfirst.BestFirstEpsilon;
import ai.libs.jaicore.search.model.other.EvaluatedSearchGraphPath;
import ai.libs.jaicore.search.probleminputs.GraphSearchWithSubpathEvaluationsInput;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
import java.util.Random;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import java.util.stream.Stream;
import org.apache.commons.math3.stat.descriptive.DescriptiveStatistics;
import org.api4.java.ai.graphsearch.problem.implicit.graphgenerator.INodeGoalTester;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import weka.classifiers.Classifier;
import weka.classifiers.Evaluation;
import weka.classifiers.rules.OneR;
import weka.core.Attribute;
import weka.core.Capabilities;
import weka.core.Instance;
import weka.core.Instances;

/* loaded from: input_file:ai/libs/jaicore/ml/weka/classification/learner/reduction/reducer/ReductionOptimizer.class */
public class ReductionOptimizer implements Classifier {
    private static final long serialVersionUID = -6241267445544412443L;
    private final long seed;
    private MCTreeNode root;
    private transient Logger logger = LoggerFactory.getLogger(ReductionOptimizer.class);

    public ReductionOptimizer(long j) {
        this.seed = j;
    }

    public void buildClassifier(Instances instances) throws Exception {
        int i;
        BestFirstEpsilon bestFirstEpsilon = new BestFirstEpsilon(new GraphSearchWithSubpathEvaluationsInput(new ReductionGraphGenerator(new Random(this.seed), WekaUtil.getStratifiedSplit(new WekaInstances(instances), this.seed, 0.6000000238418579d).get(0).m66getList()), new INodeGoalTester<RestProblem, Decision>() { // from class: ai.libs.jaicore.ml.weka.classification.learner.reduction.reducer.ReductionOptimizer.1
            public boolean isGoal(RestProblem restProblem) {
                Iterator<Set<String>> it = restProblem.iterator();
                while (it.hasNext()) {
                    if (it.next().size() > 1) {
                        return false;
                    }
                }
                return true;
            }
        }, iLabeledPath -> {
            return Double.valueOf(getLossForClassifier(getTreeFromSolution(iLabeledPath.getNodes(), instances, false), instances) * 1.0d);
        }), iLabeledPath2 -> {
            return Double.valueOf(iLabeledPath2.getNodes().size() * (-1.0d));
        }, 0.1d, false);
        int i2 = 0;
        ArrayList arrayList = new ArrayList();
        do {
            EvaluatedSearchGraphPath nextSolutionCandidate = bestFirstEpsilon.nextSolutionCandidate();
            if (nextSolutionCandidate == null) {
                break;
            }
            arrayList.add(nextSolutionCandidate);
            i = i2;
            i2++;
        } while (i <= 100);
        Optional min = arrayList.stream().min((evaluatedSearchGraphPath, evaluatedSearchGraphPath2) -> {
            return ((Double) evaluatedSearchGraphPath.getScore()).compareTo((Double) evaluatedSearchGraphPath2.getScore());
        });
        if (!min.isPresent()) {
            this.logger.error("No solution found");
        } else {
            this.root = getTreeFromSolution(((EvaluatedSearchGraphPath) min.get()).getNodes(), instances, true);
            this.root.buildClassifier(instances);
        }
    }

    public double classifyInstance(Instance instance) throws Exception {
        return this.root.classifyInstance(instance);
    }

    public double[] distributionForInstance(Instance instance) throws Exception {
        return this.root.distributionForInstance(instance);
    }

    public Capabilities getCapabilities() {
        return null;
    }

    private void completeTree(MCTreeNode mCTreeNode) {
        if (mCTreeNode.isCompletelyConfigured()) {
            return;
        }
        Iterator<MCTreeNode> it = mCTreeNode.iterator();
        while (it.hasNext()) {
            MCTreeNode next = it.next();
            if (next.getChildren().isEmpty() && next.getContainedClasses2().size() != 1) {
                next.setNodeType(EMCNodeType.DIRECT);
                next.setBaseClassifier(new OneR());
                Iterator<Integer> it2 = next.getContainedClasses2().iterator();
                while (it2.hasNext()) {
                    try {
                        next.addChild(new MCTreeNodeLeaf(it2.next().intValue()));
                    } catch (Exception e) {
                        this.logger.error(LoggerUtil.getExceptionInfo(e));
                    }
                }
            }
        }
    }

    private int getLossForClassifier(MCTreeNode mCTreeNode, Instances instances) {
        int round;
        completeTree(mCTreeNode);
        synchronized (this) {
            try {
                DescriptiveStatistics descriptiveStatistics = new DescriptiveStatistics();
                for (int i = 0; i < 2; i++) {
                    List<IWekaInstances> stratifiedSplit = WekaUtil.getStratifiedSplit(new WekaInstances(instances), this.seed + i, 0.6000000238418579d);
                    mCTreeNode.buildClassifier(stratifiedSplit.get(0).m66getList());
                    Evaluation evaluation = new Evaluation(instances);
                    evaluation.evaluateModel(mCTreeNode, stratifiedSplit.get(1).m66getList(), new Object[0]);
                    descriptiveStatistics.addValue(evaluation.pctIncorrect());
                }
                round = (int) Math.round(descriptiveStatistics.getMean() * 100.0d);
            } catch (Exception e) {
                this.logger.error(LoggerUtil.getExceptionInfo(e));
                return Integer.MAX_VALUE;
            }
        }
        return round;
    }

    private MCTreeNode getTreeFromSolution(List<RestProblem> list, Instances instances, boolean z) {
        List<Decision> list2 = (List) list.stream().filter(restProblem -> {
            return restProblem.getEdgeToParent() != null;
        }).map((v0) -> {
            return v0.getEdgeToParent();
        }).collect(Collectors.toList());
        LinkedList linkedList = new LinkedList();
        Attribute classAttribute = instances.classAttribute();
        MCTreeNode mCTreeNode = new MCTreeNode((List) IntStream.range(0, classAttribute.numValues()).mapToObj(i -> {
            return Integer.valueOf(i);
        }).collect(Collectors.toList()));
        linkedList.addFirst(mCTreeNode);
        for (Decision decision : list2) {
            MCTreeNode mCTreeNode2 = (MCTreeNode) linkedList.removeFirst();
            if (mCTreeNode2 == null) {
                throw new IllegalStateException("No node to apply the decision to! Apparently, there are more decisions for nodes than there are inner nodes.");
            }
            mCTreeNode2.setNodeType(decision.getClassificationType());
            mCTreeNode2.setBaseClassifier(decision.getBaseClassifier());
            if (decision.getLft() == null || decision.getRgt() == null) {
                Iterator<Integer> it = mCTreeNode2.getContainedClasses2().iterator();
                while (it.hasNext()) {
                    try {
                        mCTreeNode2.addChild(new MCTreeNodeLeaf(it.next().intValue()));
                    } catch (Exception e) {
                        this.logger.error(LoggerUtil.getExceptionInfo(e));
                    }
                }
            } else {
                boolean z2 = false;
                ArrayList arrayList = new ArrayList(decision.getLft());
                if (arrayList.size() == 1) {
                    try {
                        mCTreeNode2.addChild(new MCTreeNodeLeaf(classAttribute.indexOfValue((String) arrayList.get(0))));
                    } catch (Exception e2) {
                        this.logger.error(LoggerUtil.getExceptionInfo(e2));
                    }
                } else {
                    Stream stream = arrayList.stream();
                    Objects.requireNonNull(classAttribute);
                    MCTreeNode mCTreeNode3 = new MCTreeNode((List) stream.map(classAttribute::indexOfValue).collect(Collectors.toList()));
                    mCTreeNode2.addChild(mCTreeNode3);
                    z2 = true;
                    linkedList.push(mCTreeNode3);
                }
                ArrayList arrayList2 = new ArrayList(decision.getRgt());
                if (arrayList2.size() == 1) {
                    try {
                        mCTreeNode2.addChild(new MCTreeNodeLeaf(instances.classAttribute().indexOfValue((String) arrayList2.get(0))));
                    } catch (Exception e3) {
                        this.logger.error(LoggerUtil.getExceptionInfo(e3));
                    }
                } else {
                    Stream stream2 = arrayList2.stream();
                    Objects.requireNonNull(classAttribute);
                    MCTreeNode mCTreeNode4 = new MCTreeNode((List) stream2.map(classAttribute::indexOfValue).collect(Collectors.toList()));
                    mCTreeNode2.addChild(mCTreeNode4);
                    if (z2) {
                        MCTreeNode mCTreeNode5 = (MCTreeNode) linkedList.pop();
                        linkedList.push(mCTreeNode4);
                        linkedList.push(mCTreeNode5);
                    } else {
                        linkedList.push(mCTreeNode4);
                    }
                }
            }
        }
        if (!z || linkedList.isEmpty()) {
            return mCTreeNode;
        }
        throw new IllegalStateException("Not all nodes have been equipped with decisions!");
    }
}
