package org.aika;

import java.util.Collection;
import java.util.Comparator;
import java.util.Iterator;
import java.util.TreeSet;
import org.aika.lattice.AndNode;
import org.aika.lattice.InputNode;
import org.aika.lattice.Node;
import org.aika.lattice.OrNode;
import org.aika.neuron.INeuron;
import org.aika.neuron.Synapse;

/* loaded from: input_file:org/aika/Converter.class */
public class Converter {
    public static Comparator<Synapse> SYNAPSE_COMP = new Comparator<Synapse>() { // from class: org.aika.Converter.1
        @Override // java.util.Comparator
        public int compare(Synapse synapse, Synapse synapse2) {
            int compare = Double.compare(synapse2.w, synapse.w);
            return compare != 0 ? compare : Synapse.INPUT_SYNAPSE_COMP.compare(synapse, synapse2);
        }
    };
    private Model m;
    private int threadId;
    private INeuron neuron;
    private OrNode outputNode;
    private Collection<Synapse> modifiedSynapses;

    public static boolean convert(Model model, int i, INeuron iNeuron, Collection<Synapse> collection) {
        return new Converter(model, i, iNeuron, collection).convert();
    }

    private Converter(Model model, int i, INeuron iNeuron, Collection<Synapse> collection) {
        this.m = model;
        this.neuron = iNeuron;
        this.threadId = i;
        this.modifiedSynapses = collection;
    }

    private boolean convert() {
        this.outputNode = this.neuron.node.get();
        initInputNodesAndComputeWeightSums();
        double d = 0.0d;
        double d2 = 0.0d;
        TreeSet treeSet = new TreeSet(SYNAPSE_COMP);
        for (Synapse synapse : this.neuron.inputSynapses.values()) {
            if (!synapse.isNegative() && !synapse.key.isRecurrent) {
                if (synapse.w + this.neuron.bias > 0.0d) {
                    if (d2 > 0.0d && this.neuron.inputSynapses.size() != this.modifiedSynapses.size()) {
                        break;
                    }
                    d2 += 1.0d;
                }
                d += synapse.w;
                treeSet.add(synapse);
            }
        }
        Integer num = null;
        Node node = null;
        boolean z = false;
        TreeSet treeSet2 = new TreeSet(Synapse.INPUT_SYNAPSE_COMP);
        double d3 = 0.0d;
        if (d2 != 0.0d && this.neuron.inputSynapses.size() != this.modifiedSynapses.size()) {
            for (Synapse synapse2 : this.modifiedSynapses) {
                this.outputNode.addInput(synapse2.key.relativeRid, this.threadId, synapse2.inputNode.get(), false);
            }
            return true;
        }
        int i = 0;
        Iterator it = treeSet.iterator();
        while (true) {
            if (!it.hasNext()) {
                break;
            }
            Synapse synapse3 = (Synapse) it.next();
            boolean z2 = (((((d3 + d) - ((double) synapse3.w)) - this.neuron.negRecSum) - this.neuron.negDirSum) + this.neuron.posRecSum) + this.neuron.bias > 0.0d;
            boolean z3 = i >= AndNode.MAX_AND_NODE_SIZE;
            if (z2 || z3) {
                break;
            }
            d -= synapse3.w;
            treeSet2.add(synapse3);
            node = getNextLevelNode(num, node, synapse3);
            num = Utils.nullSafeMin(synapse3.key.relativeRid, num);
            i++;
            d3 += synapse3.w;
            if ((((d3 - this.neuron.negRecSum) - this.neuron.negDirSum) + this.neuron.posRecSum) + this.neuron.bias > 0.0d) {
                z = true;
                break;
            }
        }
        this.outputNode.removeParents(this.threadId, false);
        if (node != this.outputNode.requiredNode) {
            this.outputNode.requiredNode = node;
        }
        if (z || i == AndNode.MAX_AND_NODE_SIZE) {
            this.outputNode.addInput(num, this.threadId, node, false);
            return true;
        }
        Iterator it2 = treeSet.iterator();
        while (it2.hasNext()) {
            Synapse synapse4 = (Synapse) it2.next();
            if ((((((d3 + ((double) synapse4.w)) + d) - this.neuron.negRecSum) - this.neuron.negDirSum) + this.neuron.posRecSum) + this.neuron.bias <= 0.0d) {
                return true;
            }
            if (!treeSet2.contains(synapse4)) {
                this.outputNode.addInput(Utils.nullSafeMin(synapse4.key.relativeRid, num), this.threadId, getNextLevelNode(num, node, synapse4), false);
                d -= synapse4.w;
            }
        }
        return true;
    }

    private void initInputNodesAndComputeWeightSums() {
        double d = 0.0d;
        double d2 = 0.0d;
        double d3 = 0.0d;
        double d4 = 0.0d;
        for (Synapse synapse : this.modifiedSynapses) {
            INeuron iNeuron = synapse.input.get();
            iNeuron.lock.acquireWriteLock();
            if (synapse.inputNode == null) {
                InputNode add = InputNode.add(this.m, synapse.key.createInputNodeKey(), synapse.input.get());
                add.provider.setModified();
                add.setSynapse(synapse);
                synapse.inputNode = add.provider;
            }
            if (synapse.key.isRecurrent) {
                d4 += Math.abs(synapse.nw) - Math.abs(synapse.w);
                ((Neuron) this.neuron.provider).setModified();
            }
            if (synapse.isNegative()) {
                if (synapse.key.isRecurrent) {
                    d2 -= synapse.w;
                } else {
                    d -= synapse.w;
                }
            } else if (synapse.key.isRecurrent) {
                d3 -= synapse.w;
            }
            if (synapse.nw <= 0.0d) {
                if (synapse.key.isRecurrent) {
                    d2 += synapse.nw;
                } else {
                    d += synapse.nw;
                }
            } else if (synapse.key.isRecurrent) {
                d3 += synapse.nw;
            }
            synapse.w = synapse.nw;
            iNeuron.lock.releaseWriteLock();
        }
        this.neuron.maxRecurrentSum += d4;
        this.neuron.negDirSum += d;
        this.neuron.negRecSum += d2;
        this.neuron.posRecSum += d3;
    }

    private Node getNextLevelNode(Integer num, Node node, Synapse synapse) {
        return node == null ? synapse.inputNode.get() : AndNode.createNextLevelNode(this.m, this.threadId, node, new AndNode.Refinement(synapse.key.relativeRid, num, synapse.inputNode), false);
    }
}
