package com.github.chen0040.clustering.kmeans;

import com.github.chen0040.clustering.DistanceMeasureService;
import com.github.chen0040.data.frame.DataFrame;
import com.github.chen0040.data.frame.DataRow;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.function.BiFunction;

/* loaded from: input_file:com/github/chen0040/clustering/kmeans/KMeans.class */
public class KMeans {
    private static final Random random = new Random();
    private final Map<Integer, double[]> clusters = new HashMap();
    private int maxIters = 2000;
    private int clusterCount = 5;
    private BiFunction<DataRow, double[], Double> distanceMeasure;

    /* loaded from: input_file:com/github/chen0040/clustering/kmeans/KMeans$Cluster.class */
    private class Cluster {
        private final List<DataRow> elements;

        private Cluster() {
            this.elements = new ArrayList();
        }

        void append(DataRow dataRow) {
            this.elements.add(dataRow);
        }

        double[] calcCenter(DataFrame dataFrame) {
            if (this.elements.isEmpty()) {
                return null;
            }
            int length = dataFrame.row(0).toArray().length;
            double[] dArr = new double[length];
            int size = this.elements.size();
            for (int i = 0; i < size; i++) {
                double[] array = dataFrame.row(i).toArray();
                for (int i2 = 0; i2 < length; i2++) {
                    int i3 = i2;
                    dArr[i3] = dArr[i3] + array[i2];
                }
            }
            for (int i4 = 0; i4 < length; i4++) {
                int i5 = i4;
                dArr[i5] = dArr[i5] / size;
            }
            return dArr;
        }
    }

    public int transform(DataRow dataRow) {
        double d = Double.MAX_VALUE;
        int i = -1;
        for (int i2 = 0; i2 < this.clusterCount; i2++) {
            double distance = DistanceMeasureService.getDistance(dataRow, this.clusters.get(Integer.valueOf(i2)), this.distanceMeasure);
            if (d > distance) {
                d = distance;
                i = i2;
            }
        }
        return i;
    }

    private void initializeCluster(DataFrame dataFrame) {
        if (this.clusters.size() != this.clusterCount) {
            HashSet hashSet = new HashSet();
            int rowCount = dataFrame.rowCount();
            if (rowCount < this.clusterCount * 3) {
                this.clusterCount = Math.min(rowCount, this.clusterCount);
                for (int i = 0; i < this.clusterCount; i++) {
                    hashSet.add(Integer.valueOf(i));
                }
            } else {
                while (hashSet.size() < this.clusterCount) {
                    int nextInt = random.nextInt(rowCount);
                    if (!hashSet.contains(Integer.valueOf(nextInt))) {
                        hashSet.add(Integer.valueOf(nextInt));
                    }
                }
            }
            this.clusters.clear();
            int i2 = 0;
            Iterator it = hashSet.iterator();
            while (it.hasNext()) {
                int i3 = i2;
                i2++;
                this.clusters.put(Integer.valueOf(i3), dataFrame.row(((Integer) it.next()).intValue()).toArray());
            }
        }
    }

    public DataFrame fitAndTransform(DataFrame dataFrame) {
        DataFrame makeCopy = dataFrame.makeCopy();
        initializeCluster(makeCopy);
        int rowCount = makeCopy.rowCount();
        for (int i = 0; i < this.maxIters; i++) {
            Cluster[] clusterArr = new Cluster[this.clusterCount];
            for (int i2 = 0; i2 < this.clusterCount; i2++) {
                clusterArr[i2] = new Cluster();
            }
            for (int i3 = 0; i3 < rowCount; i3++) {
                DataRow row = makeCopy.row(i3);
                double d = Double.MAX_VALUE;
                int i4 = -1;
                for (int i5 = 0; i5 < this.clusterCount; i5++) {
                    double distance = DistanceMeasureService.getDistance(row, this.clusters.get(Integer.valueOf(i5)), this.distanceMeasure);
                    if (d > distance) {
                        d = distance;
                        i4 = i5;
                    }
                }
                clusterArr[i4].append(row);
                row.setCategoricalTargetCell("cluster", String.format("%d", Integer.valueOf(i4)));
            }
            for (int i6 = 0; i6 < this.clusterCount; i6++) {
                double[] calcCenter = clusterArr[i6].calcCenter(makeCopy);
                if (calcCenter != null) {
                    this.clusters.put(Integer.valueOf(i6), calcCenter);
                }
            }
        }
        return makeCopy;
    }

    public Map<Integer, double[]> getClusters() {
        return this.clusters;
    }

    public int getMaxIters() {
        return this.maxIters;
    }

    public int getClusterCount() {
        return this.clusterCount;
    }

    public BiFunction<DataRow, double[], Double> getDistanceMeasure() {
        return this.distanceMeasure;
    }

    public void setMaxIters(int i) {
        this.maxIters = i;
    }

    public void setClusterCount(int i) {
        this.clusterCount = i;
    }

    public void setDistanceMeasure(BiFunction<DataRow, double[], Double> biFunction) {
        this.distanceMeasure = biFunction;
    }
}
