package com.gengoai.apollo.ml.model.clustering;

import com.gengoai.ParamMap;
import com.gengoai.apollo.math.linalg.NDArray;
import com.gengoai.apollo.math.linalg.NDArrayFactory;
import com.gengoai.apollo.math.linalg.Shape;
import com.gengoai.apollo.math.statistics.measure.Measure;
import com.gengoai.apollo.ml.DataSet;
import com.gengoai.apollo.ml.model.Params;
import com.gengoai.apollo.ml.model.StoppingCriteria;
import com.gengoai.conversion.Cast;
import com.gengoai.function.Functional;
import com.gengoai.math.Optimum;
import java.lang.invoke.SerializedLambda;
import java.util.Iterator;
import java.util.List;
import java.util.Objects;
import java.util.Random;
import java.util.function.Consumer;
import java.util.logging.Logger;
import java.util.stream.IntStream;
import lombok.NonNull;

/* loaded from: input_file:com/gengoai/apollo/ml/model/clustering/KMeans.class */
public class KMeans extends FlatCentroidClusterer {
    private static final Logger log = Logger.getLogger(KMeans.class.getName());
    private static final long serialVersionUID = 1;

    /* loaded from: input_file:com/gengoai/apollo/ml/model/clustering/KMeans$Parameters.class */
    public static class Parameters extends ClusterFitParameters {
        public final ParamMap<ClusterFitParameters>.Parameter<Integer> K = parameter(Params.Clustering.K, 2);
        public final ParamMap<ClusterFitParameters>.Parameter<Integer> maxIterations = parameter(Params.Optimizable.maxIterations, 100);
        public final ParamMap<ClusterFitParameters>.Parameter<Double> tolerance = parameter(Params.Optimizable.tolerance, Double.valueOf(0.001d));
    }

    public KMeans() {
        super(new Parameters());
    }

    public KMeans(@NonNull Parameters parameters) {
        super(parameters);
        if (parameters == null) {
            throw new NullPointerException("parameters is marked non-null but is null");
        }
    }

    public KMeans(@NonNull Consumer<Parameters> consumer) {
        super((ClusterFitParameters) Functional.with(new Parameters(), consumer));
        if (consumer == null) {
            throw new NullPointerException("updater is marked non-null but is null");
        }
    }

    private Cluster estimate(NDArray nDArray) {
        NDArray asNDArray = transform(nDArray).asNDArray();
        return ((Measure) ((ClusterFitParameters) this.parameters).measure.value()).getOptimum() == Optimum.MAXIMUM ? this.clustering.get((int) asNDArray.argmax()) : this.clustering.get((int) asNDArray.argmin());
    }

    @Override // com.gengoai.apollo.ml.model.Model
    public void estimate(@NonNull DataSet dataSet) {
        if (dataSet == null) {
            throw new NullPointerException("dataset is marked non-null but is null");
        }
        Parameters parameters = (Parameters) Cast.as(this.parameters);
        this.clustering = new FlatClustering();
        this.clustering.setMeasure((Measure) parameters.measure.value());
        List<NDArray> collect = dataSet.parallelStream().map(this::getNDArray).collect();
        for (NDArray nDArray : initCentroids(((Integer) parameters.K.value()).intValue(), collect)) {
            Cluster cluster = new Cluster();
            cluster.setCentroid(nDArray);
            this.clustering.add(cluster);
        }
        Measure measure = (Measure) parameters.measure.value();
        StoppingCriteria.create("numPointsChanged").historySize(3).maxIterations(((Integer) parameters.maxIterations.value()).intValue()).tolerance(((Double) parameters.tolerance.value()).doubleValue()).reportInterval(((Boolean) parameters.verbose.value()).booleanValue() ? 1 : -1).logger(log).untilTermination(i -> {
            return iteration(collect);
        });
        for (int i2 = 0; i2 < this.clustering.size(); i2++) {
            Cluster cluster2 = this.clustering.get(i2);
            cluster2.setId(i2);
            if (cluster2.size() > 0) {
                cluster2.getPoints().removeIf((v0) -> {
                    return Objects.isNull(v0);
                });
                cluster2.setScore(cluster2.getPoints().parallelStream().flatMapToDouble(nDArray2 -> {
                    return cluster2.getPoints().stream().filter(nDArray2 -> {
                        return nDArray2 != nDArray2;
                    }).mapToDouble(nDArray3 -> {
                        return measure.calculate(nDArray2, nDArray3);
                    });
                }).summaryStatistics().getAverage());
            } else {
                cluster2.setScore(Double.MAX_VALUE);
            }
        }
    }

    @Override // com.gengoai.apollo.ml.model.SingleSourceModel, com.gengoai.apollo.ml.model.Model
    public Parameters getFitParameters() {
        return (Parameters) Cast.as(this.parameters);
    }

    private NDArray[] initCentroids(int i, List<NDArray> list) {
        int length = (int) list.get(0).length();
        NDArray[] nDArrayArr = (NDArray[]) IntStream.range(0, i).mapToObj(i2 -> {
            return NDArrayFactory.ND.array(length);
        }).toArray(i3 -> {
            return new NDArray[i3];
        });
        double[] dArr = new double[i];
        Random random = new Random();
        for (NDArray nDArray : list) {
            int nextInt = random.nextInt(i);
            nDArrayArr[nextInt].addi(nDArray);
            dArr[nextInt] = dArr[nextInt] + 1.0d;
        }
        for (int i4 = 0; i4 < i; i4++) {
            if (dArr[i4] > 0.0d) {
                nDArrayArr[i4].divi((float) dArr[i4]);
            }
        }
        return nDArrayArr;
    }

    private double iteration(List<NDArray> list) {
        NDArray array;
        int length = (int) list.get(0).length();
        this.clustering.keepOnlyCentroids();
        Object[] objArr = new Object[this.clustering.size()];
        for (int i = 0; i < this.clustering.size(); i++) {
            objArr[i] = new Object();
        }
        list.parallelStream().forEach(nDArray -> {
            Cluster estimate = estimate(nDArray);
            synchronized (objArr[estimate.getId()]) {
                estimate.addPoint(nDArray);
            }
        });
        double d = 0.0d;
        Iterator<Cluster> it = this.clustering.iterator();
        while (it.hasNext()) {
            Cluster next = it.next();
            if (next.size() == 0) {
                array = NDArrayFactory.ND.uniform(Shape.shape(length), -1, 1);
            } else {
                array = NDArrayFactory.ND.array(length);
                Iterator<NDArray> it2 = next.getPoints().iterator();
                while (it2.hasNext()) {
                    array.addi(it2.next());
                }
                array.divi(next.size());
            }
            next.setCentroid(array);
            d += next.getPoints().parallelStream().mapToDouble(nDArray2 -> {
                if (nDArray2.getPredicted() == null) {
                    nDArray2.setPredicted(Double.valueOf(next.getId()));
                    return 1.0d;
                }
                double d2 = nDArray2.getPredictedAsDouble() == ((double) next.getId()) ? 0.0d : 1.0d;
                nDArray2.setPredicted(Double.valueOf(next.getId()));
                return d2;
            }).sum();
        }
        return d;
    }

    private static /* synthetic */ Object $deserializeLambda$(SerializedLambda serializedLambda) {
        String implMethodName = serializedLambda.getImplMethodName();
        boolean z = -1;
        switch (implMethodName.hashCode()) {
            case 1888337453:
                if (implMethodName.equals("getNDArray")) {
                    z = false;
                    break;
                }
                break;
        }
        switch (z) {
            case false:
                if (serializedLambda.getImplMethodKind() == 5 && serializedLambda.getFunctionalInterfaceClass().equals("com/gengoai/function/SerializableFunction") && serializedLambda.getFunctionalInterfaceMethodName().equals("apply") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;)Ljava/lang/Object;") && serializedLambda.getImplClass().equals("com/gengoai/apollo/ml/model/clustering/Clusterer") && serializedLambda.getImplMethodSignature().equals("(Lcom/gengoai/apollo/ml/Datum;)Lcom/gengoai/apollo/math/linalg/NDArray;")) {
                    KMeans kMeans = (KMeans) serializedLambda.getCapturedArg(0);
                    return kMeans::getNDArray;
                }
                break;
        }
        throw new IllegalArgumentException("Invalid lambda deserialization");
    }
}
