package ai.libs.jaicore.ml.weka.classification.learner.reduction;

import ai.libs.jaicore.ml.weka.WekaUtil;
import ai.libs.jaicore.ml.weka.classification.singlelabel.timeseries.learner.trees.TimeSeriesTreeLearningAlgorithm;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.locks.Lock;
import java.util.concurrent.locks.ReentrantLock;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import java.util.stream.Stream;
import org.apache.commons.lang3.builder.HashCodeBuilder;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import weka.classifiers.AbstractClassifier;
import weka.classifiers.Classifier;
import weka.classifiers.meta.MultiClassClassifier;
import weka.classifiers.rules.ZeroR;
import weka.core.Capabilities;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.WekaException;

/* loaded from: input_file:ai/libs/jaicore/ml/weka/classification/learner/reduction/MCTreeNode.class */
public class MCTreeNode extends AMCTreeNode<Integer> implements ITreeClassifier, Iterable<MCTreeNode> {
    private static final long serialVersionUID = 8873192747068561266L;
    private EMCNodeType nodeType;
    private List<MCTreeNode> children;
    private Classifier classifier;
    private String classifierID;
    private boolean trained;
    private transient Logger logger;
    public static final AtomicInteger cacheRetrievals;
    private static Map<String, Classifier> classifierCacheMap;
    private static Lock classifierCacheMapLock;
    static final /* synthetic */ boolean $assertionsDisabled;

    /* JADX INFO: Access modifiers changed from: package-private */
    /* renamed from: ai.libs.jaicore.ml.weka.classification.learner.reduction.MCTreeNode$2, reason: invalid class name */
    /* loaded from: input_file:ai/libs/jaicore/ml/weka/classification/learner/reduction/MCTreeNode$2.class */
    public static /* synthetic */ class AnonymousClass2 {
        static final /* synthetic */ int[] $SwitchMap$ai$libs$jaicore$ml$weka$classification$learner$reduction$EMCNodeType = new int[EMCNodeType.values().length];

        static {
            try {
                $SwitchMap$ai$libs$jaicore$ml$weka$classification$learner$reduction$EMCNodeType[EMCNodeType.ONEVSREST.ordinal()] = 1;
            } catch (NoSuchFieldError e) {
            }
            try {
                $SwitchMap$ai$libs$jaicore$ml$weka$classification$learner$reduction$EMCNodeType[EMCNodeType.ALLPAIRS.ordinal()] = 2;
            } catch (NoSuchFieldError e2) {
            }
            try {
                $SwitchMap$ai$libs$jaicore$ml$weka$classification$learner$reduction$EMCNodeType[EMCNodeType.DIRECT.ordinal()] = 3;
            } catch (NoSuchFieldError e3) {
            }
        }
    }

    public MCTreeNode(List<Integer> list) {
        super(list);
        this.children = new ArrayList();
        this.trained = false;
        this.logger = LoggerFactory.getLogger(MCTreeNode.class);
    }

    public MCTreeNode(List<Integer> list, EMCNodeType eMCNodeType, String str) throws Exception {
        this(list, eMCNodeType, AbstractClassifier.forName(str, (String[]) null));
    }

    public MCTreeNode(List<Integer> list, EMCNodeType eMCNodeType, Classifier classifier) {
        this(list);
        setNodeType(eMCNodeType);
        setBaseClassifier(classifier);
    }

    public EMCNodeType getNodeType() {
        return this.nodeType;
    }

    public void addChild(MCTreeNode mCTreeNode) {
        if (mCTreeNode.getNodeType() != EMCNodeType.MERGE) {
            this.children.add(mCTreeNode);
            return;
        }
        Iterator<MCTreeNode> it = mCTreeNode.getChildren().iterator();
        while (it.hasNext()) {
            this.children.add(it.next());
        }
    }

    public List<MCTreeNode> getChildren() {
        return this.children;
    }

    public boolean isCompletelyConfigured() {
        if (this.classifier == null || this.children.isEmpty()) {
            return false;
        }
        Iterator<MCTreeNode> it = this.children.iterator();
        while (it.hasNext()) {
            if (!it.next().isCompletelyConfigured()) {
                return false;
            }
        }
        return true;
    }

    public void buildClassifier(Instances instances) throws Exception {
        if (!$assertionsDisabled && getNodeType() == EMCNodeType.MERGE) {
            throw new AssertionError("MERGE node detected while building classifier. This must not happen!");
        }
        if (instances.isEmpty()) {
            throw new IllegalArgumentException("Cannot train MCTree with empty set of instances.");
        }
        if (this.children.isEmpty()) {
            throw new IllegalStateException("Cannot train MCTree without children");
        }
        ArrayList arrayList = new ArrayList();
        IntStream.range(0, this.children.size()).forEach(i -> {
            arrayList.add(new HashSet());
        });
        int i2 = 0;
        Iterator<MCTreeNode> it = this.children.iterator();
        while (it.hasNext()) {
            Iterator<Integer> it2 = it.next().getContainedClasses2().iterator();
            while (it2.hasNext()) {
                ((Set) arrayList.get(i2)).add(instances.classAttribute().value(it2.next().intValue()));
            }
            i2++;
        }
        String str = this.classifier.getClass().getName() + "#" + arrayList + "#" + instances.size() + "#" + new HashCodeBuilder().append(instances.toString()).toHashCode();
        Instances mergeClassesOfInstances = WekaUtil.mergeClassesOfInstances(instances, arrayList);
        try {
            this.classifier.buildClassifier(mergeClassesOfInstances);
        } catch (WekaException e) {
            this.classifier = new ZeroR();
            this.classifier.buildClassifier(mergeClassesOfInstances);
        }
        classifierCacheMapLock.lock();
        try {
            classifierCacheMap.put(str, this.classifier);
            classifierCacheMapLock.unlock();
            ((Stream) this.children.stream().parallel()).forEach(mCTreeNode -> {
                try {
                    mCTreeNode.buildClassifier(instances);
                } catch (Exception e2) {
                    this.logger.error("Encountered problem when training MCTreeNode.", e2);
                }
            });
            this.trained = true;
        } catch (Throwable th) {
            classifierCacheMapLock.unlock();
            throw th;
        }
    }

    public void distributionForInstance(Instance instance, double[] dArr) throws Exception {
        double[] distributionForInstance = this.classifier.distributionForInstance(WekaUtil.getRefactoredInstance(instance, (List) IntStream.range(0, this.children.size()).mapToObj(i -> {
            return i + ".0";
        }).collect(Collectors.toList())));
        for (MCTreeNode mCTreeNode : this.children) {
            mCTreeNode.distributionForInstance(instance, dArr);
            int indexOf = this.children.indexOf(mCTreeNode);
            Iterator<Integer> it = mCTreeNode.getContainedClasses2().iterator();
            while (it.hasNext()) {
                int intValue = it.next().intValue();
                dArr[intValue] = dArr[intValue] * distributionForInstance[indexOf];
            }
        }
    }

    public double[] distributionForInstance(Instance instance) throws Exception {
        if (!this.trained) {
            throw new IllegalStateException("Cannot get distribution from untrained classifier " + toStringWithOffset());
        }
        double[] dArr = new double[getContainedClasses2().size()];
        distributionForInstance(instance, dArr);
        return dArr;
    }

    public Capabilities getCapabilities() {
        return this.classifier.getCapabilities();
    }

    @Override // ai.libs.jaicore.ml.weka.classification.learner.reduction.ITreeClassifier
    public int getHeight() {
        Stream<R> map = this.children.stream().map((v0) -> {
            return v0.getHeight();
        });
        Class cls = Integer.TYPE;
        Objects.requireNonNull(cls);
        return 1 + map.mapToInt((v1) -> {
            return r2.cast(v1);
        }).max().getAsInt();
    }

    @Override // ai.libs.jaicore.ml.weka.classification.learner.reduction.ITreeClassifier
    public int getDepthOfFirstCommonParent(List<Integer> list) {
        for (MCTreeNode mCTreeNode : this.children) {
            if (mCTreeNode.getContainedClasses2().containsAll(list)) {
                return 1 + mCTreeNode.getDepthOfFirstCommonParent(list);
            }
        }
        return 1;
    }

    public static void clearCache() {
        classifierCacheMap.clear();
    }

    public static Map<String, Classifier> getClassifierCache() {
        return classifierCacheMap;
    }

    public Classifier getClassifier() {
        return this.classifier;
    }

    public void setBaseClassifier(Classifier classifier) {
        if (classifier == null) {
            throw new IllegalArgumentException("Cannot set null classifier!");
        }
        this.classifierID = classifier.getClass().getName();
        switch (AnonymousClass2.$SwitchMap$ai$libs$jaicore$ml$weka$classification$learner$reduction$EMCNodeType[this.nodeType.ordinal()]) {
            case TimeSeriesTreeLearningAlgorithm.USE_BIAS_CORRECTION /* 1 */:
                MultiClassClassifier multiClassClassifier = new MultiClassClassifier();
                multiClassClassifier.setClassifier(classifier);
                this.classifier = multiClassClassifier;
                return;
            case 2:
                MultiClassClassifier multiClassClassifier2 = new MultiClassClassifier();
                try {
                    multiClassClassifier2.setOptions(new String[]{"-M", "3"});
                } catch (Exception e) {
                    this.logger.error("Observed problem when setting options for classifier.", e);
                }
                multiClassClassifier2.setClassifier(classifier);
                this.classifier = multiClassClassifier2;
                return;
            case 3:
                this.classifier = classifier;
                return;
            default:
                return;
        }
    }

    public void setNodeType(EMCNodeType eMCNodeType) {
        this.nodeType = eMCNodeType;
    }

    public String toString() {
        return toStringWithOffset("", null);
    }

    public String toStringWithOffset() {
        return toStringWithOffset("", "  ");
    }

    public String toStringWithOffset(String str, String str2) {
        StringBuilder sb = new StringBuilder();
        sb.append(str).append("(").append(getContainedClasses2()).append(":").append(this.classifierID).append(":").append(this.nodeType).append(") {");
        boolean z = true;
        for (MCTreeNode mCTreeNode : this.children) {
            if (z) {
                z = false;
            } else {
                sb.append(",");
            }
            if (str2 != null) {
                sb.append("\n");
            }
            sb.append(mCTreeNode.toStringWithOffset(str + (str2 != null ? str2 : ""), str2));
        }
        if (str2 != null) {
            sb.append("\n").append(str);
        }
        sb.append("}");
        return sb.toString();
    }

    @Override // java.lang.Iterable
    public Iterator<MCTreeNode> iterator() {
        return new Iterator<MCTreeNode>() { // from class: ai.libs.jaicore.ml.weka.classification.learner.reduction.MCTreeNode.1
            private int currentlyTraversedChild = -1;
            private Iterator<MCTreeNode> childIterator = null;

            @Override // java.util.Iterator
            public boolean hasNext() {
                if (this.currentlyTraversedChild < 0) {
                    return true;
                }
                if (MCTreeNode.this.children.isEmpty()) {
                    return false;
                }
                if (this.childIterator == null) {
                    this.childIterator = ((MCTreeNode) MCTreeNode.this.children.get(this.currentlyTraversedChild)).iterator();
                }
                if (this.childIterator.hasNext()) {
                    return true;
                }
                if (this.currentlyTraversedChild == MCTreeNode.this.children.size() - 1) {
                    return false;
                }
                this.currentlyTraversedChild++;
                this.childIterator = ((MCTreeNode) MCTreeNode.this.children.get(this.currentlyTraversedChild)).iterator();
                return this.childIterator.hasNext();
            }

            /* JADX WARN: Can't rename method to resolve collision */
            @Override // java.util.Iterator
            public MCTreeNode next() {
                if (this.currentlyTraversedChild != -1) {
                    return this.childIterator.next();
                }
                this.currentlyTraversedChild++;
                return MCTreeNode.this;
            }
        };
    }

    static {
        $assertionsDisabled = !MCTreeNode.class.desiredAssertionStatus();
        cacheRetrievals = new AtomicInteger();
        classifierCacheMap = new HashMap();
        classifierCacheMapLock = new ReentrantLock();
    }
}
