package gov.sandia.cognition.learning.algorithm.clustering;

import gov.sandia.cognition.annotation.PublicationReference;
import gov.sandia.cognition.annotation.PublicationType;
import gov.sandia.cognition.learning.algorithm.clustering.cluster.ClusterCreator;
import gov.sandia.cognition.learning.algorithm.clustering.cluster.MiniBatchCentroidCluster;
import gov.sandia.cognition.learning.algorithm.clustering.cluster.VectorMeanMiniBatchCentroidClusterCreator;
import gov.sandia.cognition.learning.algorithm.clustering.divergence.CentroidClusterDivergenceFunction;
import gov.sandia.cognition.learning.algorithm.clustering.initializer.FixedClusterInitializer;
import gov.sandia.cognition.learning.algorithm.clustering.initializer.GreedyClusterInitializer;
import gov.sandia.cognition.learning.function.distance.EuclideanDistanceMetric;
import gov.sandia.cognition.math.Semimetric;
import gov.sandia.cognition.math.matrix.Vector;
import gov.sandia.cognition.statistics.DiscreteSamplingUtil;
import gov.sandia.cognition.util.ObjectUtil;
import gov.sandia.cognition.util.Randomized;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.RandomAccess;
import java.util.concurrent.ConcurrentMap;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import java.util.stream.Stream;

@PublicationReference(author = {"Jeff Piersol"}, title = "Parallel Mini-Batch k-means Clustering", type = PublicationType.Conference, year = 2016, publication = "to appear", url = "to appear")
/* loaded from: input_file:gov/sandia/cognition/learning/algorithm/clustering/MiniBatchKMeansClusterer.class */
public class MiniBatchKMeansClusterer<DataType extends Vector> extends KMeansClusterer<Vector, MiniBatchCentroidCluster> implements Randomized {
    private static final long serialVersionUID = 2587013040037999607L;
    public static final int DEFAULT_MAX_ITERATIONS = 100000;
    protected Random random;
    private int minibatchSize;
    protected List<Integer> dataIndices;
    private double stoppingCriterion;

    /* loaded from: input_file:gov/sandia/cognition/learning/algorithm/clustering/MiniBatchKMeansClusterer$Builder.class */
    public static class Builder<DataType extends Vector> {
        private int numClusters;
        private int maxIterations;
        private int minibatchSize;
        private FixedClusterInitializer<MiniBatchCentroidCluster, Vector> initializer;
        private Semimetric<? super Vector> metric;
        private ClusterCreator<MiniBatchCentroidCluster, Vector> creator;
        private Random random;

        public Builder(int i) {
            this(i, EuclideanDistanceMetric.INSTANCE);
        }

        public Builder(int i, Semimetric<? super Vector> semimetric) {
            this.numClusters = i;
            this.maxIterations = MiniBatchKMeansClusterer.DEFAULT_MAX_ITERATIONS;
            this.random = new Random();
            this.creator = VectorMeanMiniBatchCentroidClusterCreator.INSTANCE;
            this.metric = semimetric;
            this.initializer = new GreedyClusterInitializer(this.metric, this.creator, this.random);
        }

        public MiniBatchKMeansClusterer<DataType> build() {
            MiniBatchKMeansClusterer<DataType> miniBatchKMeansClusterer = new MiniBatchKMeansClusterer<>(this.numClusters, this.maxIterations, this.initializer, this.metric, this.creator, this.random);
            miniBatchKMeansClusterer.setMinibatchSize(this.minibatchSize);
            return miniBatchKMeansClusterer;
        }

        public Builder<DataType> withNumClusters(int i) {
            this.numClusters = i;
            return this;
        }

        public Builder<DataType> withMaxIterations(int i) {
            this.maxIterations = i;
            return this;
        }

        public Builder<DataType> withMinibatchSize(int i) {
            this.minibatchSize = i;
            return this;
        }

        public Builder<DataType> withInitializer(FixedClusterInitializer<MiniBatchCentroidCluster, Vector> fixedClusterInitializer) {
            this.initializer = fixedClusterInitializer;
            return this;
        }

        public Builder<DataType> withCreator(ClusterCreator<MiniBatchCentroidCluster, Vector> clusterCreator) {
            this.creator = clusterCreator;
            return this;
        }

        public Builder<DataType> withRandom(Random random) {
            this.random = random;
            return this;
        }
    }

    public MiniBatchKMeansClusterer(int i) {
        this(i, new Random(), VectorMeanMiniBatchCentroidClusterCreator.INSTANCE);
    }

    private MiniBatchKMeansClusterer(int i, Random random, ClusterCreator<MiniBatchCentroidCluster, Vector> clusterCreator) {
        this(i, DEFAULT_MAX_ITERATIONS, new GreedyClusterInitializer(EuclideanDistanceMetric.INSTANCE, clusterCreator, random), EuclideanDistanceMetric.INSTANCE, clusterCreator, random);
    }

    public MiniBatchKMeansClusterer(int i, int i2, FixedClusterInitializer<MiniBatchCentroidCluster, Vector> fixedClusterInitializer, Semimetric<? super Vector> semimetric, ClusterCreator<MiniBatchCentroidCluster, Vector> clusterCreator, Random random) {
        super(i, i2, fixedClusterInitializer, new CentroidClusterDivergenceFunction(semimetric), clusterCreator);
        this.stoppingCriterion = 0.01d;
        setRandom(random);
    }

    @Override // gov.sandia.cognition.learning.algorithm.clustering.KMeansClusterer, gov.sandia.cognition.learning.algorithm.AbstractAnytimeBatchLearner, gov.sandia.cognition.algorithm.AbstractIterativeAlgorithm, gov.sandia.cognition.util.AbstractCloneableSerializable
    /* renamed from: clone */
    public MiniBatchKMeansClusterer<DataType> mo0clone() {
        MiniBatchKMeansClusterer<DataType> miniBatchKMeansClusterer = (MiniBatchKMeansClusterer) super.mo0clone();
        this.random = (Random) ObjectUtil.cloneSmart(this.random);
        return miniBatchKMeansClusterer;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // gov.sandia.cognition.learning.algorithm.clustering.KMeansClusterer, gov.sandia.cognition.learning.algorithm.AbstractAnytimeBatchLearner
    public boolean initializeAlgorithm() {
        boolean initializeAlgorithm = super.initializeAlgorithm();
        if (initializeAlgorithm) {
            this.minibatchSize = getNumClusters() < 1 ? 0 : this.minibatchSize <= 0 ? Math.min(getNumElements(), 10000) : this.minibatchSize;
        }
        return initializeAlgorithm;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // gov.sandia.cognition.learning.algorithm.clustering.KMeansClusterer, gov.sandia.cognition.learning.algorithm.AbstractAnytimeBatchLearner
    public boolean step() {
        List<? extends DataType> data = getData();
        ArrayList sampleWithReplacement = DiscreteSamplingUtil.sampleWithReplacement(this.random, this.dataIndices, this.minibatchSize);
        Stream stream = sampleWithReplacement.stream();
        data.getClass();
        List list = (List) stream.map((v1) -> {
            return r1.get(v1);
        }).collect(Collectors.toList());
        int[] assignDataToClusters = assignDataToClusters(list);
        getCreator();
        ((Stream) ((Map) IntStream.range(0, sampleWithReplacement.size()).parallel().mapToObj(Integer::valueOf).collect(Collectors.groupingByConcurrent(num -> {
            return (MiniBatchCentroidCluster) this.clusters.get(assignDataToClusters[num.intValue()]);
        }, Collectors.mapping(num2 -> {
            return (Vector) list.get(num2.intValue());
        }, Collectors.toList())))).entrySet().stream().parallel()).forEach(entry -> {
            ((MiniBatchCentroidCluster) entry.getKey()).updateCluster((Collection<? extends Vector>) entry.getValue());
        });
        int i = 0;
        for (int i2 = 0; i2 < assignDataToClusters.length; i2++) {
            if (setAssignment(((Integer) sampleWithReplacement.get(i2)).intValue(), assignDataToClusters[i2])) {
                i++;
            }
        }
        setNumChanged(i);
        return ((double) getNumChanged()) / ((double) this.minibatchSize) > this.stoppingCriterion;
    }

    protected void saveFinalClustering() {
        if (this.clusters.size() > 0) {
            List<? extends DataType> data = getData();
            this.assignments = assignDataToClusters(data);
            this.clusters.forEach(miniBatchCentroidCluster -> {
                miniBatchCentroidCluster.getMembers().clear();
            });
            ((ConcurrentMap) IntStream.range(0, this.assignments.length).parallel().mapToObj(Integer::valueOf).collect(Collectors.groupingByConcurrent(num -> {
                return Integer.valueOf(this.assignments[num.intValue()]);
            }))).forEach((num2, list) -> {
                ((MiniBatchCentroidCluster) this.clusters.get(num2.intValue())).getMembers().addAll((Collection) list.stream().map(num2 -> {
                    return (Vector) data.get(num2.intValue());
                }).collect(Collectors.toList()));
            });
        }
    }

    @Override // gov.sandia.cognition.learning.algorithm.clustering.KMeansClusterer, gov.sandia.cognition.learning.algorithm.AbstractAnytimeBatchLearner
    protected void cleanupAlgorithm() {
        saveFinalClustering();
    }

    @Override // gov.sandia.cognition.util.Randomized
    public Random getRandom() {
        return this.random;
    }

    @Override // gov.sandia.cognition.util.Randomized
    public final void setRandom(Random random) {
        this.random = random;
    }

    @Override // gov.sandia.cognition.learning.algorithm.AbstractAnytimeBatchLearner, gov.sandia.cognition.learning.algorithm.AnytimeBatchLearner
    public List<? extends DataType> getData() {
        return (List) super.getData();
    }

    @Override // gov.sandia.cognition.learning.algorithm.clustering.KMeansClusterer, gov.sandia.cognition.learning.algorithm.AbstractAnytimeBatchLearner
    public void setData(Collection<? extends Vector> collection) {
        if (collection == null) {
            collection = new ArrayList();
        }
        super.setData(((collection instanceof List) && (collection instanceof RandomAccess)) ? collection : new ArrayList(collection));
        this.dataIndices = (List) IntStream.range(0, collection.size()).boxed().collect(Collectors.toList());
    }

    public double getStoppingCriterion() {
        return this.stoppingCriterion;
    }

    public void setStoppingCriterion(double d) {
        this.stoppingCriterion = d;
    }

    public int getMinibatchSize() {
        return this.minibatchSize;
    }

    public void setMinibatchSize(int i) {
        this.minibatchSize = i;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // gov.sandia.cognition.learning.algorithm.clustering.KMeansClusterer
    public int[] assignDataToClusters(Collection<? extends Vector> collection) {
        return (collection.size() > 25 ? collection.parallelStream() : collection.stream()).mapToInt(vector -> {
            return getClosestClusterIndex(vector);
        }).toArray();
    }
}
