package org.aika.training;

import java.util.Collections;
import java.util.TreeSet;
import org.aika.corpus.Document;
import org.aika.neuron.Activation;
import org.aika.neuron.INeuron;
import org.aika.neuron.Synapse;
import org.aika.training.SynapseEvaluation;

/* loaded from: input_file:org/aika/training/LongTermLearning.class */
public class LongTermLearning {

    /* loaded from: input_file:org/aika/training/LongTermLearning$Config.class */
    public static class Config {
        public SynapseEvaluation synapseEvaluation;
        public double ltpLearnRate;
        public double ltdLearnRate;
        public double beta;
        public boolean createNewSynapses;

        public Config setSynapseEvaluation(SynapseEvaluation synapseEvaluation) {
            this.synapseEvaluation = synapseEvaluation;
            return this;
        }

        public Config setLTPLearnRate(double d) {
            this.ltpLearnRate = d;
            return this;
        }

        public Config setLTDLearnRate(double d) {
            this.ltdLearnRate = d;
            return this;
        }

        public Config setBeta(double d) {
            this.beta = d;
            return this;
        }

        public Config setCreateNewSynapses(boolean z) {
            this.createNewSynapses = z;
            return this;
        }
    }

    public static void train(Document document, Config config) {
        document.getFinalActivations().forEach(activation -> {
            longTermPotentiation(document, config, activation);
            longTermDepression(document, config, activation, false);
            longTermDepression(document, config, activation, true);
        });
        document.commit();
    }

    public static void longTermPotentiation(Document document, Config config, Activation activation) {
        INeuron iNeuron = activation.getINeuron();
        double f = iNeuron.activationFunction.f(iNeuron.biasSum + iNeuron.posDirSum + iNeuron.posRecSum);
        double max = config.ltpLearnRate * (1.0d - activation.getFinalState().value) * (f > 0.0d ? Math.max(1.0d, activation.getFinalState().value / f) : 1.0d);
        if (config.createNewSynapses) {
            document.getFinalActivations().filter(activation2 -> {
                return activation2.key.node != activation.key.node;
            }).forEach(activation3 -> {
                synapseLTP(config, activation3, activation, max, config.synapseEvaluation.evaluate(null, activation3, activation));
            });
        } else {
            activation.getFinalInputActivations().forEach(synapseActivation -> {
                synapseLTP(config, synapseActivation.input, activation, max, config.synapseEvaluation.evaluate(synapseActivation.synapse, synapseActivation.input, activation));
            });
        }
        document.notifyWeightsModified(iNeuron, iNeuron.inputSynapses.values());
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static void synapseLTP(Config config, Activation activation, Activation activation2, double d, SynapseEvaluation.Result result) {
        if (result == null) {
            return;
        }
        double d2 = activation.getFinalState().value * d * result.significance;
        if (d2 > 0.0d) {
            Synapse createOrLookup = Synapse.createOrLookup(result.synapseKey, activation.getNeuron(), activation2.getNeuron());
            createOrLookup.weightDelta += (float) d2;
            createOrLookup.changeBias((-config.beta) * d2);
        }
    }

    public static void longTermDepression(Document document, Config config, Activation activation, boolean z) {
        if (activation.getFinalState().value <= 0.0d) {
            return;
        }
        INeuron iNeuron = activation.getINeuron();
        TreeSet treeSet = new TreeSet(z ? Synapse.OUTPUT_SYNAPSE_COMP : Synapse.INPUT_SYNAPSE_COMP);
        (z ? activation.getFinalOutputActivations() : activation.getFinalInputActivations()).forEach(synapseActivation -> {
            treeSet.add(synapseActivation.synapse);
        });
        (z ? iNeuron.outputSynapses : iNeuron.inputSynapses).values().stream().filter(synapse -> {
            return (synapse.isNegative() || treeSet.contains(synapse)) ? false : true;
        }).forEach(synapse2 -> {
            if (synapse2.isConjunction(false) != z) {
                SynapseEvaluation.Result evaluate = config.synapseEvaluation.evaluate(synapse2, z ? activation : null, z ? null : activation);
                if (evaluate != null) {
                    synapse2.weightDelta -= (float) ((config.ltdLearnRate * activation.getFinalState().value) * evaluate.significance);
                    evaluate.deleteMode.checkIfDelete(synapse2);
                    if (z) {
                        document.notifyWeightsModified(synapse2.output.get(), Collections.singletonList(synapse2));
                    }
                }
            }
        });
        if (z) {
            return;
        }
        document.notifyWeightsModified(iNeuron, iNeuron.inputSynapses.values());
    }
}
