package com.gengoai.apollo.ml.model.clustering;

import com.gengoai.ParamMap;
import com.gengoai.apollo.math.linalg.NDArray;
import com.gengoai.apollo.math.statistics.measure.Measure;
import com.gengoai.apollo.ml.DataSet;
import com.gengoai.apollo.ml.model.Params;
import com.gengoai.collection.Iterables;
import com.gengoai.conversion.Cast;
import com.gengoai.function.Functional;
import com.gengoai.tuple.Tuple3;
import com.gengoai.tuple.Tuples;
import java.lang.invoke.SerializedLambda;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Comparator;
import java.util.Iterator;
import java.util.List;
import java.util.PriorityQueue;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.Consumer;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import java.util.stream.Stream;
import lombok.NonNull;

/* loaded from: input_file:com/gengoai/apollo/ml/model/clustering/AgglomerativeClusterer.class */
public class AgglomerativeClusterer extends HierarchicalClusterer {
    private static final long serialVersionUID = 1;
    private final AtomicInteger idGenerator;

    /* loaded from: input_file:com/gengoai/apollo/ml/model/clustering/AgglomerativeClusterer$Parameters.class */
    public static class Parameters extends ClusterFitParameters {
        private static final long serialVersionUID = 1;
        public final ParamMap<ClusterFitParameters>.Parameter<Linkage> linkage = parameter(Params.Clustering.linkage, Linkage.Complete);
    }

    public AgglomerativeClusterer() {
        super(new Parameters());
        this.idGenerator = new AtomicInteger();
    }

    public AgglomerativeClusterer(@NonNull Parameters parameters) {
        super(parameters);
        this.idGenerator = new AtomicInteger();
        if (parameters == null) {
            throw new NullPointerException("parameters is marked non-null but is null");
        }
    }

    public AgglomerativeClusterer(@NonNull Consumer<Parameters> consumer) {
        super((ClusterFitParameters) Functional.with(new Parameters(), consumer));
        this.idGenerator = new AtomicInteger();
        if (consumer == null) {
            throw new NullPointerException("updater is marked non-null but is null");
        }
    }

    private void doTurn(PriorityQueue<Tuple3<Cluster, Cluster, Double>> priorityQueue, List<Cluster> list, Parameters parameters) {
        Tuple3<Cluster, Cluster, Double> remove = priorityQueue.remove();
        if (remove != null) {
            priorityQueue.removeIf(tuple3 -> {
                return ((Cluster) tuple3.v2).getId() == ((Cluster) remove.v2).getId() || ((Cluster) tuple3.v1).getId() == ((Cluster) remove.v1).getId() || ((Cluster) tuple3.v2).getId() == ((Cluster) remove.v1).getId() || ((Cluster) tuple3.v1).getId() == ((Cluster) remove.v2).getId();
            });
            Cluster cluster = new Cluster();
            cluster.setId(this.idGenerator.getAndIncrement());
            cluster.setLeft((Cluster) remove.getV1());
            cluster.setRight((Cluster) remove.getV2());
            ((Cluster) remove.getV1()).setParent(cluster);
            ((Cluster) remove.getV2()).setParent(cluster);
            cluster.setScore(((Double) remove.v3).doubleValue());
            Iterator it = Iterables.concat(new Iterable[]{((Cluster) remove.getV1()).getPoints(), ((Cluster) remove.getV2()).getPoints()}).iterator();
            while (it.hasNext()) {
                cluster.addPoint((NDArray) it.next());
            }
            list.remove(remove.getV1());
            list.remove(remove.getV2());
            priorityQueue.addAll((Collection) list.parallelStream().map(cluster2 -> {
                return Tuples.$(cluster, cluster2, Double.valueOf(((Linkage) parameters.linkage.value()).calculate(cluster, cluster2, (Measure) parameters.measure.value())));
            }).collect(Collectors.toList()));
            list.add(cluster);
        }
    }

    @Override // com.gengoai.apollo.ml.model.Model
    public void estimate(@NonNull DataSet dataSet) {
        if (dataSet == null) {
            throw new NullPointerException("dataset is marked non-null but is null");
        }
        Parameters parameters = (Parameters) Cast.as(getFitParameters());
        this.clustering = new HierarchicalClustering();
        this.clustering.setMeasure((Measure) parameters.measure.value());
        List<NDArray> collect = dataSet.parallelStream().map(datum -> {
            return datum.get(parameters.input.value()).asNDArray();
        }).collect();
        this.idGenerator.set(0);
        PriorityQueue<Tuple3<Cluster, Cluster, Double>> priorityQueue = new PriorityQueue<>((Comparator<? super Tuple3<Cluster, Cluster, Double>>) Comparator.comparingDouble((v0) -> {
            return v0.getV3();
        }));
        List<Cluster> initDistanceMatrix = initDistanceMatrix(collect, priorityQueue, parameters);
        while (initDistanceMatrix.size() > 1) {
            doTurn(priorityQueue, initDistanceMatrix, parameters);
        }
        this.clustering.root = initDistanceMatrix.get(0);
    }

    @Override // com.gengoai.apollo.ml.model.SingleSourceModel, com.gengoai.apollo.ml.model.Model
    public Parameters getFitParameters() {
        return new Parameters();
    }

    private List<Cluster> initDistanceMatrix(List<NDArray> list, PriorityQueue<Tuple3<Cluster, Cluster, Double>> priorityQueue, Parameters parameters) {
        ArrayList arrayList = new ArrayList();
        for (NDArray nDArray : list) {
            Cluster cluster = new Cluster();
            cluster.addPoint(nDArray);
            cluster.setId(this.idGenerator.getAndIncrement());
            arrayList.add(cluster);
        }
        int size = list.size();
        priorityQueue.addAll((Collection) ((Stream) IntStream.range(0, size - 2).boxed().flatMap(num -> {
            return IntStream.range(num.intValue() + 1, size).boxed().map(num -> {
                return Tuples.$((Cluster) arrayList.get(num.intValue()), (Cluster) arrayList.get(num.intValue()));
            });
        }).parallel()).map(tuple2 -> {
            return Tuples.$((Cluster) tuple2.v1, (Cluster) tuple2.v2, Double.valueOf(((Linkage) parameters.linkage.value()).calculate((Cluster) tuple2.v1, (Cluster) tuple2.v2, (Measure) parameters.measure.value())));
        }).collect(Collectors.toList()));
        return arrayList;
    }

    private static /* synthetic */ Object $deserializeLambda$(SerializedLambda serializedLambda) {
        String implMethodName = serializedLambda.getImplMethodName();
        boolean z = -1;
        switch (implMethodName.hashCode()) {
            case 95354007:
                if (implMethodName.equals("lambda$estimate$470e01e7$1")) {
                    z = false;
                    break;
                }
                break;
        }
        switch (z) {
            case false:
                if (serializedLambda.getImplMethodKind() == 6 && serializedLambda.getFunctionalInterfaceClass().equals("com/gengoai/function/SerializableFunction") && serializedLambda.getFunctionalInterfaceMethodName().equals("apply") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;)Ljava/lang/Object;") && serializedLambda.getImplClass().equals("com/gengoai/apollo/ml/model/clustering/AgglomerativeClusterer") && serializedLambda.getImplMethodSignature().equals("(Lcom/gengoai/apollo/ml/model/clustering/AgglomerativeClusterer$Parameters;Lcom/gengoai/apollo/ml/Datum;)Lcom/gengoai/apollo/math/linalg/NDArray;")) {
                    Parameters parameters = (Parameters) serializedLambda.getCapturedArg(0);
                    return datum -> {
                        return datum.get(parameters.input.value()).asNDArray();
                    };
                }
                break;
        }
        throw new IllegalArgumentException("Invalid lambda deserialization");
    }
}
