package com.github.bentorfs.ai.ml.ann;

import com.github.bentorfs.ai.common.FunctionLearner;
import java.util.List;

/* loaded from: input_file:com/github/bentorfs/ai/ml/ann/PerceptronLearner.class */
public class PerceptronLearner implements FunctionLearner<Double, Double> {
    private double learningRate;
    private NetworkUnit perceptron;
    private PerceptronTrainingStrategy strategy;

    /* loaded from: input_file:com/github/bentorfs/ai/ml/ann/PerceptronLearner$PerceptronTrainingStrategy.class */
    public enum PerceptronTrainingStrategy {
        perceptron_rule,
        delta_rule
    }

    public PerceptronLearner(NetworkUnit networkUnit) {
        this.learningRate = 0.1d;
        this.strategy = PerceptronTrainingStrategy.delta_rule;
        this.perceptron = networkUnit;
    }

    public PerceptronLearner(NetworkUnit networkUnit, PerceptronTrainingStrategy perceptronTrainingStrategy, double d) {
        this.learningRate = 0.1d;
        this.strategy = PerceptronTrainingStrategy.delta_rule;
        this.perceptron = networkUnit;
        this.strategy = perceptronTrainingStrategy;
        this.learningRate = d;
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // com.github.bentorfs.ai.common.FunctionLearner
    public Double predictValue(List<Double> list) {
        return Double.valueOf(this.perceptron.getValue(list));
    }

    @Override // com.github.bentorfs.ai.common.FunctionLearner
    public void showExample(Double d, List<Double> list) {
        double value = this.perceptron.getValue(list);
        if (PerceptronTrainingStrategy.perceptron_rule.equals(this.strategy)) {
            value = getBinaryValue(value);
        }
        List<Double> weights = this.perceptron.getWeights();
        for (int i = 0; i < weights.size(); i++) {
            this.perceptron.setWeight(i, this.perceptron.getWeight(i) + (this.learningRate * (d.doubleValue() - value) * list.get(i).doubleValue()));
        }
        this.perceptron.setConstantInputWeight(this.perceptron.getConstantInputWeight() + (this.learningRate * (d.doubleValue() - value)));
    }

    public int getBinaryValue(double d) {
        return d > 0.0d ? 1 : -1;
    }
}
