package org.aika;

import java.util.Collection;
import java.util.Comparator;
import java.util.Iterator;
import java.util.TreeSet;
import org.aika.lattice.AndNode;
import org.aika.lattice.InputNode;
import org.aika.lattice.Node;
import org.aika.lattice.OrNode;
import org.aika.neuron.INeuron;
import org.aika.neuron.Synapse;

/* loaded from: input_file:org/aika/Converter.class */
public class Converter {
    public static int MAX_AND_NODE_SIZE;
    public static Comparator<Synapse> SYNAPSE_COMP;
    private Model model;
    private int threadId;
    private INeuron neuron;
    private OrNode outputNode;
    private Collection<Synapse> modifiedSynapses;
    public static final int DIRECT = 0;
    public static final int RECURRENT = 1;
    public static final int POSITIVE = 0;
    public static final int NEGATIVE = 1;
    static final /* synthetic */ boolean $assertionsDisabled;

    public static boolean convert(Model model, int i, INeuron iNeuron, Collection<Synapse> collection) {
        return new Converter(model, i, iNeuron, collection).convert();
    }

    private Converter(Model model, int i, INeuron iNeuron, Collection<Synapse> collection) {
        this.model = model;
        this.neuron = iNeuron;
        this.threadId = i;
        this.modifiedSynapses = collection;
    }

    private boolean convert() {
        this.outputNode = this.neuron.node.get();
        initInputNodesAndComputeWeightSums();
        double d = 0.0d;
        TreeSet treeSet = new TreeSet(SYNAPSE_COMP);
        for (Synapse synapse : this.neuron.inputSynapses.values()) {
            if (!synapse.isNegative() && !synapse.key.isRecurrent) {
                d += synapse.weight;
                treeSet.add(synapse);
            }
        }
        Integer num = null;
        Node node = null;
        boolean z = false;
        TreeSet treeSet2 = new TreeSet(Synapse.INPUT_SYNAPSE_COMP);
        double d2 = 0.0d;
        if (d + this.neuron.posRecSum + this.neuron.biasSum > 0.0d) {
            int i = 0;
            Iterator it = treeSet.iterator();
            while (true) {
                if (!it.hasNext()) {
                    break;
                }
                Synapse synapse2 = (Synapse) it.next();
                boolean z2 = (((d2 + d) - synapse2.weight) + this.neuron.posRecSum) + this.neuron.biasSum > 0.0d;
                boolean z3 = i >= MAX_AND_NODE_SIZE;
                if (z2 || z3) {
                    break;
                }
                d -= synapse2.weight;
                treeSet2.add(synapse2);
                node = getNextLevelNode(num, node, synapse2);
                num = Utils.nullSafeMin(synapse2.key.relativeRid, num);
                i++;
                d2 += synapse2.weight;
                if ((d2 + this.neuron.posRecSum) + this.neuron.biasSum > 0.0d) {
                    z = true;
                    break;
                }
            }
            this.outputNode.removeParents(this.threadId, false);
            if (node != this.outputNode.requiredNode) {
                this.outputNode.requiredNode = node;
            }
            if (z || i == MAX_AND_NODE_SIZE) {
                this.outputNode.addInput(num, this.threadId, node, false);
            } else {
                Iterator it2 = treeSet.iterator();
                while (it2.hasNext()) {
                    Synapse synapse3 = (Synapse) it2.next();
                    if ((((d2 + synapse3.weight) + d) + this.neuron.posRecSum) + this.neuron.biasSum <= 0.0d) {
                        break;
                    }
                    if (!treeSet2.contains(synapse3)) {
                        this.outputNode.addInput(Utils.nullSafeMin(synapse3.key.relativeRid, num), this.threadId, getNextLevelNode(num, node, synapse3), false);
                        d -= synapse3.weight;
                    }
                }
            }
        }
        for (Synapse synapse4 : this.modifiedSynapses) {
            if (synapse4.weight + this.neuron.posRecSum + this.neuron.biasSum > 0.0d) {
                this.outputNode.addInput(synapse4.key.relativeRid, this.threadId, synapse4.inputNode.get(), false);
            }
        }
        return true;
    }

    private void initInputNodesAndComputeWeightSums() {
        double[][] dArr = new double[2][2];
        double d = 0.0d;
        this.neuron.biasSum = 0.0d;
        for (Synapse synapse : this.modifiedSynapses) {
            if (synapse.toBeDeleted) {
                synapse.weightDelta = -synapse.weight;
                synapse.biasDelta = -synapse.bias;
            }
            INeuron iNeuron = synapse.input.get();
            iNeuron.lock.acquireWriteLock();
            if (synapse.inputNode == null) {
                InputNode add = InputNode.add(this.model, synapse.key.createInputNodeKey(), synapse.input.get());
                add.setModified();
                add.setSynapse(synapse);
                synapse.inputNode = add.provider;
            }
            if (synapse.key.isRecurrent) {
                d += Math.abs(synapse.weight + synapse.weightDelta) - Math.abs(synapse.weight);
            }
            double[] dArr2 = dArr[synapse.key.isRecurrent ? (char) 1 : (char) 0];
            char c = synapse.isNegative() ? (char) 1 : (char) 0;
            dArr2[c] = dArr2[c] - synapse.weight;
            double[] dArr3 = dArr[synapse.key.isRecurrent ? (char) 1 : (char) 0];
            char c2 = synapse.weight + synapse.weightDelta <= 0.0d ? (char) 1 : (char) 0;
            dArr3[c2] = dArr3[c2] + synapse.weight + synapse.weightDelta;
            synapse.weight += synapse.weightDelta;
            synapse.weightDelta = 0.0d;
            synapse.bias += synapse.biasDelta;
            synapse.biasDelta = 0.0d;
            this.neuron.biasSum += synapse.bias;
            iNeuron.lock.releaseWriteLock();
            if (synapse.toBeDeleted) {
                synapse.unlink();
            }
        }
        this.neuron.bias += this.neuron.biasDelta;
        this.neuron.biasDelta = 0.0d;
        this.neuron.biasSum += this.neuron.bias;
        this.neuron.biasSum = Math.min(this.neuron.biasSum, 0.0d);
        if (!$assertionsDisabled && !Double.isFinite(this.neuron.biasSum)) {
            throw new AssertionError();
        }
        this.neuron.maxRecurrentSum += d;
        this.neuron.posDirSum += dArr[0][0];
        this.neuron.negDirSum += dArr[0][1];
        this.neuron.negRecSum += dArr[1][1];
        this.neuron.posRecSum += dArr[1][0];
        this.neuron.setModified();
    }

    private Node getNextLevelNode(Integer num, Node node, Synapse synapse) {
        return node == null ? synapse.inputNode.get() : AndNode.createNextLevelNode(this.model, this.threadId, node, new AndNode.Refinement(synapse.key.relativeRid, num, synapse.inputNode), null);
    }

    static {
        $assertionsDisabled = !Converter.class.desiredAssertionStatus();
        MAX_AND_NODE_SIZE = 4;
        SYNAPSE_COMP = (synapse, synapse2) -> {
            int compare = Double.compare(synapse2.weight, synapse.weight);
            return compare != 0 ? compare : Synapse.INPUT_SYNAPSE_COMP.compare(synapse, synapse2);
        };
    }
}
