package ai.libs.jaicore.ml.weka.rangequery.learner.intervaltree;

import ai.libs.jaicore.ml.weka.rangequery.learner.intervaltree.aggregation.AggressiveAggregator;
import ai.libs.jaicore.ml.weka.rangequery.learner.intervaltree.aggregation.IntervalAggregator;
import ai.libs.jaicore.ml.weka.rangequery.learner.intervaltree.featurespace.CategoricalFeatureDomain;
import ai.libs.jaicore.ml.weka.rangequery.learner.intervaltree.featurespace.FeatureDomain;
import ai.libs.jaicore.ml.weka.rangequery.learner.intervaltree.featurespace.FeatureSpace;
import ai.libs.jaicore.ml.weka.rangequery.learner.intervaltree.featurespace.NumericFeatureDomain;
import ai.libs.jaicore.ml.weka.rangequery.learner.intervaltree.util.RQPHelper;
import com.google.common.collect.Lists;
import com.google.common.collect.Sets;
import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
import org.apache.commons.math3.geometry.euclidean.oned.Interval;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import weka.classifiers.trees.RandomTree;

/* loaded from: input_file:ai/libs/jaicore/ml/weka/rangequery/learner/intervaltree/ExtendedRandomTree.class */
public class ExtendedRandomTree extends RandomTree implements RangeQueryPredictor {
    private static final Logger LOGGER;
    private static final String LOG_WARN_VARIANCE_ZERO = "The trees total variance is zero, predictions make no sense at this point!";
    private static final String LOG_WARN_NOT_PREPARED = "Tree is not prepared, preprocessing may take a while";
    private static final String LOG_INDIVIDUAL_VAR = "Individual var for {} = {}";
    private static final String LOG_TOTAL_VAR = "current total variance for {} = {}";
    private static final long serialVersionUID = -467555221387281335L;
    private final IntervalAggregator intervalAggregator;
    private FeatureSpace featureSpace;
    private HashMap<RandomTree.Tree, FeatureSpace> partitioning;
    private ArrayList<RandomTree.Tree> leaves;
    private ArrayList<Set<Double>> splitPoints;
    private double totalVariance;
    private transient Observation[][] allObservations;
    private HashMap<Set<Integer>, Double> varianceOfSubsetIndividual;
    private HashMap<Set<Integer>, Double> varianceOfSubsetTotal;
    private HashMap<RandomTree.Tree, Double> mapForEmptyLeaves;
    private boolean isPrepared;
    static final /* synthetic */ boolean $assertionsDisabled;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:ai/libs/jaicore/ml/weka/rangequery/learner/intervaltree/ExtendedRandomTree$Observation.class */
    public class Observation {
        private double midPoint;
        private double intervalSize;

        public Observation(double d, double d2) {
            this.midPoint = d;
            this.intervalSize = d2;
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:ai/libs/jaicore/ml/weka/rangequery/learner/intervaltree/ExtendedRandomTree$WeightedVarianceHelper.class */
    public class WeightedVarianceHelper {
        private double average = 0.0d;
        private double squaredDistanceToMean = 0.0d;
        private double sumOfWeights = 0.0d;

        public WeightedVarianceHelper() {
        }

        public void push(double d, double d2) {
            if (d2 <= 0.0d) {
                throw new IllegalArgumentException("Weights have to be strictly positive!");
            }
            double d3 = d - this.average;
            this.sumOfWeights += d2;
            this.average += (d3 * d2) / this.sumOfWeights;
            this.squaredDistanceToMean += d2 * d3 * (d - this.average);
        }

        public double getPopulaionVariance() {
            if (this.sumOfWeights > 0.0d) {
                return Math.max(0.0d, this.squaredDistanceToMean / this.sumOfWeights);
            }
            return Double.NaN;
        }
    }

    public ExtendedRandomTree() {
        this(new AggressiveAggregator());
        this.partitioning = new HashMap<>();
        this.leaves = new ArrayList<>();
        setAllowUnclassifiedInstances(false);
        this.varianceOfSubsetTotal = new HashMap<>();
        this.varianceOfSubsetIndividual = new HashMap<>();
        this.mapForEmptyLeaves = new HashMap<>();
        this.isPrepared = false;
    }

    public ExtendedRandomTree(FeatureSpace featureSpace) {
        this();
        this.featureSpace = featureSpace;
        this.isPrepared = false;
    }

    public ExtendedRandomTree(IntervalAggregator intervalAggregator) {
        try {
            setOptions(new String[]{"-U"});
            this.intervalAggregator = intervalAggregator;
            this.partitioning = new HashMap<>();
            this.leaves = new ArrayList<>();
            setAllowUnclassifiedInstances(false);
            this.varianceOfSubsetTotal = new HashMap<>();
            this.varianceOfSubsetIndividual = new HashMap<>();
            this.mapForEmptyLeaves = new HashMap<>();
            this.isPrepared = false;
        } catch (Exception e) {
            throw new IllegalStateException("Couldn't unprune the tree");
        }
    }

    @Override // ai.libs.jaicore.ml.weka.rangequery.learner.intervaltree.RangeQueryPredictor
    public Interval predictInterval(RQPHelper.IntervalAndHeader intervalAndHeader) {
        Interval[] intervals = intervalAndHeader.getIntervals();
        ArrayDeque arrayDeque = new ArrayDeque();
        arrayDeque.push(RQPHelper.getEntry(intervals, this.m_Tree));
        ArrayList arrayList = new ArrayList();
        while (arrayDeque.peek() != null) {
            Map.Entry entry = (Map.Entry) arrayDeque.pop();
            RandomTree.Tree tree = (RandomTree.Tree) entry.getValue();
            double m_SplitPoint = tree.getM_SplitPoint();
            int m_Attribute = tree.getM_Attribute();
            RandomTree.Tree[] m_Successors = tree.getM_Successors();
            double[] m_Classdistribution = tree.getM_Classdistribution();
            if (m_Attribute == -1) {
                arrayList.add(Double.valueOf(m_Classdistribution[0]));
            } else {
                Interval interval = intervals[m_Attribute];
                RandomTree.Tree tree2 = m_Successors[0];
                RandomTree.Tree tree3 = m_Successors[1];
                if (interval.getInf() <= m_SplitPoint) {
                    if (m_SplitPoint <= interval.getSup()) {
                        Interval[] substituteInterval = RQPHelper.substituteInterval((Interval[]) entry.getKey(), new Interval(interval.getInf(), m_SplitPoint), m_Attribute);
                        Interval[] substituteInterval2 = RQPHelper.substituteInterval((Interval[]) entry.getKey(), new Interval(m_SplitPoint, interval.getSup()), m_Attribute);
                        arrayDeque.push(RQPHelper.getEntry(substituteInterval, tree2));
                        arrayDeque.push(RQPHelper.getEntry(substituteInterval2, tree3));
                    } else {
                        arrayDeque.push(RQPHelper.getEntry((Interval[]) entry.getKey(), tree2));
                    }
                }
                if (interval.getSup() > m_SplitPoint) {
                    arrayDeque.push(RQPHelper.getEntry((Interval[]) entry.getKey(), tree3));
                }
            }
        }
        return this.intervalAggregator.aggregate(arrayList);
    }

    public void setFeatureSpace(FeatureSpace featureSpace) {
        this.featureSpace = featureSpace;
    }

    public FeatureSpace getFeatureSpace() {
        return this.featureSpace;
    }

    public double computeMarginalStandardDeviationForSubsetOfFeatures(Set<Integer> set) {
        if (!this.isPrepared) {
            LOGGER.warn(LOG_WARN_NOT_PREPARED);
            preprocess();
        }
        Set<Integer> unmodifiableSet = Collections.unmodifiableSet(set);
        if (this.totalVariance == 0.0d) {
            LOGGER.warn(LOG_WARN_VARIANCE_ZERO);
            return Double.NaN;
        }
        double doubleValue = this.varianceOfSubsetTotal.containsKey(unmodifiableSet) ? this.varianceOfSubsetTotal.get(unmodifiableSet).doubleValue() : computeTotalVarianceOfSubset(unmodifiableSet);
        LOGGER.trace(LOG_TOTAL_VAR, unmodifiableSet, Double.valueOf(doubleValue));
        for (int i = 1; i < unmodifiableSet.size(); i++) {
            for (Set set2 : Sets.combinations(unmodifiableSet, i)) {
                if (!set2.isEmpty()) {
                    LOGGER.trace("Subtracting {} for {}", this.varianceOfSubsetIndividual.get(set2), set2);
                    doubleValue -= this.varianceOfSubsetIndividual.get(set2).doubleValue();
                }
            }
        }
        LOGGER.trace(LOG_INDIVIDUAL_VAR, unmodifiableSet, Double.valueOf(doubleValue));
        if (doubleValue < 0.0d) {
            doubleValue = 0.0d;
        }
        this.varianceOfSubsetIndividual.put(unmodifiableSet, Double.valueOf(doubleValue));
        return Math.sqrt(doubleValue);
    }

    public double computeMarginalVarianceContributionForSubsetOfFeatures(Set<Integer> set) {
        if (!this.isPrepared) {
            LOGGER.warn(LOG_WARN_NOT_PREPARED);
            preprocess();
        }
        Set<Integer> unmodifiableSet = Collections.unmodifiableSet(set);
        if (this.totalVariance == 0.0d) {
            LOGGER.warn(LOG_WARN_VARIANCE_ZERO);
            return Double.NaN;
        }
        double doubleValue = this.varianceOfSubsetTotal.containsKey(unmodifiableSet) ? this.varianceOfSubsetTotal.get(unmodifiableSet).doubleValue() : computeTotalVarianceOfSubset(unmodifiableSet);
        LOGGER.trace(LOG_TOTAL_VAR, unmodifiableSet, Double.valueOf(doubleValue));
        for (int i = 1; i < unmodifiableSet.size(); i++) {
            for (Set set2 : Sets.combinations(unmodifiableSet, i)) {
                if (!set2.isEmpty()) {
                    LOGGER.trace("Subtracting {} for {} ", this.varianceOfSubsetIndividual.get(set2), set2);
                    doubleValue -= this.varianceOfSubsetIndividual.get(set2).doubleValue();
                }
            }
        }
        LOGGER.trace(LOG_INDIVIDUAL_VAR, unmodifiableSet, Double.valueOf(doubleValue));
        double max = Math.max(doubleValue, 0.0d);
        this.varianceOfSubsetIndividual.put(unmodifiableSet, Double.valueOf(max));
        return max / this.totalVariance;
    }

    public double computeMarginalVarianceContributionForSubsetOfFeaturesNotNormalized(Set<Integer> set) {
        if (!this.isPrepared) {
            LOGGER.warn(LOG_WARN_NOT_PREPARED);
            preprocess();
        }
        Set<Integer> unmodifiableSet = Collections.unmodifiableSet(set);
        if (this.totalVariance == 0.0d) {
            LOGGER.warn(LOG_WARN_VARIANCE_ZERO);
            return Double.NaN;
        }
        double doubleValue = this.varianceOfSubsetTotal.containsKey(unmodifiableSet) ? this.varianceOfSubsetTotal.get(unmodifiableSet).doubleValue() : computeTotalVarianceOfSubset(unmodifiableSet);
        LOGGER.trace(LOG_TOTAL_VAR, unmodifiableSet, Double.valueOf(doubleValue));
        for (int i = 1; i < unmodifiableSet.size(); i++) {
            for (Set set2 : Sets.combinations(unmodifiableSet, i)) {
                if (!set2.isEmpty()) {
                    LOGGER.trace("Subtracting {} for {} ", this.varianceOfSubsetIndividual.get(set2), set2);
                    doubleValue -= this.varianceOfSubsetIndividual.get(set2).doubleValue();
                }
            }
        }
        LOGGER.trace(LOG_INDIVIDUAL_VAR, unmodifiableSet, Double.valueOf(doubleValue));
        if (doubleValue < 0.0d) {
            doubleValue = 0.0d;
        }
        this.varianceOfSubsetIndividual.put(unmodifiableSet, Double.valueOf(doubleValue));
        return doubleValue;
    }

    private double getMarginalPrediction(List<Integer> list, List<Observation> list2) {
        double d;
        double d2 = 0.0d;
        HashSet hashSet = new HashSet();
        hashSet.addAll(list);
        ArrayList arrayList = new ArrayList(list2.size());
        Iterator<Observation> it = list2.iterator();
        while (it.hasNext()) {
            arrayList.add(Double.valueOf(it.next().midPoint));
        }
        boolean z = false;
        Iterator<Map.Entry<RandomTree.Tree, FeatureSpace>> it2 = this.partitioning.entrySet().iterator();
        while (it2.hasNext()) {
            RandomTree.Tree key = it2.next().getKey();
            if (this.partitioning.get(key).containsPartialInstance(list, arrayList)) {
                double rangeSizeOfAllButSubset = this.partitioning.get(key).getRangeSizeOfAllButSubset(hashSet) / this.featureSpace.getRangeSizeOfAllButSubset(hashSet);
                if (key.getM_Classdistribution() != null) {
                    d = key.getM_Classdistribution()[0];
                } else if (this.mapForEmptyLeaves.containsKey(key)) {
                    d = this.mapForEmptyLeaves.get(key).doubleValue();
                } else {
                    LOGGER.warn("No prediction found anywhere!");
                    d = Double.NaN;
                }
                if (!$assertionsDisabled && d == Double.NaN) {
                    throw new AssertionError("Prediction must not be NaN");
                }
                d2 += d * rangeSizeOfAllButSubset;
                z = true;
            }
        }
        if (!z) {
            LOGGER.warn("Observation {} is not consistent with any leaf with indices: {}", arrayList, list);
        }
        return d2;
    }

    private void computePartitioning(FeatureSpace featureSpace, RandomTree.Tree tree) {
        double m_SplitPoint = tree.getM_SplitPoint();
        int m_Attribute = tree.getM_Attribute();
        RandomTree.Tree[] m_Successors = tree.getM_Successors();
        if (m_Attribute == -1) {
            this.leaves.add(tree);
            this.partitioning.put(tree, featureSpace);
            return;
        }
        if (!(featureSpace.getFeatureDomain(m_Attribute) instanceof CategoricalFeatureDomain)) {
            if (featureSpace.getFeatureDomain(m_Attribute) instanceof NumericFeatureDomain) {
                FeatureSpace featureSpace2 = new FeatureSpace(featureSpace);
                ((NumericFeatureDomain) featureSpace2.getFeatureDomain(m_Attribute)).setMax(m_SplitPoint);
                FeatureSpace featureSpace3 = new FeatureSpace(featureSpace);
                ((NumericFeatureDomain) featureSpace3.getFeatureDomain(m_Attribute)).setMin(m_SplitPoint);
                computePartitioning(featureSpace2, m_Successors[0]);
                computePartitioning(featureSpace3, m_Successors[1]);
                return;
            }
            return;
        }
        for (int i = 0; i < m_Successors.length; i++) {
            if (m_Successors[i].getM_Classdistribution() == null && m_Successors[i].getM_Attribute() == -1) {
                this.mapForEmptyLeaves.put(m_Successors[i], Double.valueOf(tree.getM_Classdistribution()[0]));
            }
            FeatureSpace featureSpace4 = new FeatureSpace(featureSpace);
            ((CategoricalFeatureDomain) featureSpace4.getFeatureDomain(m_Attribute)).setValues(new double[]{i});
            computePartitioning(featureSpace4, m_Successors[i]);
        }
    }

    private void collectSplitPointsAndIntervalSizes(RandomTree.Tree tree) {
        this.splitPoints = new ArrayList<>(this.featureSpace.getDimensionality());
        ArrayList arrayList = new ArrayList(this.featureSpace.getDimensionality());
        for (int i = 0; i < this.featureSpace.getDimensionality(); i++) {
            this.splitPoints.add(i, new HashSet());
            arrayList.add(i, new ArrayList());
        }
        LinkedList linkedList = new LinkedList();
        linkedList.add(tree);
        while (!linkedList.isEmpty()) {
            RandomTree.Tree tree2 = (RandomTree.Tree) linkedList.poll();
            if (tree2.getM_Attribute() > -1) {
                this.splitPoints.get(tree2.getM_Attribute()).add(Double.valueOf(tree2.getM_SplitPoint()));
                ((ArrayList) arrayList.get(tree2.getM_Attribute())).add(Double.valueOf(tree2.getM_SplitPoint()));
                for (int i2 = 0; i2 < tree2.getM_Successors().length; i2++) {
                    linkedList.add(tree2.getM_Successors()[i2]);
                }
            }
        }
    }

    /* JADX WARN: Type inference failed for: r1v3, types: [ai.libs.jaicore.ml.weka.rangequery.learner.intervaltree.ExtendedRandomTree$Observation[], ai.libs.jaicore.ml.weka.rangequery.learner.intervaltree.ExtendedRandomTree$Observation[][]] */
    private void computeObservations() {
        this.allObservations = new Observation[this.featureSpace.getDimensionality()];
        for (int i = 0; i < this.featureSpace.getDimensionality(); i++) {
            ArrayList arrayList = new ArrayList();
            arrayList.addAll(this.splitPoints.get(i));
            FeatureDomain featureDomain = this.featureSpace.getFeatureDomain(i);
            if (featureDomain instanceof NumericFeatureDomain) {
                NumericFeatureDomain numericFeatureDomain = (NumericFeatureDomain) featureDomain;
                arrayList.add(Double.valueOf(numericFeatureDomain.getMin()));
                arrayList.add(Double.valueOf(numericFeatureDomain.getMax()));
                Collections.sort(arrayList);
                if (arrayList.isEmpty()) {
                    this.allObservations[i] = new Observation[0];
                } else {
                    this.allObservations[i] = new Observation[arrayList.size() - 1];
                    for (int i2 = 0; i2 < arrayList.size() - 1; i2++) {
                        this.allObservations[i][i2] = ((Double) arrayList.get(i2 + 1)).doubleValue() - ((Double) arrayList.get(i2)).doubleValue() > 0.0d ? new Observation((((Double) arrayList.get(i2)).doubleValue() + ((Double) arrayList.get(i2 + 1)).doubleValue()) / 2.0d, ((Double) arrayList.get(i2 + 1)).doubleValue() - ((Double) arrayList.get(i2)).doubleValue()) : new Observation((((Double) arrayList.get(i2)).doubleValue() + ((Double) arrayList.get(i2 + 1)).doubleValue()) / 2.0d, 1.0d);
                    }
                }
            } else if (featureDomain instanceof CategoricalFeatureDomain) {
                CategoricalFeatureDomain categoricalFeatureDomain = (CategoricalFeatureDomain) featureDomain;
                this.allObservations[i] = new Observation[categoricalFeatureDomain.getValues().length];
                for (int i3 = 0; i3 < this.allObservations[i].length; i3++) {
                    this.allObservations[i][i3] = new Observation(categoricalFeatureDomain.getValues()[i3], 1.0d);
                }
            }
        }
    }

    public double computeTotalVarianceOfSubset(Set<Integer> set) {
        Set<Integer> unmodifiableSet = Collections.unmodifiableSet(set);
        if (this.varianceOfSubsetTotal.containsKey(unmodifiableSet)) {
            return this.varianceOfSubsetTotal.get(unmodifiableSet).doubleValue();
        }
        LinkedList linkedList = new LinkedList();
        LinkedList linkedList2 = new LinkedList();
        Iterator<Integer> it = unmodifiableSet.iterator();
        while (it.hasNext()) {
            List list = (List) Arrays.stream(this.allObservations[it.next().intValue()]).collect(Collectors.toList());
            HashSet hashSet = new HashSet();
            hashSet.addAll(list);
            linkedList.add(list);
            linkedList2.add(hashSet);
        }
        List<List<Observation>> cartesianProduct = Lists.cartesianProduct(linkedList);
        WeightedVarianceHelper weightedVarianceHelper = new WeightedVarianceHelper();
        for (List<Observation> list2 : cartesianProduct) {
            ArrayList arrayList = new ArrayList();
            arrayList.addAll(unmodifiableSet);
            Collections.sort(arrayList);
            double marginalPrediction = getMarginalPrediction(arrayList, list2);
            double d = 1.0d;
            for (Observation observation : list2) {
                if (observation.intervalSize != 0.0d) {
                    d *= observation.intervalSize;
                }
            }
            double rangeSizeOfAllButSubset = getFeatureSpace().getRangeSizeOfAllButSubset(unmodifiableSet);
            if (!Double.isNaN(marginalPrediction)) {
                weightedVarianceHelper.push(marginalPrediction, rangeSizeOfAllButSubset * d);
            }
        }
        double populaionVariance = weightedVarianceHelper.getPopulaionVariance();
        this.varianceOfSubsetTotal.put(unmodifiableSet, Double.valueOf(populaionVariance));
        return populaionVariance;
    }

    public double getTotalVariance() {
        return this.totalVariance;
    }

    public void preprocess() {
        computePartitioning(this.featureSpace, this.m_Tree);
        collectSplitPointsAndIntervalSizes(this.m_Tree);
        computeObservations();
        HashSet hashSet = new HashSet();
        for (int i = 0; i < this.featureSpace.getDimensionality(); i++) {
            hashSet.add(Integer.valueOf(i));
        }
        this.totalVariance = computeTotalVarianceOfSubset(hashSet);
        this.isPrepared = true;
    }

    public void printObservations() {
        for (int i = 0; i < this.allObservations.length; i++) {
            StringBuilder sb = new StringBuilder();
            for (int i2 = 0; i2 < this.allObservations[i].length; i2++) {
                sb.append(this.allObservations[i][i2].midPoint + ", ");
            }
            LOGGER.debug("Observations for feature {}: {}", Integer.valueOf(i), sb);
        }
    }

    public void printSplitPoints() {
        for (int i = 0; i < this.splitPoints.size(); i++) {
            ArrayList arrayList = new ArrayList(this.splitPoints.get(i));
            if (getFeatureSpace().getFeatureDomain(i) instanceof NumericFeatureDomain) {
                arrayList.add(Double.valueOf(((NumericFeatureDomain) getFeatureSpace().getFeatureDomain(i)).getMin()));
                arrayList.add(Double.valueOf(((NumericFeatureDomain) getFeatureSpace().getFeatureDomain(i)).getMax()));
            }
            Collections.sort(arrayList);
        }
    }

    public void printSizeOfFeatureSpaceAndPartitioning() {
        LOGGER.debug("Size of feature space: {}", Double.valueOf(this.featureSpace.getRangeSize()));
        double d = 0.0d;
        Iterator<Map.Entry<RandomTree.Tree, FeatureSpace>> it = this.partitioning.entrySet().iterator();
        while (it.hasNext()) {
            d += this.partitioning.get(it.next().getKey()).getRangeSize();
        }
        LOGGER.debug("Complete size of partitioning: {}", Double.valueOf(d));
        double d2 = 1.0d;
        for (int i = 0; i < this.allObservations.length; i++) {
            double d3 = 0.0d;
            for (int i2 = 0; i2 < this.allObservations[i].length; i2++) {
                d3 += this.allObservations[i][i2].intervalSize;
            }
            d2 *= d3;
        }
        LOGGER.debug("Complete size of intervals: {}", Double.valueOf(d2));
    }

    static {
        $assertionsDisabled = !ExtendedRandomTree.class.desiredAssertionStatus();
        LOGGER = LoggerFactory.getLogger(ExtendedRandomTree.class);
    }
}
