package com.github.bentorfs.ai.ml.ann;

import com.github.bentorfs.ai.common.FunctionLearner;
import java.util.ArrayList;
import java.util.List;

/* loaded from: input_file:com/github/bentorfs/ai/ml/ann/BackPropagation.class */
public class BackPropagation implements FunctionLearner<Double, List<Double>> {
    private double learningRate;
    private FeedForwardNetwork feedForwardNetwork;

    public BackPropagation(MultiLayerPerceptron multiLayerPerceptron, double d) {
        this.learningRate = 0.1d;
        this.feedForwardNetwork = multiLayerPerceptron;
        this.learningRate = d;
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // com.github.bentorfs.ai.common.FunctionLearner
    public List<Double> predictValue(List<Double> list) {
        return this.feedForwardNetwork.getOutput(list);
    }

    @Override // com.github.bentorfs.ai.common.FunctionLearner
    public void showExample(List<Double> list, List<Double> list2) {
        ArrayList arrayList = new ArrayList(this.feedForwardNetwork.getNbOfLayers() + 1);
        arrayList.add(list2);
        for (int i = 1; i <= this.feedForwardNetwork.getNbOfLayers(); i++) {
            arrayList.add(this.feedForwardNetwork.getOutputAtLayer(list2, i));
        }
        List<List<Double>> errorTerms = getErrorTerms(list, arrayList);
        for (int i2 = 1; i2 < arrayList.size(); i2++) {
            List<Double> list3 = arrayList.get(i2 - 1);
            NetworkLayer layer = this.feedForwardNetwork.getLayer(i2);
            for (int i3 = 0; i3 < layer.getNumberOfUnits(); i3++) {
                NetworkUnit unit = layer.getUnit(i3);
                for (int i4 = 0; i4 < list3.size(); i4++) {
                    unit.setWeight(i4, unit.getWeight(i4) + (this.learningRate * list3.get(i4).doubleValue() * errorTerms.get(i2).get(i3).doubleValue()));
                }
            }
            this.feedForwardNetwork.getOutputAtLayer(list2, i2);
        }
    }

    private List<List<Double>> getErrorTerms(List<Double> list, List<List<Double>> list2) {
        List<List<Double>> emptyListWithSize = getEmptyListWithSize(this.feedForwardNetwork.getNbOfLayers() + 1);
        emptyListWithSize.set(this.feedForwardNetwork.getNbOfLayers(), getOutputLayerErrorTerms(list, list2));
        for (int size = list2.size() - 2; size > 0; size--) {
            List<Double> list3 = list2.get(size);
            ArrayList arrayList = new ArrayList();
            for (int i = 0; i < list3.size(); i++) {
                Double d = list3.get(i);
                NetworkLayer layer = this.feedForwardNetwork.getLayer(size + 1);
                double d2 = 0.0d;
                for (int i2 = 0; i2 < layer.getNumberOfUnits(); i2++) {
                    d2 += layer.getUnit(i2).getWeight(i) * emptyListWithSize.get(size + 1).get(i2).doubleValue();
                }
                arrayList.add(Double.valueOf(d.doubleValue() * (1.0d - d.doubleValue()) * d2));
            }
            emptyListWithSize.set(size, arrayList);
        }
        return emptyListWithSize;
    }

    private List<Double> getOutputLayerErrorTerms(List<Double> list, List<List<Double>> list2) {
        List<Double> list3 = list2.get(list2.size() - 1);
        if (list3.size() != list.size()) {
            throw new RuntimeException("Mismatch between the number of outputs of the feedforward network and the training example");
        }
        ArrayList arrayList = new ArrayList();
        for (int i = 0; i < list3.size(); i++) {
            Double d = list3.get(i);
            arrayList.add(Double.valueOf(d.doubleValue() * (1.0d - d.doubleValue()) * (list.get(i).doubleValue() - d.doubleValue())));
        }
        return arrayList;
    }

    private List<List<Double>> getEmptyListWithSize(int i) {
        ArrayList arrayList = new ArrayList(i);
        for (int i2 = 0; i2 < i; i2++) {
            arrayList.add(null);
        }
        return arrayList;
    }
}
