package org.aika.lattice;

import java.io.DataInput;
import java.io.DataOutput;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.Map;
import java.util.TreeMap;
import org.aika.Model;
import org.aika.Neuron;
import org.aika.Provider;
import org.aika.ReadWriteLock;
import org.aika.Utils;
import org.aika.Writable;
import org.aika.corpus.Document;
import org.aika.corpus.InterpretationNode;
import org.aika.corpus.Range;
import org.aika.lattice.AndNode;
import org.aika.lattice.NodeActivation;
import org.aika.neuron.Activation;
import org.aika.neuron.INeuron;
import org.aika.neuron.Synapse;
import org.aika.training.PatternDiscovery;

/* loaded from: input_file:org/aika/lattice/InputNode.class */
public class InputNode extends Node<InputNode, NodeActivation<InputNode>> {
    public Synapse.Key key;
    public Neuron inputNeuron;
    public Map<SynapseKey, Synapse> synapses;
    public ReadWriteLock synapseLock;
    private long visitedDiscover;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/aika/lattice/InputNode$SynapseKey.class */
    public static class SynapseKey implements Writable, Comparable<SynapseKey> {
        Integer rid;
        Neuron neuron;

        private SynapseKey() {
        }

        public SynapseKey(Integer num, Neuron neuron) {
            this.rid = num;
            this.neuron = neuron;
        }

        @Override // java.lang.Comparable
        public int compareTo(SynapseKey synapseKey) {
            int compareInteger = Utils.compareInteger(this.rid, synapseKey.rid);
            return compareInteger != 0 ? compareInteger : this.neuron.compareTo((Provider<?>) synapseKey.neuron);
        }

        public static SynapseKey read(DataInput dataInput, Model model) throws IOException {
            SynapseKey synapseKey = new SynapseKey();
            synapseKey.readFields(dataInput, model);
            return synapseKey;
        }

        @Override // org.aika.Writable
        public void write(DataOutput dataOutput) throws IOException {
            dataOutput.writeBoolean(this.rid != null);
            if (this.rid != null) {
                dataOutput.writeInt(this.rid.intValue());
            }
            dataOutput.writeInt(this.neuron.id.intValue());
        }

        @Override // org.aika.Writable
        public void readFields(DataInput dataInput, Model model) throws IOException {
            if (dataInput.readBoolean()) {
                this.rid = Integer.valueOf(dataInput.readInt());
            }
            this.neuron = model.lookupNeuron(dataInput.readInt());
        }
    }

    public InputNode() {
        this.synapseLock = new ReadWriteLock();
    }

    public InputNode(Model model, Synapse.Key key) {
        super(model, 1);
        this.synapseLock = new ReadWriteLock();
        this.key = Synapse.lookupKey(key);
    }

    /* JADX WARN: Multi-variable type inference failed */
    public static InputNode add(Model model, Synapse.Key key, INeuron iNeuron) {
        Provider<InputNode> provider = iNeuron != null ? iNeuron.outputNodes.get(key) : null;
        if (provider != null) {
            return provider.get();
        }
        InputNode inputNode = new InputNode(model, key);
        if (iNeuron != null && inputNode.inputNeuron == null) {
            inputNode.inputNeuron = (Neuron) iNeuron.provider;
            iNeuron.outputNodes.put(key, inputNode.provider);
            iNeuron.setModified();
        }
        return inputNode;
    }

    @Override // org.aika.lattice.Node
    protected NodeActivation<InputNode> createActivation(Document document, NodeActivation.Key key) {
        int i = document.activationIdCounter;
        document.activationIdCounter = i + 1;
        return new NodeActivation<>(i, document, key);
    }

    private NodeActivation.Key computeActivationKey(NodeActivation nodeActivation) {
        NodeActivation.Key<T> key = nodeActivation.key;
        if (this.key.absoluteRid != null && this.key.absoluteRid != key.rid) {
            return null;
        }
        InterpretationNode interpretationNode = key.interpretation;
        Document document = key.interpretation.doc;
        long j = document.visitedCounter;
        document.visitedCounter = j + 1;
        if (interpretationNode.isConflicting(j)) {
            return null;
        }
        return new NodeActivation.Key(this, this.key.rangeOutput.map(key.range), this.key.relativeRid != null ? key.rid : null, key.interpretation);
    }

    public void addActivation(Document document, NodeActivation nodeActivation) {
        NodeActivation.Key computeActivationKey = computeActivationKey(nodeActivation);
        if (computeActivationKey != null) {
            addActivationAndPropagate(document, computeActivationKey, Collections.singleton(nodeActivation));
        }
    }

    @Override // org.aika.lattice.Node
    public void propagateAddedActivation(Document document, NodeActivation<InputNode> nodeActivation) {
        if (this.key.isRecurrent) {
            return;
        }
        apply(document, nodeActivation);
    }

    @Override // org.aika.lattice.Node
    public boolean isAllowedOption(int i, InterpretationNode interpretationNode, NodeActivation nodeActivation, long j) {
        return false;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    @Override // org.aika.lattice.Node
    public Collection<AndNode.Refinement> collectNodeAndRefinements(AndNode.Refinement refinement) {
        ArrayList arrayList = new ArrayList(2);
        arrayList.add(new AndNode.Refinement(this.key.relativeRid, refinement.rid, this.provider));
        arrayList.add(refinement);
        return arrayList;
    }

    @Override // org.aika.lattice.Node
    void apply(Document document, NodeActivation<InputNode> nodeActivation) {
        this.lock.acquireReadLock();
        if (this.andChildren != null) {
            this.andChildren.forEach((refinement, provider) -> {
                InputNode ifNotSuspended = refinement.input.getIfNotSuspended();
                if (ifNotSuspended != null) {
                    addNextLevelActivations(document, ifNotSuspended, refinement, provider, nodeActivation);
                }
            });
        }
        this.lock.releaseReadLock();
        OrNode.processCandidate(document, this, nodeActivation, false);
    }

    private static void addNextLevelActivations(Document document, InputNode inputNode, AndNode.Refinement refinement, Provider<AndNode> provider, NodeActivation nodeActivation) {
        INeuron.ThreadState threadState = inputNode.inputNeuron.get().getThreadState(document.threadId, false);
        if (threadState == null || threadState.activations.isEmpty()) {
            return;
        }
        Activation activation = (Activation) nodeActivation.inputs.firstEntry().getValue();
        AndNode andNode = provider.get(document);
        if (andNode.combinatorialExpensive) {
            return;
        }
        NodeActivation.Key<T> key = nodeActivation.key;
        NodeActivation.Key<T> key2 = activation.key;
        InputNode inputNode2 = (InputNode) key.node;
        Activation.select(threadState, inputNode.inputNeuron.get(), Utils.nullSafeAdd(key.rid, false, refinement.rid, false), key2.range, Range.Relation.createQuery(inputNode2.key.rangeMatch, inputNode.key.rangeOutput, inputNode2.key.rangeOutput, inputNode.key.rangeMatch), (InterpretationNode) null, (InterpretationNode.Relation) null).forEach(activation2 -> {
            InterpretationNode add;
            NodeActivation inputNodeActivation = inputNode.getInputNodeActivation(activation2);
            if (inputNodeActivation == null || (add = InterpretationNode.add(document, true, key.interpretation, activation2.key.interpretation)) == null) {
                return;
            }
            Node.addActivationAndPropagate(document, new NodeActivation.Key(andNode, Range.mergeRange(inputNode2.key.rangeOutput.map(key2.range), inputNode.key.rangeOutput.map(activation2.key.range)), Utils.nullSafeMin(key.rid, inputNodeActivation.key.rid), add), AndNode.prepareInputActs(nodeActivation, inputNodeActivation));
        });
    }

    private NodeActivation getInputNodeActivation(Activation activation) {
        for (NodeActivation<?> nodeActivation : activation.outputs.values()) {
            if (nodeActivation.key.node == this) {
                return nodeActivation;
            }
        }
        return null;
    }

    @Override // org.aika.lattice.Node
    public void discover(Document document, NodeActivation<InputNode> nodeActivation, PatternDiscovery.Config config) {
        Model model = this.provider.model;
        long addAndGet = Model.visitedCounter.addAndGet(1L);
        document.getFinalActivations().forEach(activation -> {
            for (NodeActivation<?> nodeActivation2 : activation.outputs.values()) {
                AndNode.Refinement refinement = new AndNode.Refinement(nodeActivation2.key.rid, nodeActivation.key.rid, nodeActivation2.key.node.provider);
                InputNode inputNode = refinement.input.get(document);
                Range.Relation createQuery = Range.Relation.createQuery(this.key.rangeMatch, inputNode.key.rangeOutput, this.key.rangeOutput, inputNode.key.rangeMatch);
                if (nodeActivation != nodeActivation2 && this != inputNode && inputNode.visitedDiscover != addAndGet && !inputNode.key.isRecurrent && createQuery.compare(nodeActivation2.key.range, nodeActivation.key.range)) {
                    inputNode.visitedDiscover = addAndGet;
                    AndNode createNextLevelNode = AndNode.createNextLevelNode(document.model, document.threadId, this, refinement, config);
                    if (createNextLevelNode != null) {
                        createNextLevelNode.isDiscovered = true;
                        document.addedNodes.add(createNextLevelNode);
                    }
                }
            }
        });
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    @Override // org.aika.lattice.Node
    public boolean contains(AndNode.Refinement refinement) {
        return this == refinement.input.get() && Utils.compareInteger(this.key.relativeRid, refinement.rid) == 0;
    }

    @Override // org.aika.lattice.Node
    public double computeSynapseWeightSum(Integer num, INeuron iNeuron) {
        return iNeuron.biasSum + Math.abs(getSynapse(this.key.relativeRid == null ? null : num, (Neuron) iNeuron.provider).weight);
    }

    public Synapse getSynapse(Integer num, Neuron neuron) {
        this.synapseLock.acquireReadLock();
        Synapse synapse = this.synapses != null ? this.synapses.get(new SynapseKey(num, neuron)) : null;
        this.synapseLock.releaseReadLock();
        return synapse;
    }

    public void setSynapse(Synapse synapse) {
        this.synapseLock.acquireWriteLock();
        if (this.synapses == null) {
            this.synapses = new TreeMap();
        }
        this.synapses.put(new SynapseKey(synapse.key.relativeRid, synapse.output), synapse);
        this.synapseLock.releaseWriteLock();
    }

    public void removeSynapse(Synapse synapse) {
        if (this.synapses != null) {
            this.synapseLock.acquireWriteLock();
            this.synapses.remove(new SynapseKey(synapse.key.relativeRid, synapse.output));
            this.synapseLock.releaseWriteLock();
        }
    }

    @Override // org.aika.lattice.Node
    public void cleanup() {
    }

    @Override // org.aika.lattice.Node
    public void remove() {
        this.inputNeuron.get().outputNodes.remove(this.key);
        super.remove();
    }

    @Override // org.aika.lattice.Node
    public String logicToString() {
        StringBuilder sb = new StringBuilder();
        sb.append("I");
        sb.append(this.key.isRecurrent ? "R" : "");
        sb.append(getRangeBrackets(this.key.rangeOutput.begin));
        if (this.inputNeuron != null) {
            sb.append(this.inputNeuron.id);
            if (this.inputNeuron.get().label != null) {
                sb.append(",");
                sb.append(this.inputNeuron.get().label);
            }
        }
        sb.append(getRangeBrackets(this.key.rangeOutput.end));
        return sb.toString();
    }

    private String getRangeBrackets(Range.Mapping mapping) {
        return mapping == Range.Mapping.NONE ? "|" : mapping == Range.Mapping.BEGIN ? "[" : "]";
    }

    @Override // org.aika.lattice.Node, org.aika.Writable
    public void write(DataOutput dataOutput) throws IOException {
        dataOutput.writeBoolean(false);
        dataOutput.writeChar(73);
        super.write(dataOutput);
        this.key.write(dataOutput);
        dataOutput.writeBoolean(this.inputNeuron != null);
        if (this.inputNeuron != null) {
            dataOutput.writeInt(this.inputNeuron.id.intValue());
        }
    }

    @Override // org.aika.lattice.Node, org.aika.Writable
    public void readFields(DataInput dataInput, Model model) throws IOException {
        super.readFields(dataInput, model);
        this.key = Synapse.lookupKey(Synapse.Key.read(dataInput, model));
        if (dataInput.readBoolean()) {
            this.inputNeuron = model.lookupNeuron(dataInput.readInt());
        }
    }
}
