package com.github.chen0040.mlp.ann;

import com.github.chen0040.mlp.functions.TransferFunction;
import java.util.ArrayList;
import java.util.List;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:com/github/chen0040/mlp/ann/MLPNet.class */
public class MLPNet implements Cloneable {
    private static final Logger logger = LoggerFactory.getLogger(MLPNet.class);
    protected MLPLayer inputLayer = null;
    public MLPLayer outputLayer = null;
    protected double learningRate = 0.25d;
    protected List<MLPLayer> hiddenLayers = new ArrayList();

    public void copy(MLPNet mLPNet) throws CloneNotSupportedException {
        this.inputLayer = mLPNet.inputLayer == null ? null : (MLPLayer) mLPNet.inputLayer.clone();
        this.outputLayer = mLPNet.outputLayer == null ? null : (MLPLayer) mLPNet.outputLayer.clone();
        this.hiddenLayers.clear();
        for (int i = 0; i < mLPNet.hiddenLayers.size(); i++) {
            this.hiddenLayers.add((MLPLayer) mLPNet.hiddenLayers.get(i).clone());
        }
        this.learningRate = mLPNet.learningRate;
    }

    public MLPLayer createInputLayer(int i) {
        this.inputLayer = new MLPLayer(i);
        return this.inputLayer;
    }

    public MLPLayer createOutputLayer(int i) {
        this.outputLayer = new MLPLayer(i);
        return this.outputLayer;
    }

    public void addHiddenLayer(int i) {
        this.hiddenLayers.add(new MLPLayer(i));
    }

    public void addHiddenLayer(int i, TransferFunction transferFunction) {
        MLPLayer mLPLayer = new MLPLayer(i);
        mLPLayer.setTransfer(transferFunction);
        this.hiddenLayers.add(mLPLayer);
    }

    public double train(double[] dArr, double[] dArr2) {
        double[] output = this.inputLayer.setOutput(dArr);
        for (int i = 0; i < this.hiddenLayers.size(); i++) {
            output = this.hiddenLayers.get(i).forward_propagate(output);
        }
        double[] forward_propagate = this.outputLayer.forward_propagate(output);
        double d = get_target_error(dArr2);
        double[] back_propagate = this.outputLayer.back_propagate(minus(dArr2, forward_propagate));
        for (int size = this.hiddenLayers.size() - 1; size >= 0; size--) {
            back_propagate = this.hiddenLayers.get(size).back_propagate(back_propagate);
        }
        double[] output2 = this.inputLayer.output();
        for (int i2 = 0; i2 < this.hiddenLayers.size(); i2++) {
            this.hiddenLayers.get(i2).adjust_weights(output2, getLearningRate());
            output2 = this.hiddenLayers.get(i2).output();
        }
        this.outputLayer.adjust_weights(output2, getLearningRate());
        return d;
    }

    public double[] minus(double[] dArr, double[] dArr2) {
        double[] dArr3 = new double[dArr.length];
        for (int i = 0; i < dArr.length; i++) {
            dArr3[i] = dArr[i] - dArr2[i];
        }
        return dArr3;
    }

    protected double get_target_error(double[] dArr) {
        double d = 0.0d;
        double[] output = this.outputLayer.output();
        for (int i = 0; i < output.length; i++) {
            double d2 = dArr[i] - output[i];
            d += 0.5d * d2 * d2;
        }
        return d;
    }

    public double[] transform(double[] dArr) {
        double[] output = this.inputLayer.setOutput(dArr);
        for (int i = 0; i < this.hiddenLayers.size(); i++) {
            output = this.hiddenLayers.get(i).forward_propagate(output);
        }
        return this.outputLayer.forward_propagate(output);
    }

    public double getLearningRate() {
        return this.learningRate;
    }

    public void setLearningRate(double d) {
        this.learningRate = d;
    }
}
