package org.aika.training;

import java.util.Collections;
import java.util.TreeSet;
import org.aika.corpus.Document;
import org.aika.lattice.OrNode;
import org.aika.neuron.Activation;
import org.aika.neuron.INeuron;
import org.aika.neuron.Synapse;
import org.aika.training.SynapseEvaluation;

/* loaded from: input_file:org/aika/training/LongTermLearning.class */
public class LongTermLearning {
    static final /* synthetic */ boolean $assertionsDisabled;

    /* loaded from: input_file:org/aika/training/LongTermLearning$Config.class */
    public static class Config {
        public SynapseEvaluation synapseEvaluation;
        public double ltpLearnRate;
        public double ltdLearnRate;
        public double beta;

        public Config setSynapseEvaluation(SynapseEvaluation synapseEvaluation) {
            this.synapseEvaluation = synapseEvaluation;
            return this;
        }

        public Config setLTPLearnRate(double d) {
            this.ltpLearnRate = d;
            return this;
        }

        public Config setLTDLearnRate(double d) {
            this.ltdLearnRate = d;
            return this;
        }

        public Config setBeta(double d) {
            this.beta = d;
            return this;
        }
    }

    public static void train(Document document, Config config) {
        document.getFinalActivations().forEach(activation -> {
            longTermPotentiation(document, config, activation);
            longTermDepression(document, config, activation, false);
            longTermDepression(document, config, activation, true);
        });
    }

    public static void longTermPotentiation(Document document, Config config, Activation activation) {
        INeuron iNeuron = ((OrNode) activation.key.node).neuron.get();
        double d = 0.0d;
        for (Activation.SynapseActivation synapseActivation : activation.getFinalInputActivations()) {
            if (!synapseActivation.synapse.isNegative()) {
                d += synapseActivation.input.getFinalState().value * synapseActivation.synapse.weight;
            }
        }
        double d2 = config.ltpLearnRate * (1.0d - activation.getFinalState().value) * (iNeuron.posDirSum + iNeuron.posRecSum > 0.0d ? d / (iNeuron.posDirSum + iNeuron.posRecSum) : 1.0d);
        document.getFinalActivations().filter(activation2 -> {
            return activation2.key.node != activation.key.node;
        }).forEach(activation3 -> {
            SynapseEvaluation.Result evaluate = config.synapseEvaluation.evaluate(null, activation3, activation);
            double d3 = activation3.getFinalState().value * d2 * evaluate.significance;
            if (d3 > 0.0d) {
                Synapse createOrLookup = Synapse.createOrLookup(evaluate.synapseKey, ((OrNode) activation3.key.node).neuron, ((OrNode) activation.key.node).neuron);
                createOrLookup.weightDelta += (float) d3;
                createOrLookup.biasDelta -= config.beta * d3;
                if (!$assertionsDisabled && Double.isNaN(iNeuron.bias)) {
                    throw new AssertionError();
                }
            }
        });
        document.notifyWeightsModified(iNeuron, iNeuron.inputSynapses.values());
    }

    public static void longTermDepression(Document document, Config config, Activation activation, boolean z) {
        INeuron iNeuron = ((OrNode) activation.key.node).neuron.get();
        TreeSet treeSet = new TreeSet(z ? Synapse.OUTPUT_SYNAPSE_COMP : Synapse.INPUT_SYNAPSE_COMP);
        (z ? activation.getFinalOutputActivations() : activation.getFinalInputActivations()).forEach(synapseActivation -> {
            treeSet.add(synapseActivation.synapse);
        });
        (z ? iNeuron.outputSynapses : iNeuron.inputSynapses).values().stream().filter(synapse -> {
            return (synapse.isNegative() || treeSet.contains(synapse)) ? false : true;
        }).forEach(synapse2 -> {
            SynapseEvaluation.Result evaluate = config.synapseEvaluation.evaluate(synapse2, z ? activation : null, z ? null : activation);
            synapse2.weightDelta -= (float) ((config.ltdLearnRate * activation.getFinalState().value) * evaluate.significance);
            if (evaluate.deleteIfNull && synapse2.weight - synapse2.weightDelta <= 0.0d) {
                synapse2.toBeDeleted = true;
            }
            if (z) {
                document.notifyWeightsModified(synapse2.output.get(), Collections.singletonList(synapse2));
            }
        });
        if (z) {
            return;
        }
        document.notifyWeightsModified(iNeuron, iNeuron.inputSynapses.values());
    }

    static {
        $assertionsDisabled = !LongTermLearning.class.desiredAssertionStatus();
    }
}
