package com.github.chen0040.clustering.em;

import com.github.chen0040.data.frame.DataFrame;
import com.github.chen0040.data.frame.DataRow;
import java.util.HashSet;
import java.util.Random;

/* loaded from: input_file:com/github/chen0040/clustering/em/EMClustering.class */
public class EMClustering {
    private static final Random random = new Random();
    protected double[][] expectactionMatrix;
    protected double[][] clusters;
    protected double sigma0 = 0.1d;
    protected int clusterCount = 10;
    private int maxIters = 2000;

    public double getDistance(double[] dArr, double[] dArr2) {
        int length = dArr.length;
        double d = 0.0d;
        for (int i = 0; i < length; i++) {
            d += (dArr[i] - dArr2[i]) * (dArr[i] - dArr2[i]);
        }
        return Math.sqrt(d);
    }

    public double calcExpectation(DataRow dataRow, int i) {
        double[] dArr = new double[this.clusterCount];
        double d = 0.0d;
        for (int i2 = 0; i2 < this.clusterCount; i2++) {
            dArr[i2] = Math.exp((-Math.sqrt(getDistance(dataRow.toArray(), this.clusters[i2]))) / ((2.0d * this.sigma0) * this.sigma0));
            d += dArr[i2];
        }
        return dArr[i] / d;
    }

    /* JADX WARN: Type inference failed for: r1v13, types: [double[], double[][]] */
    private void initializeCluster(DataFrame dataFrame, int i) {
        HashSet<Integer> hashSet = new HashSet();
        int rowCount = dataFrame.rowCount();
        if (rowCount < this.clusterCount * 3) {
            this.clusterCount = Math.min(rowCount, this.clusterCount);
            for (int i2 = 0; i2 < this.clusterCount; i2++) {
                hashSet.add(Integer.valueOf(i2));
            }
        } else {
            while (hashSet.size() < this.clusterCount) {
                int nextInt = random.nextInt(rowCount);
                if (!hashSet.contains(Integer.valueOf(nextInt))) {
                    hashSet.add(Integer.valueOf(nextInt));
                }
            }
        }
        this.clusters = new double[this.clusterCount];
        for (int i3 = 0; i3 < this.clusterCount; i3++) {
            this.clusters[i3] = new double[i];
        }
        int i4 = 0;
        for (Integer num : hashSet) {
            double[] dArr = this.clusters[i4];
            double[] array = dataFrame.row(num.intValue()).toArray();
            for (int i5 = 0; i5 < i; i5++) {
                dArr[i5] = array[i5];
            }
            i4++;
        }
    }

    /* JADX WARN: Type inference failed for: r1v1, types: [double[], double[][]] */
    private void initializeEM(DataFrame dataFrame) {
        int rowCount = dataFrame.rowCount();
        this.expectactionMatrix = new double[rowCount];
        for (int i = 0; i < rowCount; i++) {
            this.expectactionMatrix[i] = new double[this.clusterCount];
        }
    }

    public DataFrame fitAndTransform(DataFrame dataFrame) {
        DataFrame makeCopy = dataFrame.makeCopy();
        int rowCount = makeCopy.rowCount();
        int length = makeCopy.row(0).toArray().length;
        initializeCluster(makeCopy, length);
        initializeEM(makeCopy);
        if (0 < this.maxIters) {
            for (int i = 0; i < rowCount; i++) {
                for (int i2 = 0; i2 < this.clusterCount; i2++) {
                    this.expectactionMatrix[i][i2] = calcExpectation(makeCopy.row(i), i2);
                }
            }
            for (int i3 = 0; i3 < this.clusterCount; i3++) {
                for (int i4 = 0; i4 < length; i4++) {
                    double d = 0.0d;
                    double d2 = 0.0d;
                    for (int i5 = 0; i5 < rowCount; i5++) {
                        d2 += this.expectactionMatrix[i5][i3] * makeCopy.row(i5).toArray()[i4];
                        d += this.expectactionMatrix[i5][i3];
                    }
                    this.clusters[i3][i4] = d2 / d;
                }
            }
        }
        for (int i6 = 0; i6 < rowCount; i6++) {
            int i7 = -1;
            double d3 = Double.NEGATIVE_INFINITY;
            for (int i8 = 0; i8 < this.clusterCount; i8++) {
                double d4 = this.expectactionMatrix[i6][i8];
                if (d4 > d3) {
                    d3 = d4;
                    i7 = i8;
                }
            }
            makeCopy.row(i6).setCategoricalTargetCell("cluster", "" + i7);
        }
        return makeCopy;
    }

    public double getSigma0() {
        return this.sigma0;
    }

    public void setSigma0(double d) {
        this.sigma0 = d;
    }

    public int getClusterCount() {
        return this.clusterCount;
    }

    public void setClusterCount(int i) {
        this.clusterCount = i;
    }

    public int getMaxIters() {
        return this.maxIters;
    }

    public void setMaxIters(int i) {
        this.maxIters = i;
    }
}
