package com.github.chen0040.mlp.ann;

import com.github.chen0040.data.frame.DataFrame;
import com.github.chen0040.data.frame.DataRow;
import com.github.chen0040.data.utils.transforms.Standardization;
import com.github.chen0040.mlp.enums.WeightUpdateMode;
import com.github.chen0040.mlp.functions.RangeScaler;
import java.util.ArrayList;

/* loaded from: input_file:com/github/chen0040/mlp/ann/MLP.class */
public abstract class MLP extends MLPNet {
    private Standardization inputNormalization;
    private RangeScaler outputNormalization;
    private boolean normalizeOutputs = false;

    protected abstract boolean isValidTrainingSample(DataRow dataRow);

    public void setNormalizeOutputs(boolean z) {
        this.normalizeOutputs = z;
    }

    public abstract double[] getTarget(DataRow dataRow);

    /* JADX WARN: Multi-variable type inference failed */
    public void train(DataFrame dataFrame, int i) {
        this.inputNormalization = new Standardization(dataFrame);
        if (this.normalizeOutputs) {
            ArrayList arrayList = new ArrayList();
            for (int i2 = 0; i2 < dataFrame.rowCount(); i2++) {
                DataRow row = dataFrame.row(i2);
                if (isValidTrainingSample(row)) {
                    arrayList.add(getTarget(row));
                }
            }
            this.outputNormalization = new RangeScaler(arrayList);
        }
        for (int i3 = 0; i3 < i; i3++) {
            if (this.weightUpdateMode == WeightUpdateMode.StochasticGradientDescend) {
                for (int i4 = 0; i4 < dataFrame.rowCount(); i4++) {
                    DataRow row2 = dataFrame.row(i4);
                    if (isValidTrainingSample(row2)) {
                        double[] standardize = this.inputNormalization.standardize(row2.toArray());
                        double[] target = getTarget(row2);
                        if (this.outputNormalization != null) {
                            target = this.outputNormalization.standardize(target);
                        }
                        stochasticGradientDesend(standardize, target);
                    }
                }
            } else {
                ArrayList arrayList2 = new ArrayList();
                arrayList2.add(this.outputLayer);
                for (int size = this.hiddenLayers.size() - 1; size >= 0; size--) {
                    arrayList2.add(this.hiddenLayers.get(size));
                }
                double[][] dArr = new double[arrayList2.size()];
                for (int i5 = 0; i5 < arrayList2.size(); i5++) {
                    MLPLayer mLPLayer = (MLPLayer) arrayList2.get(i5);
                    dArr[i5] = new double[mLPLayer.neurons.size()];
                    for (int i6 = 0; i6 < mLPLayer.neurons.size(); i6++) {
                        dArr[i5][i6] = new double[mLPLayer.neurons.get(0).dimension()];
                    }
                }
                for (int i7 = 0; i7 < dataFrame.rowCount(); i7++) {
                    DataRow row3 = dataFrame.row(i7);
                    if (isValidTrainingSample(row3)) {
                        double[] standardize2 = this.inputNormalization.standardize(row3.toArray());
                        double[] target2 = getTarget(row3);
                        if (this.outputNormalization != null) {
                            target2 = this.outputNormalization.standardize(target2);
                        }
                        double[] output = this.inputLayer.setOutput(standardize2);
                        for (int i8 = 0; i8 < this.hiddenLayers.size(); i8++) {
                            output = this.hiddenLayers.get(i8).forward_propagate(output);
                        }
                        double[] minus = minus(target2, this.outputLayer.forward_propagate(output));
                        for (int i9 = 0; i9 < arrayList2.size(); i9++) {
                            MLPLayer mLPLayer2 = (MLPLayer) arrayList2.get(i9);
                            double[] dArr2 = new double[minus.length];
                            for (int i10 = 0; i10 < dArr2.length; i10++) {
                                dArr2[i10] = mLPLayer2.getTransfer().gradient(mLPLayer2.neurons.get(i10).output) * minus[i10];
                            }
                            int dimension = mLPLayer2.neurons.get(0).dimension();
                            double[] dArr3 = new double[dimension];
                            for (int i11 = 0; i11 < dimension; i11++) {
                                double d = 0.0d;
                                for (int i12 = 0; i12 < dArr2.length; i12++) {
                                    d += mLPLayer2.neurons.get(i12).getWeight(i11) * dArr2[i12];
                                }
                                dArr3[i11] = d;
                                for (int i13 = 0; i13 < dArr2.length; i13++) {
                                    double d2 = mLPLayer2.neurons.get(i13).values[i11];
                                    double[] dArr4 = dArr[i9][i13];
                                    int i14 = i11;
                                    dArr4[i14] = dArr4[i14] + (d2 * dArr2[i13]);
                                }
                            }
                            minus = dArr3;
                        }
                    }
                }
                for (int i15 = 0; i15 < arrayList2.size(); i15++) {
                    MLPLayer mLPLayer3 = (MLPLayer) arrayList2.get(i15);
                    for (int i16 = 0; i16 < mLPLayer3.neurons.size(); i16++) {
                        int dimension2 = mLPLayer3.neurons.get(0).dimension();
                        for (int i17 = 0; i17 < dimension2; i17++) {
                            mLPLayer3.neurons.get(i16).setWeight(i17, mLPLayer3.neurons.get(i16).getWeight(i17) + ((getLearningRate() * dArr[i15][i16][i17]) / dataFrame.rowCount()));
                        }
                    }
                }
            }
        }
    }

    public double[] transform(DataRow dataRow) {
        double[] transform = transform(this.inputNormalization.standardize(dataRow.toArray()));
        if (this.outputNormalization != null) {
            transform = this.outputNormalization.revert(transform);
        }
        return transform;
    }
}
