package greycat.ml.profiling;

import greycat.struct.DMatrix;
import greycat.struct.DoubleArray;
import greycat.struct.ENode;
import greycat.struct.matrix.MatrixOps;
import greycat.struct.matrix.VolatileDMatrix;

/* loaded from: input_file:greycat/ml/profiling/GaussianENode.class */
public class GaussianENode {
    public static final String NAME = "GaussianENode";
    private ENode backend;
    private double[] avg = null;
    private double[] std = null;
    private DMatrix cov = null;

    public GaussianENode(ENode eNode) {
        if (eNode == null) {
            throw new RuntimeException("backend can't be null for Gaussian node!");
        }
        this.backend = eNode;
    }

    public void setPrecisions(double[] dArr) {
        ((DoubleArray) this.backend.getOrCreate(Gaussian.PRECISIONS, (byte) 6)).initWith(dArr);
    }

    public void learn(double[] dArr) {
        int length = dArr.length;
        int intValue = ((Integer) this.backend.getWithDefault(Gaussian.TOTAL, 0)).intValue();
        if (intValue == 0) {
            double[] dArr2 = new double[length];
            System.arraycopy(dArr, 0, dArr2, 0, length);
            this.backend.set(Gaussian.TOTAL, (byte) 4, 1);
            ((DoubleArray) this.backend.getOrCreate(Gaussian.SUM, (byte) 6)).initWith(dArr2);
        } else {
            DoubleArray doubleArray = (DoubleArray) this.backend.getOrCreate(Gaussian.MIN, (byte) 6);
            DoubleArray doubleArray2 = (DoubleArray) this.backend.getOrCreate(Gaussian.MAX, (byte) 6);
            DoubleArray doubleArray3 = (DoubleArray) this.backend.getOrCreate(Gaussian.SUMSQ, (byte) 6);
            DoubleArray doubleArray4 = (DoubleArray) this.backend.get(Gaussian.SUM);
            if (length != doubleArray4.size()) {
                throw new RuntimeException("Input dimensions have changed!");
            }
            if (intValue == 1) {
                double[] extract = doubleArray4.extract();
                doubleArray.initWith(extract);
                doubleArray2.initWith(extract);
                doubleArray3.init((length * (length + 1)) / 2);
                int i = 0;
                for (int i2 = 0; i2 < length; i2++) {
                    for (int i3 = i2; i3 < length; i3++) {
                        doubleArray3.set(i, extract[i2] * extract[i3]);
                        i++;
                    }
                }
            }
            for (int i4 = 0; i4 < length; i4++) {
                if (dArr[i4] < doubleArray.get(i4)) {
                    doubleArray.set(i4, dArr[i4]);
                }
                if (dArr[i4] > doubleArray2.get(i4)) {
                    doubleArray2.set(i4, dArr[i4]);
                }
                doubleArray4.set(i4, doubleArray4.get(i4) + dArr[i4]);
            }
            int i5 = 0;
            for (int i6 = 0; i6 < length; i6++) {
                for (int i7 = i6; i7 < length; i7++) {
                    doubleArray3.set(i5, doubleArray3.get(i5) + (dArr[i6] * dArr[i7]));
                    i5++;
                }
            }
            this.backend.set(Gaussian.TOTAL, (byte) 4, Integer.valueOf(intValue + 1));
        }
        invalidate();
    }

    private void invalidate() {
        this.avg = null;
        this.std = null;
        this.cov = null;
    }

    private boolean initAvg() {
        if (this.avg != null) {
            return true;
        }
        int intValue = ((Integer) this.backend.getWithDefault(Gaussian.TOTAL, 0)).intValue();
        if (intValue == 0) {
            return false;
        }
        double[] extract = ((DoubleArray) this.backend.get(Gaussian.SUM)).extract();
        this.avg = new double[extract.length];
        for (int i = 0; i < extract.length; i++) {
            this.avg[i] = extract[i] / intValue;
        }
        return true;
    }

    private boolean initStd() {
        if (this.std != null) {
            return true;
        }
        int intValue = ((Integer) this.backend.getWithDefault(Gaussian.TOTAL, 0)).intValue();
        if (intValue < 2) {
            return false;
        }
        initAvg();
        int length = this.avg.length;
        double[] dArr = (double[]) this.backend.getWithDefault(Gaussian.PRECISIONS, new double[this.avg.length]);
        double[] sumSq = getSumSq();
        this.std = new double[length];
        double d = intValue / (intValue - 1);
        int i = 0;
        for (int i2 = 0; i2 < length; i2++) {
            this.std[i2] = Math.sqrt(((sumSq[i] / intValue) - (this.avg[i2] * this.avg[i2])) * d);
            i += length - i2;
            if (this.std[i2] < dArr[i2]) {
                this.std[i2] = dArr[i2];
            }
        }
        return true;
    }

    private boolean initCov() {
        if (this.cov != null) {
            return true;
        }
        int intValue = ((Integer) this.backend.getWithDefault(Gaussian.TOTAL, 0)).intValue();
        if (intValue < 2) {
            return false;
        }
        initAvg();
        int length = this.avg.length;
        DoubleArray doubleArray = (DoubleArray) this.backend.get(Gaussian.PRECISIONS);
        double[] extract = doubleArray != null ? doubleArray.extract() : new double[this.avg.length];
        for (int i = 0; i < extract.length; i++) {
            extract[i] = extract[i] * extract[i];
        }
        double[] sumSq = getSumSq();
        double[] dArr = new double[length * length];
        double d = intValue / (intValue - 1);
        int i2 = 0;
        for (int i3 = 0; i3 < length; i3++) {
            for (int i4 = i3; i4 < length; i4++) {
                dArr[(i3 * length) + i4] = ((sumSq[i2] / intValue) - (this.avg[i3] * this.avg[i4])) * d;
                dArr[(i4 * length) + i3] = dArr[(i3 * length) + i4];
                i2++;
            }
            if (dArr[(i3 * length) + i3] < extract[i3]) {
                dArr[(i3 * length) + i3] = extract[i3];
            }
        }
        this.cov = VolatileDMatrix.wrap(dArr, length, length);
        return true;
    }

    public double[] getAvg() {
        if (!initAvg()) {
            return null;
        }
        double[] dArr = new double[this.avg.length];
        System.arraycopy(this.avg, 0, dArr, 0, this.avg.length);
        return dArr;
    }

    public double[] getSTD() {
        if (!initStd()) {
            return null;
        }
        double[] dArr = new double[this.std.length];
        System.arraycopy(this.std, 0, dArr, 0, this.std.length);
        return dArr;
    }

    public DMatrix getCovariance() {
        if (!initCov()) {
            return null;
        }
        VolatileDMatrix empty = VolatileDMatrix.empty(this.cov.rows(), this.cov.columns());
        MatrixOps.copy(this.cov, empty);
        return empty;
    }

    public DMatrix getPearson() {
        if (!initCov()) {
            return null;
        }
        VolatileDMatrix empty = VolatileDMatrix.empty(this.cov.rows(), this.cov.columns());
        for (int i = 0; i < empty.rows(); i++) {
            for (int i2 = 0; i2 < empty.columns(); i2++) {
                if (this.cov.get(i, i) != 0.0d && this.cov.get(i2, i2) != 0.0d) {
                    empty.set(i, i2, this.cov.get(i, i2) / Math.sqrt(this.cov.get(i, i) * this.cov.get(i2, i2)));
                }
            }
        }
        return empty;
    }

    public double[] getSum() {
        if (((Integer) this.backend.getWithDefault(Gaussian.TOTAL, 0)).intValue() != 0) {
            return ((DoubleArray) this.backend.get(Gaussian.SUM)).extract();
        }
        return null;
    }

    public double[] getSumSq() {
        int intValue = ((Integer) this.backend.getWithDefault(Gaussian.TOTAL, 0)).intValue();
        if (intValue == 0) {
            return null;
        }
        if (intValue != 1) {
            return ((DoubleArray) this.backend.get(Gaussian.SUMSQ)).extract();
        }
        double[] extract = ((DoubleArray) this.backend.get(Gaussian.SUM)).extract();
        int length = extract.length;
        double[] dArr = new double[(length * (length + 1)) / 2];
        int i = 0;
        for (int i2 = 0; i2 < length; i2++) {
            for (int i3 = i2; i3 < length; i3++) {
                dArr[i] = extract[i2] * extract[i3];
                i++;
            }
        }
        return dArr;
    }

    public double[] getMin() {
        int intValue = ((Integer) this.backend.getWithDefault(Gaussian.TOTAL, 0)).intValue();
        if (intValue == 0) {
            return null;
        }
        return intValue == 1 ? ((DoubleArray) this.backend.get(Gaussian.SUM)).extract() : ((DoubleArray) this.backend.get(Gaussian.MIN)).extract();
    }

    public double[] getMax() {
        int intValue = ((Integer) this.backend.getWithDefault(Gaussian.TOTAL, 0)).intValue();
        if (intValue == 0) {
            return null;
        }
        return intValue == 1 ? ((DoubleArray) this.backend.get(Gaussian.SUM)).extract() : ((DoubleArray) this.backend.get(Gaussian.MAX)).extract();
    }

    public int getTotal() {
        return ((Integer) this.backend.getWithDefault(Gaussian.TOTAL, 0)).intValue();
    }

    public int getDimensions() {
        if (((Integer) this.backend.getWithDefault(Gaussian.TOTAL, 0)).intValue() != 0) {
            return ((DoubleArray) this.backend.get(Gaussian.SUM)).size();
        }
        return 0;
    }
}
