package com.gengoai.apollo.ml.model.clustering;

import com.gengoai.apollo.math.linalg.NDArray;
import com.gengoai.apollo.math.statistics.measure.Measure;
import com.gengoai.apollo.ml.DataSet;
import com.gengoai.collection.Iterables;
import com.gengoai.function.Functional;
import com.gengoai.math.Optimum;
import com.gengoai.tuple.Tuples;
import java.lang.invoke.SerializedLambda;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.Map;
import java.util.Set;
import java.util.function.Consumer;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import lombok.NonNull;

/* loaded from: input_file:com/gengoai/apollo/ml/model/clustering/GreedyAgglomerativeClusterer.class */
public class GreedyAgglomerativeClusterer extends HierarchicalClusterer {
    public GreedyAgglomerativeClusterer() {
        super(new ClusterFitParameters());
    }

    public GreedyAgglomerativeClusterer(@NonNull ClusterFitParameters clusterFitParameters) {
        super(clusterFitParameters);
        if (clusterFitParameters == null) {
            throw new NullPointerException("parameters is marked non-null but is null");
        }
    }

    public GreedyAgglomerativeClusterer(@NonNull Consumer<ClusterFitParameters> consumer) {
        super((ClusterFitParameters) Functional.with(new ClusterFitParameters(), consumer));
        if (consumer == null) {
            throw new NullPointerException("updater is marked non-null but is null");
        }
    }

    @Override // com.gengoai.apollo.ml.model.Model
    public void estimate(@NonNull DataSet dataSet) {
        if (dataSet == null) {
            throw new NullPointerException("dataset is marked non-null but is null");
        }
        this.clustering = new HierarchicalClustering();
        this.clustering.setMeasure((Measure) ((ClusterFitParameters) this.parameters).measure.value());
        Set<Cluster> set = (Set) dataSet.parallelStream().map(this::getNDArray).zipWithIndex().map((nDArray, l) -> {
            Cluster cluster = new Cluster();
            cluster.setId(l.intValue());
            cluster.setCentroid(nDArray);
            cluster.setScore(0.0d);
            cluster.addPoint(nDArray);
            return cluster;
        }).collect(Collectors.toSet());
        if (set.size() == 1) {
            this.clustering.root = (Cluster) Iterables.getFirst(set, (Object) null);
            return;
        }
        Cluster cluster = (Cluster) Iterables.getFirst(set, (Object) null);
        Cluster findOptimal = findOptimal(cluster, set);
        Measure measure = (Measure) ((ClusterFitParameters) this.parameters).measure.value();
        while (set.size() > 1) {
            double calculate = measure.calculate(cluster.getCentroid(), findOptimal.getCentroid());
            Cluster findOptimal2 = findOptimal(findOptimal, set);
            double calculate2 = measure.calculate(findOptimal.getCentroid(), findOptimal2.getCentroid());
            if (cluster.equals(findOptimal2) || measure.getOptimum().test(calculate, calculate2)) {
                set.remove(cluster);
                set.remove(findOptimal);
                NDArray divi = cluster.getCentroid().add(findOptimal.getCentroid()).divi(2.0d);
                Cluster cluster2 = new Cluster();
                cluster2.setId(cluster.getId());
                cluster2.setCentroid(divi);
                cluster2.setScore(calculate);
                cluster2.setLeft(cluster);
                cluster2.setRight(findOptimal);
                Iterator<NDArray> it = cluster.getPoints().iterator();
                while (it.hasNext()) {
                    cluster2.addPoint(it.next());
                }
                Iterator<NDArray> it2 = findOptimal.getPoints().iterator();
                while (it2.hasNext()) {
                    cluster2.addPoint(it2.next());
                }
                set.add(cluster2);
                cluster = cluster2;
                findOptimal = findOptimal(cluster, set);
            } else {
                cluster = findOptimal;
                findOptimal = findOptimal2;
            }
        }
        this.clustering.root = (Cluster) Iterables.getFirst(set, (Object) null);
        int i = 0;
        LinkedList linkedList = new LinkedList();
        linkedList.add(this.clustering.root);
        while (linkedList.size() > 0) {
            Cluster cluster3 = (Cluster) linkedList.remove();
            if (cluster3 != null) {
                cluster3.setId(i);
                i++;
                linkedList.add(cluster3.getLeft());
                linkedList.add(cluster3.getRight());
            }
        }
    }

    private Cluster findOptimal(Cluster cluster, Set<Cluster> set) {
        Measure measure = (Measure) ((ClusterFitParameters) this.parameters).measure.value();
        Stream<R> map = set.parallelStream().filter(cluster2 -> {
            return cluster2 != cluster;
        }).map(cluster3 -> {
            return Tuples.$(cluster3, Double.valueOf(measure.calculate(cluster.getCentroid(), cluster3.getCentroid())));
        });
        return (Cluster) (measure.getOptimum() == Optimum.MAXIMUM ? map.max(Map.Entry.comparingByValue()) : map.min(Map.Entry.comparingByValue())).map((v0) -> {
            return v0.getV1();
        }).orElse(set.parallelStream().filter(cluster4 -> {
            return cluster4 != cluster;
        }).findFirst().orElse(null));
    }

    private static /* synthetic */ Object $deserializeLambda$(SerializedLambda serializedLambda) {
        String implMethodName = serializedLambda.getImplMethodName();
        boolean z = -1;
        switch (implMethodName.hashCode()) {
            case -468424929:
                if (implMethodName.equals("lambda$estimate$7e196bb3$1")) {
                    z = true;
                    break;
                }
                break;
            case 1888337453:
                if (implMethodName.equals("getNDArray")) {
                    z = false;
                    break;
                }
                break;
        }
        switch (z) {
            case false:
                if (serializedLambda.getImplMethodKind() == 5 && serializedLambda.getFunctionalInterfaceClass().equals("com/gengoai/function/SerializableFunction") && serializedLambda.getFunctionalInterfaceMethodName().equals("apply") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;)Ljava/lang/Object;") && serializedLambda.getImplClass().equals("com/gengoai/apollo/ml/model/clustering/Clusterer") && serializedLambda.getImplMethodSignature().equals("(Lcom/gengoai/apollo/ml/Datum;)Lcom/gengoai/apollo/math/linalg/NDArray;")) {
                    GreedyAgglomerativeClusterer greedyAgglomerativeClusterer = (GreedyAgglomerativeClusterer) serializedLambda.getCapturedArg(0);
                    return greedyAgglomerativeClusterer::getNDArray;
                }
                break;
            case true:
                if (serializedLambda.getImplMethodKind() == 6 && serializedLambda.getFunctionalInterfaceClass().equals("com/gengoai/function/SerializableBiFunction") && serializedLambda.getFunctionalInterfaceMethodName().equals("apply") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;Ljava/lang/Object;)Ljava/lang/Object;") && serializedLambda.getImplClass().equals("com/gengoai/apollo/ml/model/clustering/GreedyAgglomerativeClusterer") && serializedLambda.getImplMethodSignature().equals("(Lcom/gengoai/apollo/math/linalg/NDArray;Ljava/lang/Long;)Lcom/gengoai/apollo/ml/model/clustering/Cluster;")) {
                    return (nDArray, l) -> {
                        Cluster cluster = new Cluster();
                        cluster.setId(l.intValue());
                        cluster.setCentroid(nDArray);
                        cluster.setScore(0.0d);
                        cluster.addPoint(nDArray);
                        return cluster;
                    };
                }
                break;
        }
        throw new IllegalArgumentException("Invalid lambda deserialization");
    }
}
