package math.linalg;

import java.util.ArrayList;
import math.distribution.StudentT;
import math.list.DoubleArrayList;
import math.list.DoubleList;

/* loaded from: input_file:math/linalg/OLS.class */
public class OLS {
    public static LSSummary estimate(double d, DMatrix dMatrix, DMatrix dMatrix2) {
        if (dMatrix.numRows() != dMatrix2.numRows()) {
            throw new IllegalArgumentException("X.numRows != y.numRows : " + dMatrix.numRows() + " != " + dMatrix2.numRows());
        }
        if (dMatrix.numRows() - dMatrix.numColumns() < 1) {
            throw new IllegalArgumentException("degrees of freedom < 1 : " + (dMatrix.numRows() - dMatrix.numColumns()));
        }
        if (d <= 0.0d) {
            throw new IllegalArgumentException("alpha <= 0 : " + d);
        }
        if (d >= 1.0d) {
            throw new IllegalArgumentException("alpha >= 1 : " + d);
        }
        LSSummary lSSummary = new LSSummary(d, dMatrix, dMatrix2);
        DMatrix transpose = dMatrix.transpose();
        DMatrix inverse = transpose.mul(dMatrix).inverse();
        DMatrix mul = inverse.mul(transpose).mul(dMatrix2);
        lSSummary.setBeta(mul);
        DMatrix mul2 = dMatrix.mul(mul);
        lSSummary.setYHat(mul2);
        DMatrix dMatrix3 = new DMatrix(1, dMatrix2.numRows());
        for (int i = 0; i < dMatrix2.numRows(); i++) {
            dMatrix3.setUnsafe(0, i, 1.0d);
        }
        double d2 = dMatrix3.mul(dMatrix2).scaleInplace(1.0d / dMatrix2.numRows()).get(0, 0);
        lSSummary.setYBar(d2);
        DMatrix dMatrix4 = new DMatrix(dMatrix2.numRows(), 1);
        for (int i2 = 0; i2 < dMatrix2.numRows(); i2++) {
            dMatrix4.setUnsafe(i2, 0, 1.0d);
        }
        DMatrix scaleInplace = dMatrix4.scaleInplace(d2);
        DMatrix minus = mul2.minus(scaleInplace);
        DMatrix minus2 = dMatrix2.minus(scaleInplace);
        double d3 = minus.transpose().mul(minus).get(0, 0) / minus2.transpose().mul(minus2).get(0, 0);
        lSSummary.setRSquared(d3 > 1.0d ? 1.0d : d3);
        DMatrix minus3 = dMatrix2.minus(mul2);
        lSSummary.setResiduals(minus3);
        int numRows = minus3.numRows() - dMatrix.numColumns();
        lSSummary.setDegreesOfFreedom(numRows);
        double d4 = minus3.transpose().mul(minus3).scaleInplace(1.0d / numRows).get(0, 0);
        lSSummary.setSigmaHatSquared(d4);
        DMatrix scaleInplace2 = inverse.scaleInplace(d4);
        lSSummary.setVarianceCovarianceMatrix(scaleInplace2);
        DoubleArrayList doubleArrayList = new DoubleArrayList(scaleInplace2.numRows());
        for (int i3 = 0; i3 < scaleInplace2.numRows(); i3++) {
            double d5 = scaleInplace2.get(i3, i3);
            if (d5 < 0.0d) {
                d5 = Double.MIN_NORMAL;
                scaleInplace2.set(i3, i3, Double.MIN_NORMAL);
            }
            doubleArrayList.add(Math.sqrt(d5));
        }
        lSSummary.setCoefficientStandardErrors(doubleArrayList);
        DoubleArrayList doubleArrayList2 = new DoubleArrayList(scaleInplace2.numRows());
        DoubleArrayList doubleArrayList3 = new DoubleArrayList(scaleInplace2.numRows());
        ArrayList arrayList = new ArrayList();
        StudentT studentT = new StudentT(numRows);
        double inverseCdf = studentT.inverseCdf(1.0d - (d / 2.0d));
        for (int i4 = 0; i4 < scaleInplace2.numRows(); i4++) {
            double d6 = mul.get(i4, 0);
            double d7 = doubleArrayList.get(i4);
            double d8 = d6 / d7;
            double cdf = 2.0d * (1.0d - studentT.cdf(Math.abs(d8)));
            doubleArrayList2.add(d8);
            doubleArrayList3.add(cdf);
            arrayList.add(DoubleList.of(d6 - (inverseCdf * d7), d6 + (inverseCdf * d7)));
        }
        lSSummary.setTValues(doubleArrayList2);
        lSSummary.setPValues(doubleArrayList3);
        lSSummary.setConfidenceIntervals(arrayList);
        return lSSummary;
    }
}
