package hex.genmodel.algos.glrm;

import hex.ModelCategory;
import hex.genmodel.MojoModel;
import java.util.EnumSet;
import java.util.Random;

/* loaded from: input_file:hex/genmodel/algos/glrm/GlrmMojoModel.class */
public class GlrmMojoModel extends MojoModel {
    public int _ncolA;
    public int _ncolX;
    public int _ncolY;
    public int _nrowY;
    public double[][] _archetypes;
    public int[] _numLevels;
    public int[] _permutation;
    public GlrmLoss[] _losses;
    public GlrmRegularizer _regx;
    public double _gammax;
    public GlrmInitialization _init;
    public int _ncats;
    public int _nnums;
    public double[] _normSub;
    public double[] _normMul;
    private static double alpha;
    private static final double DOWN_FACTOR = 0.5d;
    private static final double UP_FACTOR;
    private static EnumSet<ModelCategory> CATEGORIES;
    static final /* synthetic */ boolean $assertionsDisabled;

    @Override // hex.genmodel.GenModel, hex.genmodel.IGenModel
    public EnumSet<ModelCategory> getModelCategories() {
        return CATEGORIES;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public GlrmMojoModel(String[] strArr, String[][] strArr2, String str) {
        super(strArr, strArr2, str);
    }

    @Override // hex.genmodel.GenModel
    public int getPredsSize(ModelCategory modelCategory) {
        return this._ncolX;
    }

    @Override // hex.genmodel.GenModel
    public double[] score0(double[] dArr, double[] dArr2) {
        if (!$assertionsDisabled && dArr.length != this._ncolA) {
            throw new AssertionError();
        }
        if (!$assertionsDisabled && dArr2.length != this._ncolX) {
            throw new AssertionError();
        }
        if (!$assertionsDisabled && this._nrowY != this._ncolX) {
            throw new AssertionError();
        }
        if (!$assertionsDisabled && this._archetypes.length != this._nrowY) {
            throw new AssertionError();
        }
        if (!$assertionsDisabled && this._archetypes[0].length != this._ncolY) {
            throw new AssertionError();
        }
        double[] dArr3 = new double[this._ncolA];
        for (int i = 0; i < this._ncolA; i++) {
            dArr3[i] = dArr[this._permutation[i]];
        }
        double[] dArr4 = new double[this._ncolX];
        Random random = new Random();
        for (int i2 = 0; i2 < this._ncolX; i2++) {
            dArr4[i2] = random.nextGaussian();
        }
        double[] project = this._regx.project(dArr4, random);
        double objective = objective(project, dArr3);
        boolean z = false;
        int i3 = 0;
        while (!z) {
            int i4 = i3;
            i3++;
            if (i4 >= 100) {
                break;
            }
            double[] gradientL = gradientL(project, dArr3);
            double[] dArr5 = new double[this._ncolX];
            while (true) {
                for (int i5 = 0; i5 < this._ncolX; i5++) {
                    dArr5[i5] = project[i5] - (alpha * gradientL[i5]);
                }
                double[] rproxgrad = this._regx.rproxgrad(dArr5, alpha * this._gammax, random);
                double objective2 = objective(rproxgrad, dArr3);
                if (objective2 == 0.0d) {
                    break;
                }
                double d = 1.0d - (objective2 / objective);
                if (d >= 0.0d) {
                    if (d < 1.0E-6d) {
                        z = true;
                    }
                    objective = objective2;
                    project = rproxgrad;
                    alpha *= UP_FACTOR;
                } else {
                    alpha *= DOWN_FACTOR;
                }
            }
        }
        System.arraycopy(project, 0, dArr2, 0, this._ncolX);
        return dArr2;
    }

    private double[] gradientL(double[] dArr, double[] dArr2) {
        double[] dArr3 = new double[this._ncolX];
        int i = 0;
        for (int i2 = 0; i2 < this._ncats; i2++) {
            if (!Double.isNaN(dArr2[i2])) {
                int i3 = this._numLevels[i2];
                double[] dArr4 = new double[i3];
                for (int i4 = 0; i4 < i3; i4++) {
                    for (int i5 = 0; i5 < this._ncolX; i5++) {
                        int i6 = i4;
                        dArr4[i6] = dArr4[i6] + (dArr[i5] * this._archetypes[i5][i4 + i]);
                    }
                }
                double[] mlgrad = this._losses[i2].mlgrad(dArr4, (int) dArr2[i2]);
                for (int i7 = 0; i7 < this._ncolX; i7++) {
                    for (int i8 = 0; i8 < i3; i8++) {
                        int i9 = i7;
                        dArr3[i9] = dArr3[i9] + (mlgrad[i8] * this._archetypes[i7][i8 + i]);
                    }
                }
                i += i3;
            }
        }
        for (int i10 = this._ncats; i10 < this._ncolA; i10++) {
            int i11 = i10 - this._ncats;
            if (!Double.isNaN(dArr2[i10])) {
                double d = 0.0d;
                for (int i12 = 0; i12 < this._ncolX; i12++) {
                    d += dArr[i12] * this._archetypes[i12][i11 + i];
                }
                double lgrad = this._losses[i10].lgrad(d, (dArr2[i10] - this._normSub[i11]) * this._normMul[i11]);
                for (int i13 = 0; i13 < this._ncolX; i13++) {
                    int i14 = i13;
                    dArr3[i14] = dArr3[i14] + (lgrad * this._archetypes[i13][i11 + i]);
                }
            }
        }
        return dArr3;
    }

    private double objective(double[] dArr, double[] dArr2) {
        double d = 0.0d;
        int i = 0;
        for (int i2 = 0; i2 < this._ncats; i2++) {
            if (!Double.isNaN(dArr2[i2])) {
                int i3 = this._numLevels[i2];
                double[] dArr3 = new double[i3];
                for (int i4 = 0; i4 < i3; i4++) {
                    for (int i5 = 0; i5 < this._ncolX; i5++) {
                        int i6 = i4;
                        dArr3[i6] = dArr3[i6] + (dArr[i5] * this._archetypes[i5][i4 + i]);
                    }
                }
                d += this._losses[i2].mloss(dArr3, (int) dArr2[i2]);
                i += i3;
            }
        }
        for (int i7 = this._ncats; i7 < this._ncolA; i7++) {
            int i8 = i7 - this._ncats;
            if (!Double.isNaN(dArr2[i7])) {
                double d2 = 0.0d;
                for (int i9 = 0; i9 < this._ncolX; i9++) {
                    d2 += dArr[i9] * this._archetypes[i9][i8 + i];
                }
                d += this._losses[i7].loss(d2, (dArr2[i7] - this._normSub[i8]) * this._normMul[i8]);
            }
        }
        return d + (this._gammax * this._regx.regularize(dArr));
    }

    static {
        $assertionsDisabled = !GlrmMojoModel.class.desiredAssertionStatus();
        alpha = 1.0d;
        UP_FACTOR = Math.pow(2.0d, 0.25d);
        if (!$assertionsDisabled && UP_FACTOR <= 1.0d) {
            throw new AssertionError();
        }
        CATEGORIES = EnumSet.of(ModelCategory.AutoEncoder, ModelCategory.DimReduction);
    }
}
