package com.gengoai.apollo.ml.model.clustering;

import com.gengoai.ParamMap;
import com.gengoai.apollo.math.statistics.measure.Measure;
import com.gengoai.apollo.ml.DataSet;
import com.gengoai.apollo.ml.Datum;
import com.gengoai.apollo.ml.InMemoryDataSet;
import com.gengoai.apollo.ml.model.Params;
import com.gengoai.apollo.ml.model.clustering.KMeans;
import com.gengoai.conversion.Cast;
import com.gengoai.function.Functional;
import com.gengoai.tuple.Tuples;
import java.util.Collection;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.function.Consumer;
import java.util.stream.Collectors;
import lombok.NonNull;

/* loaded from: input_file:com/gengoai/apollo/ml/model/clustering/DivisiveKMeans.class */
public class DivisiveKMeans extends FlatCentroidClusterer {
    private static final long serialVersionUID = 1;

    /* loaded from: input_file:com/gengoai/apollo/ml/model/clustering/DivisiveKMeans$Parameters.class */
    public static class Parameters extends ClusterFitParameters {
        public final ParamMap<ClusterFitParameters>.Parameter<Integer> minPoints = parameter(Params.Clustering.minPoints, 4);
        public final ParamMap<ClusterFitParameters>.Parameter<Integer> K = parameter(Params.Clustering.K, 2);
        public final ParamMap<ClusterFitParameters>.Parameter<Double> tolerance = parameter(Params.Optimizable.tolerance, Double.valueOf(100.0d));
    }

    public DivisiveKMeans() {
        super(new Parameters());
    }

    public DivisiveKMeans(@NonNull Consumer<Parameters> consumer) {
        super((ClusterFitParameters) Functional.with(new Parameters(), consumer));
        if (consumer == null) {
            throw new NullPointerException("consumer is marked non-null but is null");
        }
    }

    private FlatClustering cluster(DataSet dataSet) {
        KMeans kMeans = new KMeans((Consumer<KMeans.Parameters>) parameters -> {
            parameters.K.set((Integer) getFitParameters().K.value());
            parameters.input.set((String) getFitParameters().input.value());
            parameters.maxIterations.set(20);
            parameters.output.set((String) getFitParameters().output.value());
            parameters.measure.set((Measure) getFitParameters().measure.value());
            parameters.verbose.set(false);
        });
        kMeans.estimate(dataSet);
        return (FlatClustering) Cast.as(kMeans.getClustering());
    }

    @Override // com.gengoai.apollo.ml.model.Model
    public void estimate(@NonNull DataSet dataSet) {
        if (dataSet == null) {
            throw new NullPointerException("dataset is marked non-null but is null");
        }
        this.clustering = new FlatClustering();
        this.clustering.setMeasure((Measure) getFitParameters().measure.value());
        LinkedList linkedList = new LinkedList();
        linkedList.add(cluster(dataSet));
        while (!linkedList.isEmpty()) {
            Iterator<Cluster> it = ((FlatClustering) linkedList.remove()).iterator();
            while (it.hasNext()) {
                Cluster next = it.next();
                if (next.size() != 0) {
                    if (next.getScore() <= ((Double) getFitParameters().tolerance.value()).doubleValue() || next.getPoints().size() <= ((Integer) getFitParameters().minPoints.value()).intValue()) {
                        this.clustering.add(next);
                    } else {
                        linkedList.add(cluster(new InMemoryDataSet((Collection) next.getPoints().stream().map(nDArray -> {
                            return Datum.of(Tuples.$((String) getFitParameters().input.value(), nDArray));
                        }).collect(Collectors.toList()))));
                    }
                }
            }
        }
    }

    @Override // com.gengoai.apollo.ml.model.SingleSourceModel, com.gengoai.apollo.ml.model.Model
    public Parameters getFitParameters() {
        return (Parameters) Cast.as(this.parameters);
    }
}
