package greycat.ml.common.matrix.operation;

import greycat.ml.common.matrix.MatrixOps;
import greycat.ml.common.matrix.VolatileDMatrix;
import greycat.struct.DMatrix;

/* loaded from: input_file:greycat/ml/common/matrix/operation/MultivariateNormalDistribution.class */
public class MultivariateNormalDistribution {
    double[] min;
    double[] max;
    double[] means;
    double[] covDiag;
    DMatrix inv;
    DMatrix covariance;
    PInvSVD pinvsvd;
    int rank;
    double det;

    public MultivariateNormalDistribution(double[] dArr, DMatrix dMatrix, boolean z) {
        this.means = dArr;
        if (dMatrix != null) {
            this.covariance = dMatrix;
            this.covDiag = new double[dMatrix.rows()];
            for (int i = 0; i < this.covDiag.length; i++) {
                this.covDiag[i] = dMatrix.get(i, i);
            }
            this.pinvsvd = new PInvSVD();
            this.pinvsvd.factor(this.covariance, false);
            this.inv = this.pinvsvd.getPInv();
            this.det = this.pinvsvd.getDeterminant();
            this.rank = this.pinvsvd.getRank();
            if (z || this.rank >= dMatrix.rows()) {
                return;
            }
            this.covariance = VolatileDMatrix.cloneFrom(dMatrix);
            double[] dArr2 = new double[this.covDiag.length];
            for (int i2 = 0; i2 < this.covDiag.length; i2++) {
                dArr2[i2] = Math.sqrt(this.covDiag[i2]);
            }
            for (int i3 = 0; i3 < this.covDiag.length; i3++) {
                for (int i4 = i3 + 1; i4 < this.covDiag.length; i4++) {
                    double d = this.covariance.get(i3, i4) - ((1.0E-4d * dArr2[i3]) * dArr2[i4]);
                    this.covariance.set(i3, i4, d);
                    this.covariance.set(i4, i3, d);
                }
            }
            this.pinvsvd = new PInvSVD();
            this.pinvsvd.factor(this.covariance, false);
            this.inv = this.pinvsvd.getPInv();
            this.det = this.pinvsvd.getDeterminant();
            this.rank = this.pinvsvd.getRank();
        }
    }

    public double[] getMin() {
        return this.min;
    }

    public double[] getMax() {
        return this.max;
    }

    public double[] getAvg() {
        return this.means;
    }

    public double[] getCovDiag() {
        return this.covDiag;
    }

    public void setMin(double[] dArr) {
        this.min = dArr;
    }

    public void setMax(double[] dArr) {
        this.max = dArr;
    }

    public static DMatrix getCovariance(double[] dArr, double[] dArr2, int i) {
        if (i < 2) {
            return null;
        }
        int length = dArr.length;
        double[] dArr3 = new double[length];
        for (int i2 = 0; i2 < length; i2++) {
            dArr3[i2] = dArr[i2] / i;
        }
        double[] dArr4 = new double[length * length];
        double d = i / (i - 1);
        int i3 = 0;
        for (int i4 = 0; i4 < length; i4++) {
            for (int i5 = i4; i5 < length; i5++) {
                dArr4[(i4 * length) + i5] = ((dArr2[i3] / i) - (dArr3[i4] * dArr3[i5])) * d;
                dArr4[(i5 * length) + i4] = dArr4[(i4 * length) + i5];
                i3++;
            }
        }
        return VolatileDMatrix.wrap(dArr4, length, length);
    }

    public static MultivariateNormalDistribution getDistribution(double[] dArr, double[] dArr2, int i, boolean z) {
        if (i < 2) {
            return null;
        }
        int length = dArr.length;
        double[] dArr3 = new double[length];
        for (int i2 = 0; i2 < length; i2++) {
            dArr3[i2] = dArr[i2] / i;
        }
        double[] dArr4 = new double[length * length];
        double d = i / (i - 1);
        int i3 = 0;
        for (int i4 = 0; i4 < length; i4++) {
            for (int i5 = i4; i5 < length; i5++) {
                dArr4[(i4 * length) + i5] = ((dArr2[i3] / i) - (dArr3[i4] * dArr3[i5])) * d;
                dArr4[(i5 * length) + i4] = dArr4[(i4 * length) + i5];
                i3++;
            }
        }
        return new MultivariateNormalDistribution(dArr3, VolatileDMatrix.wrap(dArr4, length, length), z);
    }

    public double density(double[] dArr, boolean z) {
        return z ? getExponentTerm(dArr) : Math.pow(6.283185307179586d, (-0.5d) * this.rank) * Math.pow(this.det, -0.5d) * getExponentTerm(dArr);
    }

    private double getExponentTerm(double[] dArr) {
        double[] dArr2 = new double[dArr.length];
        System.arraycopy(dArr, 0, dArr2, 0, dArr.length);
        for (int i = 0; i < dArr.length; i++) {
            dArr2[i] = dArr2[i] - this.means[i];
        }
        return Math.exp((-0.5d) * MatrixOps.multiply(MatrixOps.multiply(VolatileDMatrix.wrap(dArr2, 1, dArr2.length), this.inv), VolatileDMatrix.wrap(dArr2, dArr2.length, 1)).get(0, 0));
    }

    public MultivariateNormalDistribution clone(double[] dArr) {
        MultivariateNormalDistribution multivariateNormalDistribution = new MultivariateNormalDistribution(dArr, null, false);
        multivariateNormalDistribution.pinvsvd = this.pinvsvd;
        multivariateNormalDistribution.inv = this.inv;
        multivariateNormalDistribution.det = this.det;
        multivariateNormalDistribution.rank = this.rank;
        multivariateNormalDistribution.covDiag = this.covDiag;
        return multivariateNormalDistribution;
    }

    public double densityExponent(double[] dArr) {
        double[] dArr2 = new double[dArr.length];
        System.arraycopy(dArr, 0, dArr2, 0, dArr.length);
        for (int i = 0; i < dArr.length; i++) {
            dArr2[i] = dArr2[i] - this.means[i];
        }
        return (-0.5d) * MatrixOps.multiply(MatrixOps.multiply(VolatileDMatrix.wrap(dArr2, 1, dArr2.length), this.inv), VolatileDMatrix.wrap(dArr2, dArr2.length, 1)).get(0, 0);
    }
}
