package com.github.chen0040.mlp.ann;

import com.github.chen0040.data.frame.DataFrame;
import com.github.chen0040.data.frame.DataRow;
import com.github.chen0040.data.utils.transforms.Standardization;
import java.util.ArrayList;

/* loaded from: input_file:com/github/chen0040/mlp/ann/MLP.class */
public abstract class MLP extends MLPNet {
    private Standardization inputNormalization;
    private Standardization outputNormalization;
    private boolean normalizeOutputs = false;

    @Override // com.github.chen0040.mlp.ann.MLPNet
    public void copy(MLPNet mLPNet) throws CloneNotSupportedException {
        super.copy(mLPNet);
        MLP mlp = (MLP) mLPNet;
        this.inputNormalization = mlp.inputNormalization == null ? null : (Standardization) mlp.inputNormalization.clone();
        this.outputNormalization = mlp.outputNormalization == null ? null : (Standardization) mlp.outputNormalization.clone();
        this.normalizeOutputs = mlp.normalizeOutputs;
    }

    protected abstract boolean isValidTrainingSample(DataRow dataRow);

    public void setNormalizeOutputs(boolean z) {
        this.normalizeOutputs = z;
    }

    public abstract double[] getTarget(DataRow dataRow);

    public void train(DataFrame dataFrame, int i) {
        this.inputNormalization = new Standardization(dataFrame);
        if (this.normalizeOutputs) {
            ArrayList arrayList = new ArrayList();
            for (int i2 = 0; i2 < dataFrame.rowCount(); i2++) {
                DataRow row = dataFrame.row(i2);
                if (isValidTrainingSample(row)) {
                    arrayList.add(getTarget(row));
                }
            }
            this.outputNormalization = new Standardization(arrayList);
        }
        for (int i3 = 0; i3 < i; i3++) {
            for (int i4 = 0; i4 < dataFrame.rowCount(); i4++) {
                DataRow row2 = dataFrame.row(i4);
                if (isValidTrainingSample(row2)) {
                    double[] standardize = this.inputNormalization.standardize(row2.toArray());
                    double[] target = getTarget(row2);
                    if (this.outputNormalization != null) {
                        target = this.outputNormalization.standardize(target);
                    }
                    train(standardize, target);
                }
            }
        }
    }

    public double[] transform(DataRow dataRow) {
        double[] transform = transform(this.inputNormalization.standardize(dataRow.toArray()));
        if (this.outputNormalization != null) {
            transform = this.outputNormalization.revert(transform);
        }
        return transform;
    }
}
