package com.gengoai.apollo.ml.model.clustering;

import com.gengoai.ParamMap;
import com.gengoai.apollo.math.linalg.NDArray;
import com.gengoai.apollo.math.linalg.NDArrayFactory;
import com.gengoai.apollo.math.statistics.measure.Distance;
import com.gengoai.apollo.ml.DataSet;
import com.gengoai.apollo.ml.Datum;
import com.gengoai.apollo.ml.model.Params;
import com.gengoai.conversion.Cast;
import com.gengoai.function.Functional;
import com.gengoai.tuple.Tuples;
import java.lang.invoke.SerializedLambda;
import java.util.Iterator;
import java.util.function.Consumer;
import lombok.NonNull;
import org.apache.spark.mllib.clustering.KMeansModel;
import org.apache.spark.mllib.linalg.DenseVector;
import org.apache.spark.mllib.linalg.Vector;

/* loaded from: input_file:com/gengoai/apollo/ml/model/clustering/DistributedKMeans.class */
public class DistributedKMeans extends FlatCentroidClusterer {
    private static final long serialVersionUID = 1;

    /* loaded from: input_file:com/gengoai/apollo/ml/model/clustering/DistributedKMeans$Parameters.class */
    public static class Parameters extends ClusterFitParameters {
        private static final long serialVersionUID = 1;
        public final ParamMap<ClusterFitParameters>.Parameter<Integer> K = parameter(Params.Clustering.K, 2);
        public final ParamMap<ClusterFitParameters>.Parameter<Integer> maxIterations = parameter(Params.Optimizable.maxIterations, 100);
        public final ParamMap<ClusterFitParameters>.Parameter<Double> tolerance = parameter(Params.Optimizable.tolerance, Double.valueOf(0.001d));
    }

    public DistributedKMeans() {
        super(new Parameters());
    }

    public DistributedKMeans(@NonNull Parameters parameters) {
        super(parameters);
        if (parameters == null) {
            throw new NullPointerException("parameters is marked non-null but is null");
        }
    }

    public DistributedKMeans(@NonNull Consumer<Parameters> consumer) {
        super((ClusterFitParameters) Functional.with(new Parameters(), consumer));
        if (consumer == null) {
            throw new NullPointerException("updater is marked non-null but is null");
        }
    }

    @Override // com.gengoai.apollo.ml.model.Model
    public void estimate(@NonNull DataSet dataSet) {
        if (dataSet == null) {
            throw new NullPointerException("dataset is marked non-null but is null");
        }
        Parameters parameters = (Parameters) Cast.as(this.parameters);
        org.apache.spark.mllib.clustering.KMeans kMeans = new org.apache.spark.mllib.clustering.KMeans();
        kMeans.setK(((Integer) parameters.K.value()).intValue());
        kMeans.setMaxIterations(((Integer) parameters.maxIterations.value()).intValue());
        kMeans.setEpsilon(((Double) parameters.tolerance.value()).doubleValue());
        KMeansModel run = kMeans.run(dataSet.stream().toDistributedStream().getRDD().map(this::toDenseVector).cache().rdd());
        this.clustering = new FlatClustering();
        this.clustering.setMeasure(Distance.Euclidean);
        for (int i = 0; i < run.clusterCenters().length; i++) {
            Cluster cluster = new Cluster();
            cluster.setId(i);
            cluster.setCentroid(NDArrayFactory.ND.rowVector(run.clusterCenters()[i].toArray()));
            this.clustering.add(cluster);
        }
        dataSet.stream().map(datum -> {
            return Tuples.$(Integer.valueOf(run.predict(toDenseVector(datum))), datum.get(((ClusterFitParameters) this.parameters).input.value()).asNDArray());
        }).forEachLocal(tuple2 -> {
            this.clustering.get(((Integer) tuple2.v1).intValue()).addPoint((NDArray) tuple2.v2);
        });
        Iterator<Cluster> it = this.clustering.iterator();
        while (it.hasNext()) {
            Cluster next = it.next();
            next.setScore(next.getScore() / next.size());
        }
    }

    @Override // com.gengoai.apollo.ml.model.SingleSourceModel, com.gengoai.apollo.ml.model.Model
    public Parameters getFitParameters() {
        return new Parameters();
    }

    private Vector toDenseVector(Datum datum) {
        return new DenseVector(datum.get(((ClusterFitParameters) this.parameters).input.value()).asNDArray().toDoubleArray());
    }

    private static /* synthetic */ Object $deserializeLambda$(SerializedLambda serializedLambda) {
        String implMethodName = serializedLambda.getImplMethodName();
        boolean z = -1;
        switch (implMethodName.hashCode()) {
            case -370565983:
                if (implMethodName.equals("lambda$estimate$58a22377$1")) {
                    z = true;
                    break;
                }
                break;
            case 712161820:
                if (implMethodName.equals("lambda$estimate$e7a716ad$1")) {
                    z = 2;
                    break;
                }
                break;
            case 771869639:
                if (implMethodName.equals("toDenseVector")) {
                    z = false;
                    break;
                }
                break;
        }
        switch (z) {
            case false:
                if (serializedLambda.getImplMethodKind() == 7 && serializedLambda.getFunctionalInterfaceClass().equals("org/apache/spark/api/java/function/Function") && serializedLambda.getFunctionalInterfaceMethodName().equals("call") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;)Ljava/lang/Object;") && serializedLambda.getImplClass().equals("com/gengoai/apollo/ml/model/clustering/DistributedKMeans") && serializedLambda.getImplMethodSignature().equals("(Lcom/gengoai/apollo/ml/Datum;)Lorg/apache/spark/mllib/linalg/Vector;")) {
                    DistributedKMeans distributedKMeans = (DistributedKMeans) serializedLambda.getCapturedArg(0);
                    return distributedKMeans::toDenseVector;
                }
                break;
            case true:
                if (serializedLambda.getImplMethodKind() == 7 && serializedLambda.getFunctionalInterfaceClass().equals("com/gengoai/function/SerializableFunction") && serializedLambda.getFunctionalInterfaceMethodName().equals("apply") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;)Ljava/lang/Object;") && serializedLambda.getImplClass().equals("com/gengoai/apollo/ml/model/clustering/DistributedKMeans") && serializedLambda.getImplMethodSignature().equals("(Lorg/apache/spark/mllib/clustering/KMeansModel;Lcom/gengoai/apollo/ml/Datum;)Lcom/gengoai/tuple/Tuple2;")) {
                    DistributedKMeans distributedKMeans2 = (DistributedKMeans) serializedLambda.getCapturedArg(0);
                    KMeansModel kMeansModel = (KMeansModel) serializedLambda.getCapturedArg(1);
                    return datum -> {
                        return Tuples.$(Integer.valueOf(kMeansModel.predict(toDenseVector(datum))), datum.get(((ClusterFitParameters) this.parameters).input.value()).asNDArray());
                    };
                }
                break;
            case true:
                if (serializedLambda.getImplMethodKind() == 7 && serializedLambda.getFunctionalInterfaceClass().equals("com/gengoai/function/SerializableConsumer") && serializedLambda.getFunctionalInterfaceMethodName().equals("accept") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;)V") && serializedLambda.getImplClass().equals("com/gengoai/apollo/ml/model/clustering/DistributedKMeans") && serializedLambda.getImplMethodSignature().equals("(Lcom/gengoai/tuple/Tuple2;)V")) {
                    DistributedKMeans distributedKMeans3 = (DistributedKMeans) serializedLambda.getCapturedArg(0);
                    return tuple2 -> {
                        this.clustering.get(((Integer) tuple2.v1).intValue()).addPoint((NDArray) tuple2.v2);
                    };
                }
                break;
        }
        throw new IllegalArgumentException("Invalid lambda deserialization");
    }
}
