package org.tribuo.clustering.kmeans;

import com.oracle.labs.mlrg.olcut.config.Config;
import com.oracle.labs.mlrg.olcut.provenance.Provenance;
import com.oracle.labs.mlrg.olcut.util.MutableLong;
import com.oracle.labs.mlrg.olcut.util.StreamUtil;
import java.security.AccessController;
import java.time.OffsetDateTime;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.SplittableRandom;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.ForkJoinWorkerThread;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.Consumer;
import java.util.logging.Level;
import java.util.logging.Logger;
import java.util.stream.IntStream;
import java.util.stream.Stream;
import org.tribuo.Dataset;
import org.tribuo.Example;
import org.tribuo.ImmutableFeatureMap;
import org.tribuo.Model;
import org.tribuo.Trainer;
import org.tribuo.clustering.ClusterID;
import org.tribuo.clustering.ImmutableClusteringInfo;
import org.tribuo.math.la.DenseVector;
import org.tribuo.math.la.SGDVector;
import org.tribuo.math.la.SparseVector;
import org.tribuo.provenance.ModelProvenance;
import org.tribuo.provenance.TrainerProvenance;
import org.tribuo.provenance.impl.TrainerProvenanceImpl;
import org.tribuo.util.Util;

/* loaded from: input_file:org/tribuo/clustering/kmeans/KMeansTrainer.class */
public class KMeansTrainer implements Trainer<ClusterID> {
    private static final Logger logger = Logger.getLogger(KMeansTrainer.class.getName());
    private static final CustomForkJoinWorkerThreadFactory THREAD_FACTORY = new CustomForkJoinWorkerThreadFactory();

    @Config(mandatory = true, description = "Number of centroids (i.e., the \"k\" in k-means).")
    private int centroids;

    @Config(mandatory = true, description = "The number of iterations to run.")
    private int iterations;

    @Config(mandatory = true, description = "The distance function to use.")
    private Distance distanceType;

    @Config(description = "The centroid initialisation method to use.")
    private Initialisation initialisationType;

    @Config(description = "The number of threads to use for training.")
    private int numThreads;

    @Config(mandatory = true, description = "The seed to use for the RNG.")
    private long seed;
    private SplittableRandom rng;
    private int trainInvocationCounter;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/tribuo/clustering/kmeans/KMeansTrainer$CustomForkJoinWorkerThreadFactory.class */
    public static final class CustomForkJoinWorkerThreadFactory implements ForkJoinPool.ForkJoinWorkerThreadFactory {
        private CustomForkJoinWorkerThreadFactory() {
        }

        @Override // java.util.concurrent.ForkJoinPool.ForkJoinWorkerThreadFactory
        public final ForkJoinWorkerThread newThread(ForkJoinPool forkJoinPool) {
            return (ForkJoinWorkerThread) AccessController.doPrivileged(() -> {
                return new ForkJoinWorkerThread(forkJoinPool) { // from class: org.tribuo.clustering.kmeans.KMeansTrainer.CustomForkJoinWorkerThreadFactory.1
                };
            });
        }
    }

    /* loaded from: input_file:org/tribuo/clustering/kmeans/KMeansTrainer$Distance.class */
    public enum Distance {
        EUCLIDEAN,
        COSINE,
        L1
    }

    /* loaded from: input_file:org/tribuo/clustering/kmeans/KMeansTrainer$Initialisation.class */
    public enum Initialisation {
        RANDOM,
        PLUSPLUS
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:org/tribuo/clustering/kmeans/KMeansTrainer$IntAndVector.class */
    public static class IntAndVector {
        final int idx;
        final SGDVector vector;

        public IntAndVector(int i, SGDVector sGDVector) {
            this.idx = i;
            this.vector = sGDVector;
        }
    }

    private KMeansTrainer() {
        this.initialisationType = Initialisation.RANDOM;
        this.numThreads = 1;
    }

    public KMeansTrainer(int i, int i2, Distance distance, int i3, long j) {
        this(i, i2, distance, Initialisation.RANDOM, i3, j);
    }

    public KMeansTrainer(int i, int i2, Distance distance, Initialisation initialisation, int i3, long j) {
        this.initialisationType = Initialisation.RANDOM;
        this.numThreads = 1;
        this.centroids = i;
        this.iterations = i2;
        this.distanceType = distance;
        this.initialisationType = initialisation;
        this.numThreads = i3;
        this.seed = j;
        postConfig();
    }

    public synchronized void postConfig() {
        this.rng = new SplittableRandom(this.seed);
    }

    public KMeansModel train(Dataset<ClusterID> dataset, Map<String, Provenance> map) {
        return train(dataset, map, -1);
    }

    public KMeansModel train(Dataset<ClusterID> dataset, Map<String, Provenance> map, int i) {
        SplittableRandom split;
        TrainerProvenance m7getProvenance;
        DenseVector[] initialisePlusPlusCentroids;
        synchronized (this) {
            if (i != -1) {
                setInvocationCount(i);
            }
            split = this.rng.split();
            m7getProvenance = m7getProvenance();
            this.trainInvocationCounter++;
        }
        ImmutableFeatureMap featureIDMap = dataset.getFeatureIDMap();
        int[] iArr = new int[dataset.size()];
        SGDVector[] sGDVectorArr = new SGDVector[dataset.size()];
        double[] dArr = new double[dataset.size()];
        int i2 = 0;
        Iterator it = dataset.iterator();
        while (it.hasNext()) {
            Example example = (Example) it.next();
            dArr[i2] = example.getWeight();
            if (example.size() == featureIDMap.size()) {
                sGDVectorArr[i2] = DenseVector.createDenseVector(example, featureIDMap, false);
            } else {
                sGDVectorArr[i2] = SparseVector.createSparseVector(example, featureIDMap, false);
            }
            iArr[i2] = -1;
            i2++;
        }
        switch (this.initialisationType) {
            case RANDOM:
                initialisePlusPlusCentroids = initialiseRandomCentroids(this.centroids, featureIDMap, split);
                break;
            case PLUSPLUS:
                initialisePlusPlusCentroids = initialisePlusPlusCentroids(this.centroids, sGDVectorArr, split, this.distanceType);
                break;
            default:
                throw new IllegalStateException("Unknown initialisation" + this.initialisationType);
        }
        HashMap hashMap = new HashMap();
        boolean z = this.numThreads > 1;
        for (int i3 = 0; i3 < this.centroids; i3++) {
            hashMap.put(Integer.valueOf(i3), z ? Collections.synchronizedList(new ArrayList()) : new ArrayList<>());
        }
        AtomicInteger atomicInteger = new AtomicInteger(0);
        DenseVector[] denseVectorArr = initialisePlusPlusCentroids;
        Consumer consumer = intAndVector -> {
            double d = Double.POSITIVE_INFINITY;
            int i4 = -1;
            int i5 = intAndVector.idx;
            SGDVector sGDVector = intAndVector.vector;
            for (int i6 = 0; i6 < this.centroids; i6++) {
                double distance = getDistance(denseVectorArr[i6], sGDVector, this.distanceType);
                if (distance < d) {
                    d = distance;
                    i4 = i6;
                }
            }
            ((List) hashMap.get(Integer.valueOf(i4))).add(Integer.valueOf(i5));
            if (iArr[i5] != i4) {
                iArr[i5] = i4;
                atomicInteger.incrementAndGet();
            }
        };
        boolean z2 = false;
        ForkJoinPool forkJoinPool = null;
        if (z) {
            try {
                forkJoinPool = System.getSecurityManager() == null ? new ForkJoinPool(this.numThreads) : new ForkJoinPool(this.numThreads, THREAD_FACTORY, null, false);
            } finally {
                if (forkJoinPool != null) {
                    forkJoinPool.shutdown();
                }
            }
        }
        for (int i4 = 0; i4 < this.iterations && !z2; i4++) {
            logger.log(Level.FINE, "Beginning iteration " + i4);
            atomicInteger.set(0);
            Iterator<Map.Entry<Integer, List<Integer>>> it2 = hashMap.entrySet().iterator();
            while (it2.hasNext()) {
                it2.next().getValue().clear();
            }
            Stream zip = StreamUtil.zip(IntStream.range(0, sGDVectorArr.length).boxed(), Arrays.stream(sGDVectorArr), (v1, v2) -> {
                return new IntAndVector(v1, v2);
            });
            if (z) {
                Stream boundParallelism = StreamUtil.boundParallelism((Stream) zip.parallel());
                try {
                    forkJoinPool.submit(() -> {
                        boundParallelism.forEach(consumer);
                    }).get();
                } catch (InterruptedException | ExecutionException e) {
                    throw new RuntimeException("Parallel execution failed", e);
                }
            } else {
                zip.forEach(consumer);
            }
            logger.log(Level.FINE, "E step completed. " + atomicInteger.get() + " words updated.");
            mStep(forkJoinPool, initialisePlusPlusCentroids, hashMap, sGDVectorArr, dArr);
            logger.log(Level.INFO, "Iteration " + i4 + " completed. " + atomicInteger.get() + " examples updated.");
            if (atomicInteger.get() == 0) {
                z2 = true;
                logger.log(Level.INFO, "K-Means converged at iteration " + i4);
            }
        }
        HashMap hashMap2 = new HashMap();
        Iterator<Map.Entry<Integer, List<Integer>>> it3 = hashMap.entrySet().iterator();
        while (it3.hasNext()) {
            hashMap2.put(it3.next().getKey(), new MutableLong(r0.getValue().size()));
        }
        return new KMeansModel("k-means-model", new ModelProvenance(KMeansModel.class.getName(), OffsetDateTime.now(), dataset.getProvenance(), m7getProvenance, map), featureIDMap, new ImmutableClusteringInfo(hashMap2), initialisePlusPlusCentroids, this.distanceType);
    }

    public KMeansModel train(Dataset<ClusterID> dataset) {
        return train(dataset, Collections.emptyMap());
    }

    public int getInvocationCount() {
        return this.trainInvocationCounter;
    }

    public synchronized void setInvocationCount(int i) {
        if (i < 0) {
            throw new IllegalArgumentException("The supplied invocationCount is less than zero.");
        }
        this.rng = new SplittableRandom(this.seed);
        this.trainInvocationCounter = 0;
        while (this.trainInvocationCounter < i) {
            this.rng.split();
            this.trainInvocationCounter++;
        }
    }

    private static DenseVector[] initialiseRandomCentroids(int i, ImmutableFeatureMap immutableFeatureMap, SplittableRandom splittableRandom) {
        DenseVector[] denseVectorArr = new DenseVector[i];
        int size = immutableFeatureMap.size();
        for (int i2 = 0; i2 < i; i2++) {
            double[] dArr = new double[size];
            for (int i3 = 0; i3 < size; i3++) {
                dArr[i3] = immutableFeatureMap.get(i3).uniformSample(splittableRandom);
            }
            denseVectorArr[i2] = DenseVector.createDenseVector(dArr);
        }
        return denseVectorArr;
    }

    private static DenseVector[] initialisePlusPlusCentroids(int i, SGDVector[] sGDVectorArr, SplittableRandom splittableRandom, Distance distance) {
        if (i > sGDVectorArr.length) {
            throw new IllegalArgumentException("The number of centroids may not exceed the number of samples.");
        }
        double[] dArr = new double[sGDVectorArr.length];
        Arrays.fill(dArr, Double.POSITIVE_INFINITY);
        double[] dArr2 = new double[sGDVectorArr.length];
        double[] dArr3 = new double[sGDVectorArr.length];
        DenseVector[] denseVectorArr = new DenseVector[i];
        denseVectorArr[0] = getRandomCentroidFromData(sGDVectorArr, splittableRandom);
        for (int i2 = 1; i2 < i; i2++) {
            DenseVector denseVector = denseVectorArr[i2 - 1];
            for (int i3 = 0; i3 < sGDVectorArr.length; i3++) {
                dArr[i3] = Math.min(dArr[i3], getDistance(denseVector, sGDVectorArr[i3], distance));
            }
            double d = 0.0d;
            for (int i4 = 0; i4 < sGDVectorArr.length; i4++) {
                dArr2[i4] = dArr[i4] * dArr[i4];
                d += dArr2[i4];
            }
            for (int i5 = 0; i5 < dArr3.length; i5++) {
                dArr3[i5] = dArr2[i5] / d;
            }
            denseVectorArr[i2] = DenseVector.createDenseVector(sGDVectorArr[Util.sampleFromCDF(Util.generateCDF(dArr3), splittableRandom)].toArray());
        }
        return denseVectorArr;
    }

    private static DenseVector getRandomCentroidFromData(SGDVector[] sGDVectorArr, SplittableRandom splittableRandom) {
        return DenseVector.createDenseVector(sGDVectorArr[splittableRandom.nextInt(sGDVectorArr.length)].toArray());
    }

    private static double getDistance(DenseVector denseVector, SGDVector sGDVector, Distance distance) {
        double l1Distance;
        switch (distance) {
            case EUCLIDEAN:
                l1Distance = denseVector.euclideanDistance(sGDVector);
                break;
            case COSINE:
                l1Distance = denseVector.cosineDistance(sGDVector);
                break;
            case L1:
                l1Distance = denseVector.l1Distance(sGDVector);
                break;
            default:
                throw new IllegalStateException("Unknown distance " + distance);
        }
        return l1Distance;
    }

    protected void mStep(ForkJoinPool forkJoinPool, DenseVector[] denseVectorArr, Map<Integer, List<Integer>> map, SGDVector[] sGDVectorArr, double[] dArr) {
        Consumer<? super Map.Entry<Integer, List<Integer>>> consumer = entry -> {
            DenseVector denseVector = denseVectorArr[((Integer) entry.getKey()).intValue()];
            denseVector.fill(0.0d);
            double d = 0.0d;
            for (Integer num : (List) entry.getValue()) {
                denseVector.intersectAndAddInPlace(sGDVectorArr[num.intValue()], d2 -> {
                    return d2 * dArr[num.intValue()];
                });
                d += dArr[num.intValue()];
            }
            if (d != 0.0d) {
                denseVector.scaleInPlace(1.0d / d);
            }
        };
        Stream<Map.Entry<Integer, List<Integer>>> stream = map.entrySet().stream();
        if (forkJoinPool == null) {
            stream.forEach(consumer);
            return;
        }
        Stream boundParallelism = StreamUtil.boundParallelism((Stream) stream.parallel());
        try {
            forkJoinPool.submit(() -> {
                boundParallelism.forEach(consumer);
            }).get();
        } catch (InterruptedException | ExecutionException e) {
            throw new RuntimeException("Parallel execution failed", e);
        }
    }

    public String toString() {
        return "KMeansTrainer(centroids=" + this.centroids + ",distanceType=" + this.distanceType + ",seed=" + this.seed + ",numThreads=" + this.numThreads + ", initialisationType=" + this.initialisationType + ")";
    }

    /* renamed from: getProvenance, reason: merged with bridge method [inline-methods] */
    public TrainerProvenance m7getProvenance() {
        return new TrainerProvenanceImpl(this);
    }

    /* renamed from: train, reason: collision with other method in class */
    public /* bridge */ /* synthetic */ Model m4train(Dataset dataset, Map map, int i) {
        return train((Dataset<ClusterID>) dataset, (Map<String, Provenance>) map, i);
    }

    /* renamed from: train, reason: collision with other method in class */
    public /* bridge */ /* synthetic */ Model m5train(Dataset dataset, Map map) {
        return train((Dataset<ClusterID>) dataset, (Map<String, Provenance>) map);
    }

    /* renamed from: train, reason: collision with other method in class */
    public /* bridge */ /* synthetic */ Model m6train(Dataset dataset) {
        return train((Dataset<ClusterID>) dataset);
    }
}
