package com.gengoai.apollo.ml.model.clustering;

import com.gengoai.Validation;
import com.gengoai.apollo.math.linalg.NDArray;
import com.gengoai.apollo.math.linalg.NDArrayFactory;
import com.gengoai.apollo.math.statistics.measure.Measure;
import java.util.Collections;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import java.util.Objects;
import org.apache.mahout.math.list.DoubleArrayList;

/* loaded from: input_file:com/gengoai/apollo/ml/model/clustering/HierarchicalClustering.class */
public class HierarchicalClustering implements Clustering {
    private static final long serialVersionUID = 1;
    protected Cluster root;
    private Measure measure;

    public Clustering asFlat(double d) {
        FlatClustering flatClustering = new FlatClustering();
        flatClustering.setMeasure(this.measure);
        process(this.root, flatClustering, d);
        for (int i = 0; i < flatClustering.size(); i++) {
            NDArray array = NDArrayFactory.ND.array(flatClustering.get(i).getPoints().get(0).shape());
            List<NDArray> points = flatClustering.get(i).getPoints();
            Objects.requireNonNull(array);
            points.forEach(array::addi);
            array.divi(flatClustering.get(i).size());
            flatClustering.get(i).setCentroid(array);
        }
        return flatClustering;
    }

    public double calculatePercentile(double d) {
        Validation.checkArgument(d > 0.0d && d <= 1.0d, "Percentile must be > 0 and <= 1");
        DoubleArrayList doubleArrayList = new DoubleArrayList();
        LinkedList linkedList = new LinkedList();
        linkedList.add(this.root);
        while (linkedList.size() > 0) {
            Cluster cluster = (Cluster) linkedList.remove();
            if (cluster != null) {
                doubleArrayList.add(cluster.getScore());
                linkedList.add(cluster.getLeft());
                linkedList.add(cluster.getRight());
            }
        }
        doubleArrayList.sort();
        return doubleArrayList.size() > 0 ? doubleArrayList.get((int) Math.floor(doubleArrayList.size() * d)) : this.root.getScore();
    }

    @Override // com.gengoai.apollo.ml.model.clustering.Clustering
    public Cluster get(int i) {
        if (i == 0) {
            return this.root;
        }
        throw new IndexOutOfBoundsException();
    }

    @Override // com.gengoai.apollo.ml.model.clustering.Clustering
    public Cluster getRoot() {
        return this.root;
    }

    @Override // com.gengoai.apollo.ml.model.clustering.Clustering
    public boolean isFlat() {
        return false;
    }

    @Override // com.gengoai.apollo.ml.model.clustering.Clustering
    public boolean isHierarchical() {
        return true;
    }

    @Override // java.lang.Iterable
    public Iterator<Cluster> iterator() {
        return Collections.singleton(this.root).iterator();
    }

    private void process(Cluster cluster, FlatClustering flatClustering, double d) {
        if (cluster == null) {
            return;
        }
        if (this.measure.getOptimum().test(cluster.getScore(), d)) {
            flatClustering.add(cluster);
        } else {
            process(cluster.getLeft(), flatClustering, d);
            process(cluster.getRight(), flatClustering, d);
        }
    }

    @Override // com.gengoai.apollo.ml.model.clustering.Clustering
    public int size() {
        return 1;
    }

    @Override // com.gengoai.apollo.ml.model.clustering.Clustering
    public Measure getMeasure() {
        return this.measure;
    }

    @Override // com.gengoai.apollo.ml.model.clustering.Clustering
    public void setMeasure(Measure measure) {
        this.measure = measure;
    }
}
