package gov.sandia.cognition.learning.algorithm.clustering;

import gov.sandia.cognition.annotation.CodeReview;
import gov.sandia.cognition.annotation.PublicationReference;
import gov.sandia.cognition.annotation.PublicationType;
import gov.sandia.cognition.collection.CollectionUtil;
import gov.sandia.cognition.learning.algorithm.AbstractAnytimeBatchLearner;
import gov.sandia.cognition.learning.algorithm.clustering.cluster.CentroidCluster;
import gov.sandia.cognition.learning.algorithm.clustering.cluster.Cluster;
import gov.sandia.cognition.learning.algorithm.clustering.cluster.IncrementalClusterCreator;
import gov.sandia.cognition.learning.algorithm.clustering.cluster.NormalizedCentroidCluster;
import gov.sandia.cognition.learning.algorithm.clustering.cluster.NormalizedCentroidClusterCreator;
import gov.sandia.cognition.learning.algorithm.clustering.cluster.VectorMeanCentroidClusterCreator;
import gov.sandia.cognition.learning.algorithm.clustering.divergence.CentroidClusterDivergenceFunction;
import gov.sandia.cognition.learning.algorithm.clustering.divergence.ClusterDivergenceFunction;
import gov.sandia.cognition.learning.algorithm.clustering.divergence.WithinClusterDivergence;
import gov.sandia.cognition.learning.algorithm.clustering.divergence.WithinClusterDivergenceWrapper;
import gov.sandia.cognition.learning.algorithm.clustering.divergence.WithinNormalizedCentroidClusterCosineDivergence;
import gov.sandia.cognition.learning.algorithm.clustering.hierarchy.BatchHierarchicalClusterer;
import gov.sandia.cognition.learning.algorithm.clustering.hierarchy.BinaryClusterHierarchyNode;
import gov.sandia.cognition.learning.algorithm.clustering.hierarchy.ClusterHierarchyNode;
import gov.sandia.cognition.learning.function.distance.DivergenceFunctionContainer;
import gov.sandia.cognition.learning.function.distance.EuclideanDistanceMetric;
import gov.sandia.cognition.math.DivergenceFunction;
import gov.sandia.cognition.math.matrix.Vector;
import gov.sandia.cognition.math.matrix.Vectorizable;
import gov.sandia.cognition.util.ArgumentChecker;
import gov.sandia.cognition.util.DefaultPair;
import gov.sandia.cognition.util.ObjectUtil;
import gov.sandia.cognition.util.Randomized;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.Comparator;
import java.util.Iterator;
import java.util.Random;

@CodeReview(reviewer = {"Justin Basilico"}, date = "2011-03-09", comments = {"Should make do a greedy splitting prioritization.", "Should make an interface for incremental cluster creation for this to use."}, changesNeeded = true)
@PublicationReference(author = {"Ying Zhao", "George Karypis"}, title = "Hierarchical Clustering Algorithms for Document Datasets", type = PublicationType.Journal, year = 2005, publication = "Data Mining and Knowledge Discovery", pages = {141, 168}, url = "http://www.springerlink.com/index/jx3825j42x4333m5.pdf")
/* loaded from: input_file:gov/sandia/cognition/learning/algorithm/clustering/PartitionalClusterer.class */
public class PartitionalClusterer<DataType, ClusterType extends Cluster<DataType>> extends AbstractAnytimeBatchLearner<Collection<? extends DataType>, Collection<ClusterType>> implements BatchClusterer<DataType, ClusterType>, BatchHierarchicalClusterer<DataType, ClusterType>, Randomized, DivergenceFunctionContainer<ClusterType, DataType> {
    public static final int DEFAULT_MIN_CLUSTER_SIZE = 1;
    public static final double DEFAULT_MAX_CRITERION_DECREASE = 1.0d;
    public static final int DEFAULT_MAX_ITERATIONS = Integer.MAX_VALUE;
    public static final int DEFAULT_NUM_REQUESTED_CLUSTERS = Integer.MAX_VALUE;
    protected WithinClusterDivergence<? super ClusterType, ? super DataType> clusterDivergenceFunction;
    protected DivergenceFunction<? super ClusterType, ? super DataType> divergenceFunction;
    protected double tolerance;
    protected IncrementalClusterCreator<ClusterType, DataType> creator;
    protected int minClusterSize;
    protected double maxCriterionDecrease;
    private int clusterIndex;
    protected Random random;
    protected transient ArrayList<ClusterType> clusters;
    protected transient ArrayList<BinaryClusterHierarchyNode<DataType, ClusterType>> clustersHierarchy;
    protected int numRequestedClusters;
    protected boolean useCachedClusters;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:gov/sandia/cognition/learning/algorithm/clustering/PartitionalClusterer$NodeCriterionPair.class */
    public class NodeCriterionPair<DataType, ClusterType extends Cluster<DataType>> implements Comparator<PartitionalClusterer<DataType, ClusterType>.NodeCriterionPair<DataType, ClusterType>> {
        ClusterHierarchyNode<DataType, ClusterType> node;
        double criterion;

        public NodeCriterionPair(ClusterHierarchyNode<DataType, ClusterType> clusterHierarchyNode, WithinClusterDivergence<? super ClusterType, ? super DataType> withinClusterDivergence) {
            this.node = clusterHierarchyNode;
            this.criterion = withinClusterDivergence.evaluate(this.node.getCluster());
        }

        @Override // java.util.Comparator
        public int compare(PartitionalClusterer<DataType, ClusterType>.NodeCriterionPair<DataType, ClusterType> nodeCriterionPair, PartitionalClusterer<DataType, ClusterType>.NodeCriterionPair<DataType, ClusterType> nodeCriterionPair2) {
            if (nodeCriterionPair.criterion > nodeCriterionPair2.criterion) {
                return -1;
            }
            return nodeCriterionPair.criterion == nodeCriterionPair2.criterion ? 0 : 1;
        }
    }

    private PartitionalClusterer(int i, IncrementalClusterCreator<ClusterType, DataType> incrementalClusterCreator) {
        super(Integer.MAX_VALUE);
        this.tolerance = 1.0E-10d;
        setNumRequestedClusters(i);
        setCreator(incrementalClusterCreator);
        setMinClusterSize(1);
        setMaxCriterionDecrease(1.0d);
        setClusters(null);
        setClustersHierarchy(null);
        setRandom(new Random());
    }

    public PartitionalClusterer(int i, ClusterDivergenceFunction<ClusterType, DataType> clusterDivergenceFunction, IncrementalClusterCreator<ClusterType, DataType> incrementalClusterCreator) {
        this(i, incrementalClusterCreator);
        setDivergenceFunction(clusterDivergenceFunction);
    }

    /* JADX WARN: Multi-variable type inference failed */
    public PartitionalClusterer(int i, WithinClusterDivergence<ClusterType, DataType> withinClusterDivergence, IncrementalClusterCreator<ClusterType, DataType> incrementalClusterCreator) {
        this(i, incrementalClusterCreator);
        setWithinClusterDivergenceFunction(withinClusterDivergence);
    }

    public static PartitionalClusterer<Vector, CentroidCluster<Vector>> create(int i) {
        return new PartitionalClusterer<>(i, new CentroidClusterDivergenceFunction(EuclideanDistanceMetric.INSTANCE), VectorMeanCentroidClusterCreator.INSTANCE);
    }

    public static PartitionalClusterer<Vectorizable, NormalizedCentroidCluster<Vectorizable>> createSpherical(int i) {
        return new PartitionalClusterer<>(i, new WithinNormalizedCentroidClusterCosineDivergence(), new NormalizedCentroidClusterCreator());
    }

    @Override // gov.sandia.cognition.learning.algorithm.AbstractAnytimeBatchLearner, gov.sandia.cognition.algorithm.AbstractIterativeAlgorithm, gov.sandia.cognition.util.AbstractCloneableSerializable
    /* renamed from: clone */
    public PartitionalClusterer<DataType, ClusterType> mo0clone() {
        PartitionalClusterer<DataType, ClusterType> partitionalClusterer = (PartitionalClusterer) super.mo0clone();
        partitionalClusterer.clusterDivergenceFunction = (WithinClusterDivergence) ObjectUtil.cloneSmart(this.clusterDivergenceFunction);
        partitionalClusterer.divergenceFunction = (DivergenceFunction) ObjectUtil.cloneSmart(this.divergenceFunction);
        partitionalClusterer.creator = (IncrementalClusterCreator) ObjectUtil.cloneSmart(this.creator);
        partitionalClusterer.clusters = null;
        partitionalClusterer.clustersHierarchy = null;
        return partitionalClusterer;
    }

    /* JADX WARN: Multi-variable type inference failed */
    @Override // gov.sandia.cognition.learning.algorithm.clustering.hierarchy.BatchHierarchicalClusterer
    public ClusterHierarchyNode<DataType, ClusterType> clusterHierarchically(Collection<? extends DataType> collection) {
        learn(collection);
        if (CollectionUtil.isEmpty((Collection<?>) this.clustersHierarchy)) {
            return null;
        }
        return this.clustersHierarchy.get(0);
    }

    @Override // gov.sandia.cognition.learning.algorithm.AbstractAnytimeBatchLearner
    protected boolean initializeAlgorithm() {
        if (this.useCachedClusters) {
            return true;
        }
        setClusters(new ArrayList<>());
        setClustersHierarchy(new ArrayList<>());
        ClusterType createCluster = this.creator.createCluster(new ArrayList((Collection) this.data));
        this.clusters.add(createCluster);
        this.clustersHierarchy.add(new BinaryClusterHierarchyNode<>(createCluster));
        this.clusterIndex = 0;
        return true;
    }

    /* JADX WARN: Multi-variable type inference failed */
    @Override // gov.sandia.cognition.learning.algorithm.AbstractAnytimeBatchLearner
    protected boolean step() {
        if (this.clusterIndex >= getClusterCount() || getClusterCount() >= (2 * getNumRequestedClusters()) - 1) {
            return false;
        }
        ClusterType clustertype = this.clusters.get(this.clusterIndex);
        if (clustertype.getMembers().size() <= getMinClusterSize()) {
            return true;
        }
        BinaryClusterHierarchyNode<DataType, ClusterType> binaryClusterHierarchyNode = this.clustersHierarchy.get(this.clusterIndex);
        DefaultPair randomPartition = randomPartition(clustertype);
        Cluster cluster = (Cluster) randomPartition.getFirst();
        Cluster cluster2 = (Cluster) randomPartition.getSecond();
        greedySwap(clustertype.getMembers(), cluster, cluster2);
        if (cluster.getMembers().size() >= this.minClusterSize && cluster2.getMembers().size() >= this.minClusterSize) {
            if (Math.abs((this.clusterDivergenceFunction.evaluate(clustertype) * this.maxCriterionDecrease) - (this.clusterDivergenceFunction.evaluate(cluster) + this.clusterDivergenceFunction.evaluate(cluster2))) > this.tolerance) {
                this.clusters.add(cluster);
                this.clusters.add(cluster2);
                BinaryClusterHierarchyNode<DataType, ClusterType> binaryClusterHierarchyNode2 = new BinaryClusterHierarchyNode<>(cluster);
                BinaryClusterHierarchyNode<DataType, ClusterType> binaryClusterHierarchyNode3 = new BinaryClusterHierarchyNode<>(cluster2);
                this.clustersHierarchy.add(binaryClusterHierarchyNode2);
                this.clustersHierarchy.add(binaryClusterHierarchyNode3);
                binaryClusterHierarchyNode.setFirstChild(binaryClusterHierarchyNode2);
                binaryClusterHierarchyNode.setSecondChild(binaryClusterHierarchyNode3);
            }
        }
        this.clusterIndex++;
        return true;
    }

    @Override // gov.sandia.cognition.learning.algorithm.AbstractAnytimeBatchLearner
    protected void cleanupAlgorithm() {
        if (!this.useCachedClusters) {
        }
    }

    @Override // gov.sandia.cognition.algorithm.AnytimeAlgorithm
    /* renamed from: getResult */
    public ArrayList<ClusterType> getResult2() {
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        NodeCriterionPair nodeCriterionPair = new NodeCriterionPair(this.clustersHierarchy.get(0), this.clusterDivergenceFunction);
        arrayList.add(nodeCriterionPair);
        arrayList2.add(nodeCriterionPair);
        while (arrayList.size() > 0 && arrayList2.size() < getNumRequestedClusters()) {
            NodeCriterionPair nodeCriterionPair2 = (NodeCriterionPair) arrayList.get(0);
            ClusterHierarchyNode<DataType, ClusterType> clusterHierarchyNode = nodeCriterionPair2.node;
            if (clusterHierarchyNode.hasChildren()) {
                Iterator<ClusterHierarchyNode<DataType, ClusterType>> it = clusterHierarchyNode.getChildren().iterator();
                while (it.hasNext()) {
                    NodeCriterionPair nodeCriterionPair3 = new NodeCriterionPair(it.next(), this.clusterDivergenceFunction);
                    int binarySearch = Collections.binarySearch(arrayList, nodeCriterionPair3, nodeCriterionPair3);
                    if (binarySearch < 0) {
                        arrayList.add((-1) * (binarySearch + 1), nodeCriterionPair3);
                        arrayList2.add(nodeCriterionPair3);
                    } else {
                        arrayList.add(binarySearch, nodeCriterionPair3);
                        arrayList2.add(nodeCriterionPair3);
                    }
                }
                arrayList2.remove(nodeCriterionPair2);
            }
            arrayList.remove(0);
        }
        ArrayList<ClusterType> arrayList3 = new ArrayList<>();
        Iterator it2 = arrayList2.iterator();
        while (it2.hasNext()) {
            arrayList3.add(((NodeCriterionPair) it2.next()).node.getCluster());
        }
        return arrayList3;
    }

    /* JADX WARN: Multi-variable type inference failed */
    public Collection<ClusterType> learnUsingCachedClusters(Collection<? extends DataType> collection) {
        if (this.clusters == null || this.clusters.size() <= 0) {
            learn(collection);
        } else if ((this.clusters.size() / 2) + 1 < this.numRequestedClusters) {
            this.useCachedClusters = true;
            learn(collection);
            this.useCachedClusters = false;
        }
        return getResult2();
    }

    private DefaultPair<ClusterType, ClusterType> randomPartition(ClusterType clustertype) {
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        for (Object obj : clustertype.getMembers()) {
            if (this.random.nextBoolean()) {
                arrayList.add(obj);
            } else {
                arrayList2.add(obj);
            }
        }
        if (arrayList.isEmpty()) {
            int size = arrayList2.size();
            Object obj2 = arrayList2.set(this.random.nextInt(size), arrayList2.get(size - 1));
            arrayList2.remove(size - 1);
            arrayList.add(obj2);
        }
        if (arrayList2.isEmpty()) {
            int size2 = arrayList.size();
            Object obj3 = arrayList.set(this.random.nextInt(size2), arrayList.get(size2 - 1));
            arrayList.remove(size2 - 1);
            arrayList2.add(obj3);
        }
        return DefaultPair.create(this.creator.createCluster(arrayList), this.creator.createCluster(arrayList2));
    }

    private void greedySwap(Collection<DataType> collection, ClusterType clustertype, ClusterType clustertype2) {
        boolean z = true;
        double evaluate = this.clusterDivergenceFunction.evaluate(clustertype) + this.clusterDivergenceFunction.evaluate(clustertype2);
        while (z) {
            z = false;
            for (DataType datatype : collection) {
                swapElement(clustertype, clustertype2, datatype);
                double evaluate2 = this.clusterDivergenceFunction.evaluate(clustertype) + this.clusterDivergenceFunction.evaluate(clustertype2);
                if (evaluate2 < evaluate) {
                    evaluate = evaluate2;
                    z = true;
                } else {
                    swapElement(clustertype, clustertype2, datatype);
                }
            }
        }
    }

    private void swapElement(ClusterType clustertype, ClusterType clustertype2, DataType datatype) {
        if (clustertype.getMembers().contains(datatype) && clustertype.getMembers().size() > 1) {
            this.creator.removeClusterMember(clustertype, datatype);
            this.creator.addClusterMember(clustertype2, datatype);
        } else {
            if (!clustertype2.getMembers().contains(datatype) || clustertype2.getMembers().size() <= 1) {
                return;
            }
            this.creator.removeClusterMember(clustertype2, datatype);
            this.creator.addClusterMember(clustertype, datatype);
        }
    }

    public int getClusterCount() {
        if (this.clusters == null) {
            return 0;
        }
        return this.clusters.size();
    }

    public WithinClusterDivergence<? super ClusterType, ? super DataType> getWithinClusterDivergenceFunction() {
        return this.clusterDivergenceFunction;
    }

    @Override // gov.sandia.cognition.learning.function.distance.DivergenceFunctionContainer
    public DivergenceFunction<? super ClusterType, ? super DataType> getDivergenceFunction() {
        return this.divergenceFunction;
    }

    public void setDivergenceFunction(DivergenceFunction<? super ClusterType, ? super DataType> divergenceFunction) {
        ArgumentChecker.assertIsNotNull("divergenceFunction", divergenceFunction);
        setWithinClusterDivergenceFunction(new WithinClusterDivergenceWrapper(divergenceFunction));
        this.divergenceFunction = divergenceFunction;
    }

    public void setWithinClusterDivergenceFunction(WithinClusterDivergence<? super ClusterType, ? super DataType> withinClusterDivergence) {
        ArgumentChecker.assertIsNotNull("clusterDivergenceFunction", withinClusterDivergence);
        this.clusterDivergenceFunction = withinClusterDivergence;
        this.divergenceFunction = null;
    }

    public double getTolerance() {
        return this.tolerance;
    }

    public void setTolerance(double d) {
        this.tolerance = d;
    }

    public IncrementalClusterCreator<ClusterType, DataType> getCreator() {
        return this.creator;
    }

    public void setCreator(IncrementalClusterCreator<ClusterType, DataType> incrementalClusterCreator) {
        ArgumentChecker.assertIsNotNull("creator", incrementalClusterCreator);
        this.creator = incrementalClusterCreator;
    }

    @Override // gov.sandia.cognition.util.Randomized
    public Random getRandom() {
        return this.random;
    }

    @Override // gov.sandia.cognition.util.Randomized
    public void setRandom(Random random) {
        ArgumentChecker.assertIsNotNull("random", random);
        this.random = random;
    }

    public int getMinClusterSize() {
        return this.minClusterSize;
    }

    public void setMinClusterSize(int i) {
        ArgumentChecker.assertIsPositive("minClusterSize", i);
        this.minClusterSize = i;
    }

    public double getMaxCriterionDecrease() {
        return this.maxCriterionDecrease;
    }

    public void setMaxCriterionDecrease(double d) {
        ArgumentChecker.assertIsPositive("maxCriterionDecrease", d);
        this.maxCriterionDecrease = d;
    }

    public ArrayList<ClusterType> getClusters() {
        return this.clusters;
    }

    protected void setClusters(ArrayList<ClusterType> arrayList) {
        this.clusters = arrayList;
    }

    public ArrayList<BinaryClusterHierarchyNode<DataType, ClusterType>> getClustersHierarchy() {
        return this.clustersHierarchy;
    }

    protected void setClustersHierarchy(ArrayList<BinaryClusterHierarchyNode<DataType, ClusterType>> arrayList) {
        this.clustersHierarchy = arrayList;
    }

    public int getNumRequestedClusters() {
        return this.numRequestedClusters;
    }

    public void setNumRequestedClusters(int i) {
        this.numRequestedClusters = i;
    }
}
