package org.aika.lattice;

import java.io.DataInput;
import java.io.DataOutput;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Iterator;
import java.util.Map;
import java.util.SortedMap;
import java.util.TreeMap;
import java.util.TreeSet;
import org.aika.Model;
import org.aika.Neuron;
import org.aika.Provider;
import org.aika.Utils;
import org.aika.corpus.Document;
import org.aika.corpus.InterpretationNode;
import org.aika.corpus.Range;
import org.aika.lattice.Node;
import org.aika.lattice.NodeActivation;
import org.aika.neuron.INeuron;
import org.aika.neuron.Synapse;
import org.aika.training.PatternDiscovery;

/* loaded from: input_file:org/aika/lattice/AndNode.class */
public class AndNode extends Node<AndNode, NodeActivation<AndNode>> {
    public SortedMap<Refinement, Provider<? extends Node>> parents;
    public boolean combinatorialExpensive;
    public int numCombExpParents;

    /* loaded from: input_file:org/aika/lattice/AndNode$Refinement.class */
    public static class Refinement implements Comparable<Refinement> {
        public static Refinement MIN = new Refinement(null, null);
        public static Refinement MAX = new Refinement(null, null);
        public Integer rid;
        public Provider<InputNode> input;

        private Refinement() {
        }

        public Refinement(Integer num, Provider<InputNode> provider) {
            this.rid = num;
            this.input = provider;
        }

        public Refinement(Integer num, Integer num2, Provider<InputNode> provider) {
            if (num2 == null && num != null) {
                this.rid = 0;
            } else if (num2 == null || num == null) {
                this.rid = null;
            } else {
                this.rid = Integer.valueOf(num.intValue() - num2.intValue());
            }
            this.input = provider;
        }

        public Integer getOffset() {
            if (this.rid != null) {
                return Integer.valueOf(Math.min(0, this.rid.intValue()));
            }
            return null;
        }

        public Integer getRelativePosition() {
            if (this.rid != null) {
                return Integer.valueOf(Math.max(0, this.rid.intValue()));
            }
            return null;
        }

        public Synapse getSynapse(Integer num, Neuron neuron) {
            return this.input.get().getSynapse(Utils.nullSafeAdd(getRelativePosition(), false, num, false), neuron);
        }

        public String toString() {
            return "(" + (this.rid != null ? this.rid + ":" : "") + this.input.get().logicToString() + ")";
        }

        public void write(DataOutput dataOutput) throws IOException {
            dataOutput.writeBoolean(this.rid != null);
            if (this.rid != null) {
                dataOutput.writeInt(this.rid.intValue());
            }
            dataOutput.writeInt(this.input.id.intValue());
        }

        public boolean readFields(DataInput dataInput, Model model) throws IOException {
            if (dataInput.readBoolean()) {
                this.rid = Integer.valueOf(dataInput.readInt());
            }
            this.input = model.lookupNodeProvider(dataInput.readInt());
            return true;
        }

        public static Refinement read(DataInput dataInput, Model model) throws IOException {
            Refinement refinement = new Refinement();
            refinement.readFields(dataInput, model);
            return refinement;
        }

        @Override // java.lang.Comparable
        public int compareTo(Refinement refinement) {
            if (this == MIN || refinement == MAX) {
                return -1;
            }
            if (this == MAX || refinement == MIN) {
                return 1;
            }
            int compareTo = this.input.compareTo((Provider<?>) refinement.input);
            return compareTo != 0 ? compareTo : Utils.compareInteger(this.rid, refinement.rid);
        }
    }

    public AndNode() {
        this.parents = new TreeMap();
        this.combinatorialExpensive = false;
        this.numCombExpParents = 0;
    }

    public AndNode(Model model, int i, SortedMap<Refinement, Provider<? extends Node>> sortedMap) {
        super(model, i);
        this.parents = new TreeMap();
        this.combinatorialExpensive = false;
        this.numCombExpParents = 0;
        this.parents = sortedMap;
    }

    /* JADX WARN: Multi-variable type inference failed */
    private void init() {
        for (Map.Entry<Refinement, Provider<? extends Node>> entry : this.parents.entrySet()) {
            Refinement key = entry.getKey();
            Node node = entry.getValue().get();
            node.addAndChild(key, this.provider);
            node.setModified();
            if (this.level > 2 && ((AndNode) node).combinatorialExpensive) {
                this.numCombExpParents++;
            }
        }
        if (this.provider.model.getAndNodeCheck() != null) {
            this.combinatorialExpensive = this.provider.model.getAndNodeCheck().checkIfCombinatorialExpensive(this);
        }
    }

    @Override // org.aika.lattice.Node
    public boolean isAllowedOption(int i, InterpretationNode interpretationNode, NodeActivation<?> nodeActivation, long j) {
        Node.ThreadState<AndNode, NodeActivation<AndNode>> threadState = getThreadState(i, true);
        if (threadState.visited == j) {
            return false;
        }
        threadState.visited = j;
        for (NodeActivation<?> nodeActivation2 : nodeActivation.inputs.values()) {
            if (nodeActivation2.key.node.isAllowedOption(i, interpretationNode, nodeActivation2, j)) {
                return true;
            }
        }
        return false;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    @Override // org.aika.lattice.Node
    public NodeActivation<AndNode> processAddedActivation(Document document, NodeActivation.Key<AndNode> key, Collection<NodeActivation> collection) {
        if (collection.size() + this.numCombExpParents != this.level) {
            return null;
        }
        return super.processAddedActivation(document, key, collection);
    }

    @Override // org.aika.lattice.Node
    public void propagateAddedActivation(Document document, NodeActivation<AndNode> nodeActivation) {
        apply(document, nodeActivation);
    }

    @Override // org.aika.lattice.Node
    public void cleanup() {
        if (this.isRemoved || isRequired()) {
            return;
        }
        remove();
        Iterator<Provider<? extends Node>> it = this.parents.values().iterator();
        while (it.hasNext()) {
            it.next().get().cleanup();
        }
    }

    @Override // org.aika.lattice.Node
    void apply(Document document, NodeActivation<AndNode> nodeActivation) {
        Refinement refinement;
        Provider<AndNode> andChild;
        if (this.andChildren != null) {
            for (NodeActivation<?> nodeActivation2 : nodeActivation.inputs.values()) {
                T t = nodeActivation2.key.node;
                t.lock.acquireReadLock();
                Refinement refinement2 = t.reverseAndChildren.get(new Node.ReverseAndRefinement(nodeActivation.key.node.provider, nodeActivation.key.rid, nodeActivation2.key.rid));
                if (refinement2 != null) {
                    Iterator<NodeActivation<?>> it = nodeActivation2.outputs.values().iterator();
                    while (it.hasNext()) {
                        NodeActivation<AndNode> nodeActivation3 = (NodeActivation) it.next();
                        if (nodeActivation != nodeActivation3 && (refinement = t.reverseAndChildren.get(new Node.ReverseAndRefinement(nodeActivation3.key.node.provider, nodeActivation3.key.rid, nodeActivation2.key.rid))) != null && (andChild = getAndChild(new Refinement(refinement.rid, refinement2.getOffset(), refinement.input))) != null) {
                            addNextLevelActivation(document, nodeActivation, nodeActivation3, andChild);
                        }
                    }
                }
                t.lock.releaseReadLock();
            }
        }
        OrNode.processCandidate(document, this, nodeActivation, false);
    }

    @Override // org.aika.lattice.Node
    public void discover(Document document, NodeActivation<AndNode> nodeActivation, PatternDiscovery.Config config) {
        for (NodeActivation<?> nodeActivation2 : nodeActivation.inputs.values()) {
            T t = nodeActivation2.key.node;
            t.lock.acquireReadLock();
            Refinement refinement = t.reverseAndChildren.get(new Node.ReverseAndRefinement(nodeActivation.key.node.provider, nodeActivation.key.rid, nodeActivation2.key.rid));
            Iterator<NodeActivation<?>> it = nodeActivation2.outputs.values().iterator();
            while (it.hasNext()) {
                NodeActivation<AndNode> nodeActivation3 = (NodeActivation) it.next();
                if ((nodeActivation3.key.node instanceof AndNode) && nodeActivation != nodeActivation3 && config.checkExpandable.evaluate(nodeActivation3)) {
                    Refinement refinement2 = t.reverseAndChildren.get(new Node.ReverseAndRefinement(nodeActivation3.key.node.provider, nodeActivation3.key.rid, nodeActivation2.key.rid));
                    AndNode createNextLevelNode = createNextLevelNode(document.model, document.threadId, this, new Refinement(refinement2.rid, refinement.getOffset(), refinement2.input), config);
                    if (createNextLevelNode != null) {
                        createNextLevelNode.isDiscovered = true;
                        document.addedNodes.add(createNextLevelNode);
                    }
                }
            }
            t.lock.releaseReadLock();
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    @Override // org.aika.lattice.Node
    public boolean contains(Refinement refinement) {
        if (refinement.rid != null && refinement.rid.intValue() < 0) {
            return false;
        }
        boolean z = false;
        this.lock.acquireReadLock();
        if (refinement.rid == null || refinement.rid.intValue() > 0) {
            z = this.parents.containsKey(refinement);
        } else if (refinement.rid.intValue() == 0) {
            Iterator<Refinement> it = this.parents.keySet().iterator();
            while (true) {
                if (!it.hasNext()) {
                    break;
                }
                Refinement next = it.next();
                if (next.rid == null || next.rid.intValue() <= 0) {
                    if (next.input == refinement.input) {
                        z = true;
                        break;
                    }
                }
            }
        }
        this.lock.releaseReadLock();
        return z;
    }

    public static AndNode createNextLevelNode(Model model, int i, Node node, Refinement refinement, PatternDiscovery.Config config) {
        Provider<AndNode> andChild = node.getAndChild(refinement);
        if (andChild != null) {
            if (config != null) {
                return null;
            }
            return andChild.get();
        }
        if (node.contains(refinement)) {
            return null;
        }
        SortedMap<Refinement, Provider<? extends Node>> computeNextLevelParents = computeNextLevelParents(model, i, node, refinement, config);
        AndNode andNode = null;
        if (computeNextLevelParents != null) {
            TreeSet treeSet = new TreeSet(computeNextLevelParents.values());
            Iterator it = treeSet.iterator();
            while (it.hasNext()) {
                ((Node) ((Provider) it.next()).get()).lock.acquireWriteLock();
            }
            if (node.andChildren == null || !node.andChildren.containsKey(refinement)) {
                andNode = new AndNode(model, node.level + 1, computeNextLevelParents);
                if (config == null || config.checkValidPattern.evaluate(andNode)) {
                    andNode.init();
                } else {
                    model.removeProvider(andNode.provider);
                    andNode = null;
                }
            }
            Iterator it2 = treeSet.iterator();
            while (it2.hasNext()) {
                ((Node) ((Provider) it2.next()).get()).lock.releaseWriteLock();
            }
        }
        return andNode;
    }

    public static void addNextLevelActivation(Document document, NodeActivation<AndNode> nodeActivation, NodeActivation<AndNode> nodeActivation2, Provider<AndNode> provider) {
        NodeActivation.Key<AndNode> key = nodeActivation.key;
        InterpretationNode add = InterpretationNode.add(document, true, key.interpretation, nodeActivation2.key.interpretation);
        if (add != null) {
            Node.addActivationAndPropagate(document, new NodeActivation.Key(provider.get(document), Range.mergeRange(key.range, nodeActivation2.key.range), Utils.nullSafeMin(key.rid, nodeActivation2.key.rid), add), prepareInputActs(nodeActivation, nodeActivation2));
        }
    }

    @Override // org.aika.lattice.Node
    public void changeNumberOfNeuronRefs(int i, long j, int i2) {
        super.changeNumberOfNeuronRefs(i, j, i2);
        this.parents.values().forEach(provider -> {
            ((Node) provider.get()).changeNumberOfNeuronRefs(i, j, i2);
        });
    }

    public static Collection<NodeActivation<?>> prepareInputActs(NodeActivation<?> nodeActivation, NodeActivation<?> nodeActivation2) {
        ArrayList arrayList = new ArrayList(2);
        arrayList.add(nodeActivation);
        arrayList.add(nodeActivation2);
        return arrayList;
    }

    public static SortedMap<Refinement, Provider<? extends Node>> computeNextLevelParents(Model model, int i, Node node, Refinement refinement, PatternDiscovery.Config config) {
        Collection<Refinement> collectNodeAndRefinements = node.collectNodeAndRefinements(refinement);
        long addAndGet = Model.visitedCounter.addAndGet(1L);
        TreeMap treeMap = new TreeMap();
        for (Refinement refinement2 : collectNodeAndRefinements) {
            TreeSet treeSet = new TreeSet(collectNodeAndRefinements);
            treeSet.remove(refinement2);
            try {
                if (!refinement2.input.get().computeAndParents(model, i, refinement2.getRelativePosition(), treeSet, treeMap, config, addAndGet)) {
                    return null;
                }
            } catch (Node.ThreadState.RidOutOfRange e) {
                return null;
            }
        }
        return treeMap;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    @Override // org.aika.lattice.Node
    public Collection<Refinement> collectNodeAndRefinements(Refinement refinement) {
        ArrayList arrayList = new ArrayList(this.parents.size() + 1);
        arrayList.add(refinement);
        int i = 0;
        Iterator<Refinement> it = this.parents.keySet().iterator();
        while (it.hasNext()) {
            if (it.next().rid != null) {
                i++;
            }
        }
        for (Refinement refinement2 : this.parents.keySet()) {
            if (refinement.rid != null && refinement.rid != null && (refinement.rid.intValue() < 0 || i == 1)) {
                arrayList.add(new Refinement(refinement2.getRelativePosition(), refinement.rid, refinement2.input));
            } else if (refinement2.rid == null || refinement.rid == null || refinement2.getOffset().intValue() >= 0) {
                arrayList.add(refinement2);
            } else {
                arrayList.add(new Refinement(0, Integer.valueOf(Math.min(-refinement2.getOffset().intValue(), refinement.getRelativePosition().intValue())), refinement2.input));
            }
        }
        return arrayList;
    }

    @Override // org.aika.lattice.Node
    public double computeSynapseWeightSum(Integer num, INeuron iNeuron) {
        double d = iNeuron.biasSum;
        Iterator<Refinement> it = this.parents.keySet().iterator();
        while (it.hasNext()) {
            d += Math.abs(it.next().getSynapse(num, (Neuron) iNeuron.provider).weight);
        }
        return d;
    }

    @Override // org.aika.lattice.Node
    protected NodeActivation<AndNode> createActivation(Document document, NodeActivation.Key key) {
        int i = document.activationIdCounter;
        document.activationIdCounter = i + 1;
        return new NodeActivation<>(i, document, key);
    }

    @Override // org.aika.lattice.Node
    public void remove() {
        super.remove();
        for (Map.Entry<Refinement, Provider<? extends Node>> entry : this.parents.entrySet()) {
            Node node = entry.getValue().get();
            node.lock.acquireWriteLock();
            node.removeAndChild(entry.getKey());
            node.setModified();
            node.lock.releaseWriteLock();
        }
    }

    @Override // org.aika.lattice.Node
    public String logicToString() {
        StringBuilder sb = new StringBuilder();
        sb.append("AND[");
        boolean z = true;
        for (Refinement refinement : this.parents.keySet()) {
            if (!z) {
                sb.append(",");
            }
            z = false;
            sb.append(refinement);
        }
        sb.append("]");
        return sb.toString();
    }

    @Override // org.aika.lattice.Node, org.aika.Writable
    public void write(DataOutput dataOutput) throws IOException {
        dataOutput.writeBoolean(false);
        dataOutput.writeChar(65);
        super.write(dataOutput);
        dataOutput.writeInt(this.parents.size());
        for (Map.Entry<Refinement, Provider<? extends Node>> entry : this.parents.entrySet()) {
            entry.getKey().write(dataOutput);
            dataOutput.writeInt(entry.getValue().id.intValue());
        }
    }

    @Override // org.aika.lattice.Node, org.aika.Writable
    public void readFields(DataInput dataInput, Model model) throws IOException {
        super.readFields(dataInput, model);
        int readInt = dataInput.readInt();
        for (int i = 0; i < readInt; i++) {
            this.parents.put(Refinement.read(dataInput, model), model.lookupNodeProvider(dataInput.readInt()));
        }
    }
}
