package hex.deeplearning;

import hex.DataInfo;
import hex.Distribution;
import hex.deeplearning.DeepLearningModel;
import hex.deeplearning.Storage;
import java.nio.ByteBuffer;
import java.util.Arrays;
import water.H2O;
import water.MemoryManager;
import water.util.ArrayUtils;
import water.util.MathUtils;

/* loaded from: input_file:hex/deeplearning/Neurons.class */
public abstract class Neurons {
    short _k;
    int[][] _maxIncoming;
    Distribution _dist;
    protected int units;
    protected transient DeepLearningModel.DeepLearningParameters params;
    protected transient int _index;
    public transient Storage.DenseVector[] _origa;
    public transient Storage.DenseVector[] _a;
    public transient Storage.DenseVector[] _e;
    public Neurons _previous;
    public Neurons _input;
    DeepLearningModelInfo _minfo;
    public Storage.DenseRowMatrix _w;
    public Storage.DenseRowMatrix _wEA;
    public Storage.DenseVector _b;
    public Storage.DenseVector _bEA;
    Storage.DenseRowMatrix _wm;
    Storage.DenseVector _bm;
    Storage.DenseRowMatrix _ada_dx_g;
    Storage.DenseVector _bias_ada_dx_g;
    protected Dropout _dropout;
    private boolean _shortcut = false;
    public Storage.DenseVector _avg_a;
    static final /* synthetic */ boolean $assertionsDisabled;

    /* loaded from: input_file:hex/deeplearning/Neurons$ExpRectifier.class */
    public static class ExpRectifier extends Neurons {
        static final /* synthetic */ boolean $assertionsDisabled;

        public ExpRectifier(int i) {
            super(i);
        }

        @Override // hex.deeplearning.Neurons
        protected void fprop(long j, boolean z, int i) {
            for (int i2 = 0; i2 < i; i2++) {
                gemv(this._a[i2], this._w, this._previous._a[i2], this._b, this._dropout != null ? this._dropout.bits() : null);
            }
            int size = this._a[0].size();
            for (int i3 = 0; i3 < size; i3++) {
                for (int i4 = 0; i4 < i; i4++) {
                    double d = this._a[i4].get(i3);
                    this._a[i4].set(i3, d >= 0.0d ? d : Math.exp(d) - 1.0d);
                }
            }
            compute_sparsity();
        }

        @Override // hex.deeplearning.Neurons
        protected void bprop(int i) {
            if (!$assertionsDisabled && this._index >= this._minfo.get_params()._hidden.length) {
                throw new AssertionError();
            }
            float momentum = this._minfo.adaDelta() ? 0.0f : momentum();
            float rate = this._minfo.adaDelta() ? 0.0f : rate(this._minfo.get_processed_total()) * (1.0f - momentum);
            int size = this._a[0].size();
            for (int i2 = 0; i2 < size; i2++) {
                double[] dArr = new double[i];
                for (int i3 = 0; i3 < i; i3++) {
                    double d = this._a[i3].get(i2);
                    dArr[i3] = this._e[i3].get(i2) * (d >= 0.0d ? 1.0d : Math.exp(d));
                }
                bprop(i2, dArr, rate, momentum, i);
            }
        }

        static {
            $assertionsDisabled = !Neurons.class.desiredAssertionStatus();
        }
    }

    /* loaded from: input_file:hex/deeplearning/Neurons$ExpRectifierDropout.class */
    public static class ExpRectifierDropout extends ExpRectifier {
        public ExpRectifierDropout(int i) {
            super(i);
        }

        @Override // hex.deeplearning.Neurons.ExpRectifier, hex.deeplearning.Neurons
        protected void fprop(long j, boolean z, int i) {
            if (z) {
                long j2 = j + (this.params._seed - 629514240);
                this._dropout.fillBytes(j2);
                super.fprop(j2, true, i);
            } else {
                super.fprop(j, false, i);
                for (int i2 = 0; i2 < i; i2++) {
                    ArrayUtils.mult(this._a[i2].raw(), 1.0d - this.params._hidden_dropout_ratios[this._index]);
                }
            }
        }
    }

    /* loaded from: input_file:hex/deeplearning/Neurons$Input.class */
    public static class Input extends Neurons {
        private DataInfo _dinfo;
        static final /* synthetic */ boolean $assertionsDisabled;

        /* JADX INFO: Access modifiers changed from: package-private */
        public Input(DeepLearningModel.DeepLearningParameters deepLearningParameters, int i, DataInfo dataInfo) {
            super(i);
            this._dinfo = dataInfo;
            this._a = new Storage.DenseVector[deepLearningParameters._mini_batch_size];
            for (int i2 = 0; i2 < this._a.length; i2++) {
                this._a[i2] = new Storage.DenseVector(i);
            }
        }

        @Override // hex.deeplearning.Neurons
        protected void bprop(int i) {
            throw new UnsupportedOperationException();
        }

        @Override // hex.deeplearning.Neurons
        protected void fprop(long j, boolean z, int i) {
            throw new UnsupportedOperationException();
        }

        public void setInput(long j, double[] dArr, int i) {
            if (!$assertionsDisabled && this._dinfo == null) {
                throw new AssertionError();
            }
            double[] malloc8d = MemoryManager.malloc8d(this._dinfo._nums);
            int[] malloc4 = MemoryManager.malloc4(this._dinfo._cats);
            int i2 = 0;
            int i3 = 0;
            while (i2 < this._dinfo._cats) {
                if (!$assertionsDisabled && !this._dinfo._catMissing[i2]) {
                    throw new AssertionError();
                }
                if (Double.isNaN(dArr[i2])) {
                    malloc4[i3] = this._dinfo._catOffsets[i2 + 1] - 1;
                } else {
                    int i4 = (int) dArr[i2];
                    if (this._dinfo._useAllFactorLevels) {
                        malloc4[i3] = i4 + this._dinfo._catOffsets[i2];
                    } else if (i4 != 0) {
                        malloc4[i3] = (i4 + this._dinfo._catOffsets[i2]) - 1;
                    }
                    if (malloc4[i3] >= this._dinfo._catOffsets[i2 + 1]) {
                        malloc4[i3] = this._dinfo._catOffsets[i2 + 1] - 1;
                    }
                }
                i3++;
                i2++;
            }
            while (i2 < dArr.length) {
                double d = dArr[i2];
                if (this._dinfo._normMul != null) {
                    d = (d - this._dinfo._normSub[i2 - this._dinfo._cats]) * this._dinfo._normMul[i2 - this._dinfo._cats];
                }
                malloc8d[i2 - this._dinfo._cats] = d;
                i2++;
            }
            setInput(j, null, malloc8d, i3, malloc4, i);
        }

        public void setInput(long j, int[] iArr, double[] dArr, int i, int[] iArr2, int i2) {
            Arrays.fill(this._a[i2].raw(), 0.0d);
            if (this.params._max_categorical_features < this._dinfo.fullN() - this._dinfo._nums) {
                if (!$assertionsDisabled && dArr.length != this._dinfo._nums) {
                    throw new AssertionError();
                }
                int length = dArr.length + this.params._max_categorical_features;
                if (!$assertionsDisabled && this._a[i2].size() != length) {
                    throw new AssertionError();
                }
                int i3 = this.params._max_categorical_features;
                if (!$assertionsDisabled && this._a[i2].size() != length) {
                    throw new AssertionError();
                }
                MurmurHash murmurHash = MurmurHash.getInstance();
                for (int i4 = 0; i4 < i; i4++) {
                    this._a[i2].add(Math.abs(murmurHash.hash(ByteBuffer.allocate(4).putInt(iArr2[i4]).array(), 4, (int) this.params._seed) % i3), 1.0d);
                }
                for (int i5 = 0; i5 < dArr.length; i5++) {
                    this._a[i2].set(i3 + i5, Double.isNaN(dArr[i5]) ? 0.0d : dArr[i5]);
                }
            } else {
                if (!$assertionsDisabled && this._a[i2].size() != this._dinfo.fullN()) {
                    throw new AssertionError();
                }
                for (int i6 = 0; i6 < i; i6++) {
                    this._a[i2].set(iArr2[i6], 1.0d);
                }
                if (iArr != null) {
                    for (int i7 = 0; i7 < iArr.length; i7++) {
                        this._a[i2].set(iArr[i7], Double.isNaN(dArr[i7]) ? 0.0d : dArr[i7]);
                    }
                } else {
                    for (int i8 = 0; i8 < dArr.length; i8++) {
                        this._a[i2].set(this._dinfo.numStart() + i8, Double.isNaN(dArr[i8]) ? 0.0d : dArr[i8]);
                    }
                }
            }
            if (this._dropout == null) {
                return;
            }
            if (this.params._autoencoder && this.params._input_dropout_ratio > 0.0d) {
                System.arraycopy(this._a[i2].raw(), 0, this._origa[i2].raw(), 0, this._a[i2].raw().length);
            }
            this._dropout.randomlySparsifyActivation((Storage.Vector) this._a[i2], j + this.params._seed + 322417854);
        }

        static {
            $assertionsDisabled = !Neurons.class.desiredAssertionStatus();
        }
    }

    /* loaded from: input_file:hex/deeplearning/Neurons$Linear.class */
    public static class Linear extends Output {
        public Linear() {
            super(1);
        }

        @Override // hex.deeplearning.Neurons
        protected void fprop(long j, boolean z, int i) {
            for (int i2 = 0; i2 < i; i2++) {
                gemv(this._a[i2], this._w, this._previous._a[i2], this._b, this._dropout != null ? this._dropout.bits() : null);
            }
        }

        @Override // hex.deeplearning.Neurons
        protected void setOutputLayerGradient(double d, int i, int i2) {
            this._e[i].set(0, ((-2.0d) * this._dist.negHalfGradient(d, this._a[i].get(0))) / i2);
        }
    }

    /* loaded from: input_file:hex/deeplearning/Neurons$Maxout.class */
    public static class Maxout extends Neurons {
        static final /* synthetic */ boolean $assertionsDisabled;

        /* JADX WARN: Type inference failed for: r1v4, types: [int[], int[][]] */
        public Maxout(DeepLearningModel.DeepLearningParameters deepLearningParameters, short s, int i) {
            super(i);
            this._k = s;
            this._maxIncoming = new int[deepLearningParameters._mini_batch_size];
            for (int i2 = 0; i2 < this._maxIncoming.length; i2++) {
                this._maxIncoming[i2] = new int[i];
            }
            if (this._k != 2) {
                throw H2O.unimpl("Maxout is currently hardcoded for 2 channels. Trivial to enable k > 2 though.");
            }
        }

        @Override // hex.deeplearning.Neurons
        protected void fprop(long j, boolean z, int i) {
            if (!$assertionsDisabled && this._b.size() != this._a[0].size() * this._k) {
                throw new AssertionError();
            }
            if (!$assertionsDisabled && this._w.size() != this._a[0].size() * this._previous._a[0].size() * this._k) {
                throw new AssertionError();
            }
            int size = this._a[0].size();
            double[] dArr = new double[this._k];
            for (int i2 = 0; i2 < size; i2++) {
                for (int i3 = 0; i3 < i; i3++) {
                    this._a[i3].set(i2, 0.0d);
                    if (!z || this._dropout == null || this._dropout.unit_active(i2)) {
                        int size2 = this._previous._a[i3].size();
                        short s = 0;
                        short s2 = 0;
                        while (true) {
                            short s3 = s2;
                            if (s3 >= this._k) {
                                break;
                            }
                            dArr[s3] = 0.0d;
                            for (int i4 = 0; i4 < size2; i4++) {
                                dArr[s3] = dArr[s3] + (this._w.raw()[(this._k * ((i2 * size2) + i4)) + s3] * this._previous._a[i3].get(i4));
                            }
                            dArr[s3] = dArr[s3] + this._b.raw()[(this._k * i2) + s3];
                            if (dArr[s3] > dArr[s]) {
                                s = s3;
                            }
                            s2 = (short) (s3 + 1);
                        }
                        this._maxIncoming[i3][i2] = s;
                        this._a[i3].set(i2, dArr[s]);
                    }
                }
                compute_sparsity();
            }
        }

        @Override // hex.deeplearning.Neurons
        protected void bprop(int i) {
            if (!$assertionsDisabled && this._index == this.params._hidden.length) {
                throw new AssertionError();
            }
            float momentum = this._minfo.adaDelta() ? 0.0f : momentum();
            float rate = this._minfo.adaDelta() ? 0.0f : rate(this._minfo.get_processed_total()) * (1.0f - momentum);
            double[] dArr = new double[i];
            int size = this._a[0].size();
            for (int i2 = 0; i2 < size; i2++) {
                for (int i3 = 0; i3 < i; i3++) {
                    dArr[i3] = this._e[i3].get(i2);
                }
                bprop(i2, dArr, rate, momentum, i);
            }
        }

        static {
            $assertionsDisabled = !Neurons.class.desiredAssertionStatus();
        }
    }

    /* loaded from: input_file:hex/deeplearning/Neurons$MaxoutDropout.class */
    public static class MaxoutDropout extends Maxout {
        public MaxoutDropout(DeepLearningModel.DeepLearningParameters deepLearningParameters, short s, int i) {
            super(deepLearningParameters, s, i);
        }

        @Override // hex.deeplearning.Neurons.Maxout, hex.deeplearning.Neurons
        protected void fprop(long j, boolean z, int i) {
            if (z) {
                long j2 = j + this.params._seed + 1372114957;
                this._dropout.fillBytes(j2);
                super.fprop(j2, true, i);
            } else {
                super.fprop(j, false, i);
                for (int i2 = 0; i2 < i; i2++) {
                    ArrayUtils.mult(this._a[i2].raw(), 1.0d - this.params._hidden_dropout_ratios[this._index]);
                }
            }
        }
    }

    /* loaded from: input_file:hex/deeplearning/Neurons$Output.class */
    public static abstract class Output extends Neurons {
        Output(int i) {
            super(i);
        }

        @Override // hex.deeplearning.Neurons
        protected void bprop(int i) {
            throw new UnsupportedOperationException();
        }
    }

    /* loaded from: input_file:hex/deeplearning/Neurons$Rectifier.class */
    public static class Rectifier extends Neurons {
        static final /* synthetic */ boolean $assertionsDisabled;

        public Rectifier(int i) {
            super(i);
        }

        @Override // hex.deeplearning.Neurons
        protected void fprop(long j, boolean z, int i) {
            for (int i2 = 0; i2 < i; i2++) {
                gemv(this._a[i2], this._w, this._previous._a[i2], this._b, this._dropout != null ? this._dropout.bits() : null);
            }
            int size = this._a[0].size();
            for (int i3 = 0; i3 < i; i3++) {
                for (int i4 = 0; i4 < size; i4++) {
                    this._a[i3].set(i4, 0.5d * (this._a[i3].get(i4) + Math.abs(this._a[i3].get(i4))));
                }
            }
            compute_sparsity();
        }

        @Override // hex.deeplearning.Neurons
        protected void bprop(int i) {
            if (!$assertionsDisabled && this._index >= this._minfo.get_params()._hidden.length) {
                throw new AssertionError();
            }
            float momentum = this._minfo.adaDelta() ? 0.0f : momentum();
            float rate = this._minfo.adaDelta() ? 0.0f : rate(this._minfo.get_processed_total()) * (1.0f - momentum);
            int size = this._a[0].size();
            double[] dArr = new double[i];
            for (int i2 = 0; i2 < size; i2++) {
                for (int i3 = 0; i3 < i; i3++) {
                    dArr[i3] = this._a[i3].get(i2) > 0.0d ? this._e[i3].get(i2) : 0.0d;
                }
                bprop(i2, dArr, rate, momentum, i);
            }
        }

        static {
            $assertionsDisabled = !Neurons.class.desiredAssertionStatus();
        }
    }

    /* loaded from: input_file:hex/deeplearning/Neurons$RectifierDropout.class */
    public static class RectifierDropout extends Rectifier {
        public RectifierDropout(int i) {
            super(i);
        }

        @Override // hex.deeplearning.Neurons.Rectifier, hex.deeplearning.Neurons
        protected void fprop(long j, boolean z, int i) {
            if (z) {
                long j2 = j + this.params._seed + 1014100461;
                this._dropout.fillBytes(j2);
                super.fprop(j2, true, i);
            } else {
                super.fprop(j, false, i);
                for (int i2 = 0; i2 < i; i2++) {
                    ArrayUtils.mult(this._a[i2].raw(), 1.0d - this.params._hidden_dropout_ratios[this._index]);
                }
            }
        }
    }

    /* loaded from: input_file:hex/deeplearning/Neurons$Softmax.class */
    public static class Softmax extends Output {
        static final /* synthetic */ boolean $assertionsDisabled;

        public Softmax(int i) {
            super(i);
        }

        @Override // hex.deeplearning.Neurons
        protected void fprop(long j, boolean z, int i) {
            for (int i2 = 0; i2 < i; i2++) {
                gemv(this._a[i2], this._w, this._previous._a[i2], this._b, null);
            }
            for (int i3 = 0; i3 < i; i3++) {
                double maxValue = ArrayUtils.maxValue(this._a[i3].raw());
                double d = 0.0d;
                int size = this._a[i3].size();
                for (int i4 = 0; i4 < size; i4++) {
                    this._a[i3].set(i4, Math.exp(this._a[i3].get(i4) - maxValue));
                    d += this._a[i3].get(i4);
                }
                for (int i5 = 0; i5 < size; i5++) {
                    double[] raw = this._a[i3].raw();
                    int i6 = i5;
                    raw[i6] = raw[i6] / d;
                }
            }
        }

        @Override // hex.deeplearning.Neurons
        protected void setOutputLayerGradient(double d, int i, int i2) {
            double d2;
            if (!$assertionsDisabled && d != ((int) d)) {
                throw new AssertionError();
            }
            int size = this._a[i].size();
            int i3 = 0;
            while (i3 < size) {
                double d3 = i3 == ((int) d) ? 1 : 0;
                double d4 = this._a[i].get(i3);
                switch (this.params._loss) {
                    case CrossEntropy:
                        d2 = d4 - d3;
                        break;
                    case ModifiedHuber:
                        d2 = (-2.0d) * this._dist.negHalfGradient(d3, d4) * (1.0d - d4) * d4;
                        break;
                    case Quadratic:
                        d2 = (d4 - d3) * (1.0d - d4) * d4;
                        break;
                    default:
                        throw H2O.unimpl();
                }
                this._e[i].set(i3, d2 / i2);
                i3++;
            }
        }

        static {
            $assertionsDisabled = !Neurons.class.desiredAssertionStatus();
        }
    }

    /* loaded from: input_file:hex/deeplearning/Neurons$Tanh.class */
    public static class Tanh extends Neurons {
        static final /* synthetic */ boolean $assertionsDisabled;

        public Tanh(int i) {
            super(i);
        }

        @Override // hex.deeplearning.Neurons
        protected void fprop(long j, boolean z, int i) {
            for (int i2 = 0; i2 < i; i2++) {
                gemv(this._a[i2], this._w, this._previous._a[i2], this._b, this._dropout != null ? this._dropout.bits() : null);
            }
            int size = this._a[0].size();
            for (int i3 = 0; i3 < i; i3++) {
                for (int i4 = 0; i4 < size; i4++) {
                    this._a[i3].set(i4, 1.0d - (2.0d / (1.0d + Math.exp(2.0d * this._a[i3].get(i4)))));
                }
            }
            compute_sparsity();
        }

        @Override // hex.deeplearning.Neurons
        protected void bprop(int i) {
            if (!$assertionsDisabled && this._index >= this._minfo.get_params()._hidden.length) {
                throw new AssertionError();
            }
            float momentum = this._minfo.adaDelta() ? 0.0f : momentum();
            float rate = this._minfo.adaDelta() ? 0.0f : rate(this._minfo.get_processed_total()) * (1.0f - momentum);
            int size = this._a[0].size();
            double[] dArr = new double[i];
            for (int i2 = 0; i2 < size; i2++) {
                for (int i3 = 0; i3 < i; i3++) {
                    dArr[i3] = this._e[i3].get(i2) * (1.0d - (this._a[i3].get(i2) * this._a[i3].get(i2)));
                }
                bprop(i2, dArr, rate, momentum, i);
            }
        }

        static {
            $assertionsDisabled = !Neurons.class.desiredAssertionStatus();
        }
    }

    /* loaded from: input_file:hex/deeplearning/Neurons$TanhDropout.class */
    public static class TanhDropout extends Tanh {
        public TanhDropout(int i) {
            super(i);
        }

        @Override // hex.deeplearning.Neurons.Tanh, hex.deeplearning.Neurons
        protected void fprop(long j, boolean z, int i) {
            if (z) {
                long j2 = j + (this.params._seed - 629514240);
                this._dropout.fillBytes(j2);
                super.fprop(j2, true, i);
            } else {
                super.fprop(j, false, i);
                for (int i2 = 0; i2 < i; i2++) {
                    ArrayUtils.mult(this._a[i2].raw(), 1.0d - this.params._hidden_dropout_ratios[this._index]);
                }
            }
        }
    }

    Neurons(int i) {
        this.units = i;
    }

    public String toString() {
        String str = (getClass().getSimpleName() + "\nNumber of Neurons: " + this.units) + "\nParameters:\n" + this.params.toString();
        if (this._dropout != null) {
            str = str + "\nDropout:\n" + this._dropout.toString();
        }
        return str;
    }

    void sanityCheck(boolean z) {
        if (this instanceof Input) {
            if (!$assertionsDisabled && this._previous != null) {
                throw new AssertionError();
            }
            return;
        }
        if (!$assertionsDisabled && this._previous == null) {
            throw new AssertionError();
        }
        if (this._minfo.has_momenta()) {
            if (!$assertionsDisabled && this._wm == null) {
                throw new AssertionError();
            }
            if (!$assertionsDisabled && this._bm == null) {
                throw new AssertionError();
            }
            if (!$assertionsDisabled && this._ada_dx_g != null) {
                throw new AssertionError();
            }
        }
        if (this._minfo.adaDelta()) {
            if (this.params._rho == 0.0d) {
                throw new IllegalArgumentException("rho must be > 0 if epsilon is >0.");
            }
            if (this.params._epsilon == 0.0d) {
                throw new IllegalArgumentException("epsilon must be > 0 if rho is >0.");
            }
            if (!$assertionsDisabled && !this._minfo.adaDelta()) {
                throw new AssertionError();
            }
            if (!$assertionsDisabled && this._bias_ada_dx_g == null) {
                throw new AssertionError();
            }
            if (!$assertionsDisabled && this._wm != null) {
                throw new AssertionError();
            }
            if (!$assertionsDisabled && this._bm != null) {
                throw new AssertionError();
            }
        }
        if (((this instanceof MaxoutDropout) || (this instanceof TanhDropout) || (this instanceof RectifierDropout)) && !$assertionsDisabled && z && this._dropout == null) {
            throw new AssertionError();
        }
    }

    public final void init(Neurons[] neuronsArr, int i, DeepLearningModel.DeepLearningParameters deepLearningParameters, DeepLearningModelInfo deepLearningModelInfo, boolean z) {
        this._index = i - 1;
        this.params = (DeepLearningModel.DeepLearningParameters) deepLearningParameters.m1957clone();
        this.params._hidden_dropout_ratios = deepLearningModelInfo.get_params()._hidden_dropout_ratios;
        this.params._rate *= Math.pow(this.params._rate_decay, i - 1);
        this.params._distribution = deepLearningModelInfo.get_params()._distribution;
        this._dist = new Distribution(this.params);
        this._a = new Storage.DenseVector[this.params._mini_batch_size];
        for (int i2 = 0; i2 < this._a.length; i2++) {
            this._a[i2] = new Storage.DenseVector(this.units);
        }
        if (!(this instanceof Input)) {
            this._e = new Storage.DenseVector[this.params._mini_batch_size];
            for (int i3 = 0; i3 < this._e.length; i3++) {
                this._e[i3] = new Storage.DenseVector(this.units);
            }
        } else if (this.params._autoencoder && this.params._input_dropout_ratio > 0.0d) {
            this._origa = new Storage.DenseVector[this.params._mini_batch_size];
            for (int i4 = 0; i4 < this._origa.length; i4++) {
                this._origa[i4] = new Storage.DenseVector(this.units);
            }
        }
        if (z && ((this instanceof MaxoutDropout) || (this instanceof TanhDropout) || (this instanceof RectifierDropout) || (this instanceof ExpRectifierDropout) || (this instanceof Input))) {
            this._dropout = this instanceof Input ? this.params._input_dropout_ratio == 0.0d ? null : new Dropout(this.units, this.params._input_dropout_ratio) : new Dropout(this.units, this.params._hidden_dropout_ratios[this._index]);
        }
        if (!(this instanceof Input)) {
            this._previous = neuronsArr[this._index];
            this._minfo = deepLearningModelInfo;
            this._w = deepLearningModelInfo.get_weights(this._index);
            this._b = deepLearningModelInfo.get_biases(this._index);
            if (this.params._autoencoder && this.params._sparsity_beta > 0.0d && this._index < this.params._hidden.length) {
                this._avg_a = deepLearningModelInfo.get_avg_activations(this._index);
            }
            if (deepLearningModelInfo.has_momenta()) {
                this._wm = deepLearningModelInfo.get_weights_momenta(this._index);
                this._bm = deepLearningModelInfo.get_biases_momenta(this._index);
            }
            if (deepLearningModelInfo.adaDelta()) {
                this._ada_dx_g = deepLearningModelInfo.get_ada_dx_g(this._index);
                this._bias_ada_dx_g = deepLearningModelInfo.get_biases_ada_dx_g(this._index);
            }
            this._shortcut = this.params._fast_mode || (!this.params._adaptive_rate && !this._minfo.has_momenta() && this.params._l1 == 0.0d && this.params._l2 == 0.0d);
        }
        sanityCheck(z);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public abstract void fprop(long j, boolean z, int i);

    /* JADX INFO: Access modifiers changed from: protected */
    public abstract void bprop(int i);

    /* JADX INFO: Access modifiers changed from: protected */
    public final void bpropOutputLayer(int i) {
        if (!$assertionsDisabled && this._index != this.params._hidden.length) {
            throw new AssertionError();
        }
        if (!$assertionsDisabled && this._a.length != this.params._mini_batch_size) {
            throw new AssertionError();
        }
        int size = this._a[0].size();
        float momentum = this._minfo.adaDelta() ? 0.0f : momentum();
        float rate = this._minfo.adaDelta() ? 0.0f : rate(this._minfo.get_processed_total()) * (1.0f - momentum);
        for (int i2 = 0; i2 < size; i2++) {
            double[] dArr = new double[i];
            for (int i3 = 0; i3 < i; i3++) {
                dArr[i3] = this._e[i3].raw()[i2];
            }
            bprop(i2, dArr, rate, momentum, i);
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void setOutputLayerGradient(double d, int i, int i2) {
        if (!$assertionsDisabled && (!this._minfo.get_params()._autoencoder || this._index != this._minfo.get_params()._hidden.length)) {
            throw new AssertionError();
        }
        int size = this._a[i].size();
        for (int i3 = 0; i3 < size; i3++) {
            this._e[i].set(i3, autoEncoderGradient(i3, i) / i2);
        }
    }

    final void bprop(int i, double[] dArr, float f, float f2, int i2) {
        float f3 = (float) this.params._rho;
        float f4 = (float) this.params._epsilon;
        float f5 = (float) this.params._l1;
        float f6 = (float) this.params._l2;
        float f7 = this.params._max_w2;
        boolean has_momenta = this._minfo.has_momenta();
        boolean adaDelta = this._minfo.adaDelta();
        boolean z = this.params._nesterov_accelerated_gradient;
        boolean z2 = this.params._fast_mode;
        int size = this._previous._a[0].size();
        if (!$assertionsDisabled && dArr.length != i2) {
            throw new AssertionError();
        }
        double d = 0.0d;
        int i3 = i * size;
        for (int i4 = 0; i4 < i2; i4++) {
            if (this._shortcut && dArr[i4] == 0.0d) {
                return;
            }
            boolean z3 = (this._previous._e == null || this._previous._e[i4] == null) ? false : true;
            for (int i5 = 0; i5 < size; i5++) {
                int i6 = i3 + i5;
                if (this._k != 0) {
                    i6 = (this._k * i6) + this._maxIncoming[i4][i];
                }
                double d2 = this._w.raw()[i6];
                if (z3) {
                    this._previous._e[i4].add(i5, dArr[i4] * d2);
                }
                double d3 = this._previous._a[i4].get(i5);
                if (!z2 || d3 != 0.0d) {
                    double signum = (dArr[i4] * d3) + (Math.signum(d2) * f5) + (d2 * f6);
                    if (this._wEA != null) {
                        signum += this.params._elastic_averaging_regularization * (this._w.raw()[i6] - this._wEA.raw()[i6]);
                    }
                    if (DeepLearningModelInfo.gradientCheck != null) {
                        DeepLearningModelInfo.gradientCheck.apply(this._index, i, i5, signum);
                    }
                    if (adaDelta) {
                        d += signum * signum;
                        float computeAdaDeltaRateForWeight = computeAdaDeltaRateForWeight(signum, i6, this._ada_dx_g, f3, f4);
                        this._w.raw()[i6] = (float) (r0[r1] - (computeAdaDeltaRateForWeight * signum));
                    } else if (z) {
                        double d4 = -signum;
                        if (has_momenta) {
                            float[] raw = this._wm.raw();
                            int i7 = i6;
                            raw[i7] = raw[i7] * f2;
                            this._wm.raw()[i6] = (float) (r0[r1] + d4);
                            d4 = this._wm.raw()[i6];
                        }
                        this._w.raw()[i6] = (float) (r0[r1] + (f * d4));
                    } else {
                        double d5 = (-f) * signum;
                        this._w.raw()[i6] = (float) (r0[r1] + d5);
                        if (has_momenta) {
                            float[] raw2 = this._w.raw();
                            int i8 = i6;
                            raw2[i8] = raw2[i8] + (f2 * this._wm.raw()[i6]);
                            this._wm.raw()[i6] = (float) d5;
                        }
                    }
                }
            }
        }
        if (f7 != Float.POSITIVE_INFINITY) {
            for (int i9 = 0; i9 < i2; i9++) {
                rescale_weights(this._w, i, f7, i9);
            }
        }
        if (adaDelta) {
            d /= size * i2;
        }
        for (int i10 = 0; i10 < i2; i10++) {
            update_bias(this._b, this._bEA, this._bm, i, dArr, d, f, f2, i10);
        }
    }

    private void rescale_weights(Storage.DenseRowMatrix denseRowMatrix, int i, float f, int i2) {
        int i3;
        int i4;
        int size = this._previous._a[0].size();
        if (this._k != 0) {
            i3 = (this._k * i * size) + this._maxIncoming[i2][i];
            i4 = (this._k * ((i * size) + (size - 1))) + this._maxIncoming[i2][i];
        } else {
            if (i2 > 0) {
                return;
            }
            i3 = i * size;
            i4 = (i * size) + size;
        }
        float sumSquares = MathUtils.sumSquares(denseRowMatrix.raw(), i3, i4);
        if (sumSquares > f) {
            float approxSqrt = MathUtils.approxSqrt(f / sumSquares);
            for (int i5 = i3; i5 < i4; i5++) {
                float[] raw = denseRowMatrix.raw();
                int i6 = i5;
                raw[i6] = raw[i6] * approxSqrt;
            }
        }
    }

    protected double autoEncoderGradient(int i, int i2) {
        if ($assertionsDisabled || (this._minfo.get_params()._autoencoder && this._index == this._minfo.get_params()._hidden.length)) {
            return (-2.0d) * this._dist.negHalfGradient(this._input._origa != null ? this._input._origa[i2].get(i) : this._input._a[i2].get(i), this._a[i2].get(i));
        }
        throw new AssertionError();
    }

    private static float computeAdaDeltaRateForWeight(double d, int i, Storage.DenseRowMatrix denseRowMatrix, float f, float f2) {
        double d2 = d * d;
        denseRowMatrix.raw()[(2 * i) + 1] = (float) ((f * denseRowMatrix.raw()[(2 * i) + 1]) + ((1.0f - f) * d2));
        float approxSqrt = MathUtils.approxSqrt((denseRowMatrix.raw()[2 * i] + f2) / (denseRowMatrix.raw()[(2 * i) + 1] + f2));
        denseRowMatrix.raw()[2 * i] = (float) ((f * denseRowMatrix.raw()[2 * i]) + ((1.0f - f) * approxSqrt * approxSqrt * d2));
        return approxSqrt;
    }

    private static double computeAdaDeltaRateForBias(double d, int i, Storage.DenseVector denseVector, float f, float f2) {
        denseVector.raw()[(2 * i) + 1] = (f * denseVector.raw()[(2 * i) + 1]) + ((1.0f - f) * d);
        double approxSqrt = MathUtils.approxSqrt((denseVector.raw()[2 * i] + f2) / (denseVector.raw()[(2 * i) + 1] + f2));
        denseVector.raw()[2 * i] = (f * denseVector.raw()[2 * i]) + ((1.0f - f) * approxSqrt * approxSqrt * d);
        return approxSqrt;
    }

    void compute_sparsity() {
        if (this._avg_a != null) {
            if (this.params._mini_batch_size > 1) {
                throw H2O.unimpl("Sparsity constraint is not yet implemented for mini-batch size > 1.");
            }
            for (int i = 0; i < this._minfo.get_params()._mini_batch_size; i++) {
                for (int i2 = 0; i2 < this._avg_a.size(); i2++) {
                    this._avg_a.set(i2, (0.999d * this._avg_a.get(i2)) + (0.001d * this._a[i].get(i2)));
                }
            }
        }
    }

    private void update_bias(Storage.DenseVector denseVector, Storage.DenseVector denseVector2, Storage.DenseVector denseVector3, int i, double[] dArr, double d, double d2, double d3, int i2) {
        boolean has_momenta = this._minfo.has_momenta();
        boolean adaDelta = this._minfo.adaDelta();
        float f = (float) this.params._l1;
        float f2 = (float) this.params._l2;
        int i3 = this._k != 0 ? (this._k * i) + this._maxIncoming[i2][i] : i;
        double d4 = denseVector.get(i3);
        dArr[i2] = dArr[i2] + (Math.signum(d4) * f) + (d4 * f2);
        if (denseVector2 != null) {
            dArr[i2] = dArr[i2] + ((d4 - denseVector2.get(i3)) * this.params._elastic_averaging_regularization);
        }
        if (DeepLearningModelInfo.gradientCheck != null) {
            DeepLearningModelInfo.gradientCheck.apply(this._index, i, -1, dArr[i2]);
        }
        if (adaDelta) {
            d2 = computeAdaDeltaRateForBias(d, i3, this._bias_ada_dx_g, (float) this.params._rho, (float) this.params._epsilon);
        }
        if (this.params._nesterov_accelerated_gradient) {
            double d5 = -dArr[i2];
            if (has_momenta) {
                denseVector3.set(i3, denseVector3.get(i3) * d3);
                denseVector3.add(i3, d5);
                d5 = denseVector3.get(i3);
            }
            denseVector.add(i3, d2 * d5);
        } else {
            double d6 = (-d2) * dArr[i2];
            denseVector.add(i3, d6);
            if (has_momenta) {
                denseVector.add(i3, d3 * denseVector3.get(i3));
                denseVector3.set(i3, d6);
            }
        }
        if (this.params._autoencoder && this.params._sparsity_beta > 0.0d && !(this instanceof Output) && !(this instanceof Input) && this._index != this.params._hidden.length) {
            denseVector.add(i3, -(d2 * this.params._sparsity_beta * (this._avg_a.raw()[i3] - this.params._average_activation)));
        }
        if (Double.isInfinite(denseVector.get(i3))) {
            this._minfo.setUnstable();
        }
    }

    public float rate(double d) {
        return (float) (this.params._rate / (1.0d + (this.params._rate_annealing * d)));
    }

    protected float momentum() {
        return momentum(-1.0d);
    }

    public final float momentum(double d) {
        double d2 = this.params._momentum_start;
        if (this.params._momentum_ramp > 0.0d) {
            double d3 = d != -1.0d ? this._minfo.get_processed_total() : d;
            d2 = d3 >= this.params._momentum_ramp ? this.params._momentum_stable : d2 + (((this.params._momentum_stable - this.params._momentum_start) * d3) / this.params._momentum_ramp);
        }
        return (float) d2;
    }

    static void gemv_naive(double[] dArr, float[] fArr, double[] dArr2, double[] dArr3, byte[] bArr) {
        int length = dArr2.length;
        int length2 = dArr3.length;
        if (!$assertionsDisabled && dArr.length != length2) {
            throw new AssertionError();
        }
        for (int i = 0; i < length2; i++) {
            dArr[i] = 0.0d;
            if (bArr == null || (bArr[i / 8] & (1 << (i % 8))) != 0) {
                for (int i2 = 0; i2 < length; i2++) {
                    int i3 = i;
                    dArr[i3] = dArr[i3] + (fArr[(i * length) + i2] * dArr2[i2]);
                }
                int i4 = i;
                dArr[i4] = dArr[i4] + dArr3[i];
            }
        }
    }

    static void gemv_row_optimized(double[] dArr, float[] fArr, double[] dArr2, double[] dArr3, byte[] bArr) {
        int length = dArr2.length;
        int length2 = dArr3.length;
        if (!$assertionsDisabled && dArr.length != length2) {
            throw new AssertionError();
        }
        int i = length - (length % 8);
        int i2 = ((length / 8) * 8) - 1;
        int i3 = 0;
        for (int i4 = 0; i4 < length2; i4++) {
            dArr[i4] = 0.0d;
            if (bArr == null || (bArr[i4 / 8] & (1 << (i4 % 8))) != 0) {
                double d = 0.0d;
                double d2 = 0.0d;
                double d3 = 0.0d;
                double d4 = 0.0d;
                double d5 = 0.0d;
                double d6 = 0.0d;
                double d7 = 0.0d;
                double d8 = 0.0d;
                for (int i5 = 0; i5 < i2; i5 += 8) {
                    int i6 = i3 + i5;
                    d += fArr[i6] * dArr2[i5];
                    d2 += fArr[i6 + 1] * dArr2[i5 + 1];
                    d3 += fArr[i6 + 2] * dArr2[i5 + 2];
                    d4 += fArr[i6 + 3] * dArr2[i5 + 3];
                    d5 += fArr[i6 + 4] * dArr2[i5 + 4];
                    d6 += fArr[i6 + 5] * dArr2[i5 + 5];
                    d7 += fArr[i6 + 6] * dArr2[i5 + 6];
                    d8 += fArr[i6 + 7] * dArr2[i5 + 7];
                }
                int i7 = i4;
                dArr[i7] = dArr[i7] + d + d2 + d3 + d4;
                int i8 = i4;
                dArr[i8] = dArr[i8] + d5 + d6 + d7 + d8;
                for (int i9 = i; i9 < length; i9++) {
                    int i10 = i4;
                    dArr[i10] = dArr[i10] + (fArr[i3 + i9] * dArr2[i9]);
                }
                int i11 = i4;
                dArr[i11] = dArr[i11] + dArr3[i4];
            }
            i3 += length;
        }
    }

    static void gemv(Storage.DenseVector denseVector, Storage.DenseRowMatrix denseRowMatrix, Storage.DenseVector denseVector2, Storage.DenseVector denseVector3, byte[] bArr) {
        gemv_row_optimized(denseVector.raw(), denseRowMatrix.raw(), denseVector2.raw(), denseVector3.raw(), bArr);
    }

    static void gemv_naive(Storage.DenseVector denseVector, Storage.DenseRowMatrix denseRowMatrix, Storage.DenseVector denseVector2, Storage.DenseVector denseVector3, byte[] bArr) {
        gemv_naive(denseVector.raw(), denseRowMatrix.raw(), denseVector2.raw(), denseVector3.raw(), bArr);
    }

    static {
        $assertionsDisabled = !Neurons.class.desiredAssertionStatus();
    }
}
