package org.aika.lattice;

import java.io.DataInput;
import java.io.DataOutput;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.Iterator;
import java.util.Map;
import java.util.SortedMap;
import java.util.TreeMap;
import java.util.TreeSet;
import org.aika.Model;
import org.aika.Neuron;
import org.aika.Provider;
import org.aika.TrainConfig;
import org.aika.Utils;
import org.aika.corpus.Document;
import org.aika.corpus.InterprNode;
import org.aika.corpus.Range;
import org.aika.lattice.Node;
import org.aika.lattice.NodeActivation;
import org.aika.neuron.INeuron;
import org.aika.neuron.Synapse;
import org.apache.commons.math3.distribution.BinomialDistribution;
import org.apache.commons.math3.random.RandomGenerator;

/* loaded from: input_file:org/aika/lattice/AndNode.class */
public class AndNode extends Node<AndNode, NodeActivation<AndNode>> {
    private static double SIGNIFICANCE_THRESHOLD;
    public static int MAX_AND_NODE_SIZE;
    public static int MAX_NODES;
    public static int MAX_RID_RANGE;
    SortedMap<Refinement, Provider<? extends Node>> parents;
    public volatile int numberOfPositionsNotify;
    private volatile int frequencyNotify;
    private double weight;
    static final /* synthetic */ boolean $assertionsDisabled;

    /* loaded from: input_file:org/aika/lattice/AndNode$Refinement.class */
    public static class Refinement implements Comparable<Refinement> {
        public static Refinement MIN = new Refinement(null, null);
        public static Refinement MAX = new Refinement(null, null);
        public Integer rid;
        public Provider<InputNode> input;

        private Refinement() {
        }

        public Refinement(Integer num, Provider<InputNode> provider) {
            this.rid = num;
            this.input = provider;
        }

        public Refinement(Integer num, Integer num2, Provider<InputNode> provider) {
            if (num2 == null && num != null) {
                this.rid = 0;
            } else if (num2 == null || num == null) {
                this.rid = null;
            } else {
                this.rid = Integer.valueOf(num.intValue() - num2.intValue());
            }
            this.input = provider;
        }

        public Integer getOffset() {
            if (this.rid != null) {
                return Integer.valueOf(Math.min(0, this.rid.intValue()));
            }
            return null;
        }

        public Integer getRelativePosition() {
            if (this.rid != null) {
                return Integer.valueOf(Math.max(0, this.rid.intValue()));
            }
            return null;
        }

        public Synapse getSynapse(Integer num, Neuron neuron) {
            return this.input.get().getSynapse(Utils.nullSafeAdd(getRelativePosition(), false, num, false), neuron);
        }

        public String toString() {
            return "(" + (this.rid != null ? this.rid + ":" : "") + this.input.get().logicToString() + ")";
        }

        public void write(DataOutput dataOutput) throws IOException {
            dataOutput.writeBoolean(this.rid != null);
            if (this.rid != null) {
                dataOutput.writeInt(this.rid.intValue());
            }
            dataOutput.writeInt(this.input.id.intValue());
        }

        public boolean readFields(DataInput dataInput, Model model) throws IOException {
            if (dataInput.readBoolean()) {
                this.rid = Integer.valueOf(dataInput.readInt());
            }
            this.input = model.lookupNodeProvider(dataInput.readInt());
            return true;
        }

        public static Refinement read(DataInput dataInput, Model model) throws IOException {
            Refinement refinement = new Refinement();
            refinement.readFields(dataInput, model);
            return refinement;
        }

        @Override // java.lang.Comparable
        public int compareTo(Refinement refinement) {
            if (this == MIN || refinement == MAX) {
                return -1;
            }
            if (this == MAX || refinement == MIN) {
                return 1;
            }
            int compareTo = this.input.compareTo((Provider<?>) refinement.input);
            return compareTo != 0 ? compareTo : Utils.compareInteger(this.rid, refinement.rid);
        }
    }

    public AndNode() {
        this.parents = new TreeMap();
        this.weight = -1.0d;
    }

    /* JADX WARN: Multi-variable type inference failed */
    public AndNode(Model model, int i, SortedMap<Refinement, Provider<? extends Node>> sortedMap) {
        super(model, i);
        this.parents = new TreeMap();
        this.weight = -1.0d;
        this.parents = sortedMap;
        model.stat.nodes++;
        int[] iArr = model.stat.nodesPerLevel;
        iArr[i] = iArr[i] + 1;
        this.ridRequired = false;
        for (Map.Entry<Refinement, Provider<? extends Node>> entry : sortedMap.entrySet()) {
            Refinement key = entry.getKey();
            Node node = entry.getValue().get();
            node.addAndChild(key, this.provider);
            node.provider.setModified();
            if (key.rid != null) {
                this.ridRequired = true;
            }
        }
        this.endRequired = false;
    }

    @Override // org.aika.lattice.Node
    public boolean isAllowedOption(int i, InterprNode interprNode, NodeActivation<?> nodeActivation, long j) {
        Node.ThreadState<AndNode, NodeActivation<AndNode>> threadState = getThreadState(i, true);
        if (threadState.visitedAllowedOption == j) {
            return false;
        }
        threadState.visitedAllowedOption = j;
        for (NodeActivation<?> nodeActivation2 : nodeActivation.inputs.values()) {
            if (nodeActivation2.key.n.isAllowedOption(i, interprNode, nodeActivation2, j)) {
                return true;
            }
        }
        return false;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    @Override // org.aika.lattice.Node
    public NodeActivation<AndNode> processAddedActivation(Document document, NodeActivation.Key<AndNode> key, Collection<NodeActivation> collection, boolean z) {
        int i = 0;
        Iterator<NodeActivation> it = collection.iterator();
        while (it.hasNext()) {
            if (!it.next().isRemoved) {
                i++;
            }
        }
        if (i != this.level) {
            return null;
        }
        return super.processAddedActivation(document, key, collection, z);
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public void addActivation(Document document, NodeActivation.Key key, Collection<NodeActivation<?>> collection) {
        Node.addActivationAndPropagate(document, key, collection);
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static void removeActivation(Document document, NodeActivation<?> nodeActivation) {
        for (NodeActivation nodeActivation2 : nodeActivation.outputs.values()) {
            if (nodeActivation2.key.n instanceof AndNode) {
                Node.removeActivationAndPropagate(document, nodeActivation2, Collections.singleton(nodeActivation));
            }
        }
    }

    @Override // org.aika.lattice.Node
    public void propagateAddedActivation(Document document, NodeActivation<AndNode> nodeActivation, InterprNode interprNode) {
        apply(document, nodeActivation, interprNode);
    }

    @Override // org.aika.lattice.Node
    public void propagateRemovedActivation(Document document, NodeActivation nodeActivation) {
        removeFromNextLevel(document, nodeActivation);
    }

    @Override // org.aika.lattice.Node
    boolean hasSupport(NodeActivation<AndNode> nodeActivation) {
        int size = this.parents.size();
        int i = 0;
        NodeActivation nodeActivation2 = null;
        for (NodeActivation nodeActivation3 : nodeActivation.inputs.values()) {
            if (!nodeActivation3.isRemoved && (nodeActivation2 == null || nodeActivation2.key.n != nodeActivation3.key.n)) {
                i++;
            }
            nodeActivation2 = nodeActivation3;
        }
        if ($assertionsDisabled || i <= size) {
            return i == size;
        }
        throw new AssertionError();
    }

    @Override // org.aika.lattice.Node
    public void computeNullHyp(Model model) {
        double d = this.sizeSum / this.instanceSum;
        double d2 = (model.numberOfPositions - this.nOffset) / d;
        double d3 = 0.0d;
        for (Map.Entry<Refinement, Provider<? extends Node>> entry : this.parents.entrySet()) {
            Node node = entry.getValue().get();
            InputNode inputNode = entry.getKey().input.get();
            d3 = Math.max(d3, Math.min(1.0d, inputNode.frequency / ((model.numberOfPositions - inputNode.nOffset) / d)) * Math.min(1.0d, Math.max(node.frequency, node.nullHypFreq) / ((model.numberOfPositions - node.nOffset) / d)));
        }
        this.nullHypFreq = d3 * d2;
    }

    /* JADX WARN: Multi-variable type inference failed */
    public void updateWeight(Document document, TrainConfig trainConfig, long j) {
        Node.ThreadState<AndNode, NodeActivation<AndNode>> threadState = getThreadState(document.threadId, true);
        Model model = document.m;
        if (model.numberOfPositions - this.nOffset == 0 || !trainConfig.patternEvaluation.evaluate(this) || threadState.visitedComputeWeight == j) {
            return;
        }
        if (this.numberOfPositionsNotify <= model.numberOfPositions || this.frequencyNotify <= this.frequency || Math.abs(this.nullHypFreq - this.oldNullHypFreq) >= 0.01d) {
            threadState.visitedComputeWeight = j;
            double d = (model.numberOfPositions - this.nOffset) / (this.sizeSum / this.instanceSum);
            document.m.numberOfPositionsQueue.remove(this.provider);
            this.numberOfPositionsNotify = computeNotify(d) + model.numberOfPositions;
            document.m.numberOfPositionsQueue.add(this.provider);
            this.weight = new BinomialDistribution((RandomGenerator) null, (int) Math.round(d), this.nullHypFreq / d).cumulativeProbability(this.frequency - 1);
            this.frequencyNotify = computeNotify(this.frequency) + this.frequency;
            this.oldNullHypFreq = this.nullHypFreq;
            if (this.weight >= SIGNIFICANCE_THRESHOLD) {
            }
        }
    }

    public int computeNotify(double d) {
        return 1 + ((int) Math.floor(Math.pow(d, 1.15d) - d));
    }

    @Override // org.aika.lattice.Node
    public void cleanup(Model model) {
        if (this.isRemoved || isRequired()) {
            return;
        }
        remove(model);
        Iterator<Provider<? extends Node>> it = this.parents.values().iterator();
        while (it.hasNext()) {
            it.next().get().cleanup(model);
        }
    }

    @Override // org.aika.lattice.Node
    void apply(Document document, NodeActivation<AndNode> nodeActivation, InterprNode interprNode) {
        Refinement refinement;
        Provider<AndNode> andChild;
        if (nodeActivation.isRemoved) {
            return;
        }
        for (NodeActivation nodeActivation2 : nodeActivation.inputs.values()) {
            T t = nodeActivation2.key.n;
            t.lock.acquireReadLock();
            Refinement refinement2 = t.reverseAndChildren.get(new Node.ReverseAndRefinement(nodeActivation.key.n.provider, nodeActivation.key.rid, nodeActivation2.key.rid));
            if (refinement2 != null) {
                for (NodeActivation nodeActivation3 : nodeActivation2.outputs.values()) {
                    if (nodeActivation != nodeActivation3 && !nodeActivation3.isRemoved && (refinement = t.reverseAndChildren.get(new Node.ReverseAndRefinement(nodeActivation3.key.n.provider, nodeActivation3.key.rid, nodeActivation2.key.rid))) != null && (andChild = getAndChild(new Refinement(refinement.rid, refinement2.getOffset(), refinement.input))) != null) {
                        addNextLevelActivation(document, nodeActivation, nodeActivation3, andChild, interprNode);
                    }
                }
            }
            t.lock.releaseReadLock();
        }
        if (interprNode == null) {
            OrNode.processCandidate(document, this, nodeActivation, false);
        }
    }

    @Override // org.aika.lattice.Node
    public void discover(Document document, NodeActivation<AndNode> nodeActivation, TrainConfig trainConfig) {
        for (NodeActivation nodeActivation2 : nodeActivation.inputs.values()) {
            T t = nodeActivation2.key.n;
            t.lock.acquireReadLock();
            Refinement refinement = t.reverseAndChildren.get(new Node.ReverseAndRefinement(nodeActivation.key.n.provider, nodeActivation.key.rid, nodeActivation2.key.rid));
            for (NodeActivation nodeActivation3 : nodeActivation2.outputs.values()) {
                if (nodeActivation3.key.n instanceof AndNode) {
                    AndNode andNode = nodeActivation3.key.n;
                    Integer nullSafeSub = Utils.nullSafeSub(nodeActivation.key.rid, false, nodeActivation3.key.rid, false);
                    if (nodeActivation != nodeActivation3 && trainConfig.patternEvaluation.evaluate(andNode) && (nullSafeSub == null || nullSafeSub.intValue() < MAX_RID_RANGE)) {
                        Refinement refinement2 = t.reverseAndChildren.get(new Node.ReverseAndRefinement(nodeActivation3.key.n.provider, nodeActivation3.key.rid, nodeActivation2.key.rid));
                        AndNode createNextLevelNode = createNextLevelNode(document.m, document.threadId, this, new Refinement(refinement2.rid, refinement.getOffset(), refinement2.input), true);
                        if (createNextLevelNode != null) {
                            createNextLevelNode.isDiscovered = true;
                            document.addedNodes.add(createNextLevelNode);
                        }
                    }
                }
            }
            t.lock.releaseReadLock();
        }
    }

    @Override // org.aika.lattice.Node
    public boolean isExpandable() {
        return this.parents.size() < MAX_AND_NODE_SIZE;
    }

    private static boolean checkRidRange(SortedMap<Refinement, Provider<? extends Node>> sortedMap) {
        int i = 0;
        for (Refinement refinement : sortedMap.keySet()) {
            if (refinement.rid != null) {
                i = Math.max(i, refinement.rid.intValue());
            }
        }
        return i < MAX_RID_RANGE;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    @Override // org.aika.lattice.Node
    public boolean contains(Refinement refinement) {
        if (refinement.rid != null && refinement.rid.intValue() < 0) {
            return false;
        }
        boolean z = false;
        this.lock.acquireReadLock();
        if (refinement.rid == null || refinement.rid.intValue() > 0) {
            z = this.parents.containsKey(refinement);
        } else if (refinement.rid.intValue() == 0) {
            Iterator<Refinement> it = this.parents.keySet().iterator();
            while (true) {
                if (!it.hasNext()) {
                    break;
                }
                Refinement next = it.next();
                if (next.rid == null || next.rid.intValue() <= 0) {
                    if (next.input == refinement.input) {
                        z = true;
                        break;
                    }
                }
            }
        }
        this.lock.releaseReadLock();
        return z;
    }

    public static AndNode createNextLevelNode(Model model, int i, Node node, Refinement refinement, boolean z) {
        Provider<AndNode> andChild = node.getAndChild(refinement);
        if (andChild != null) {
            if (z) {
                return null;
            }
            return andChild.get();
        }
        if (node.contains(refinement)) {
            return null;
        }
        SortedMap<Refinement, Provider<? extends Node>> computeNextLevelParents = computeNextLevelParents(model, i, node, refinement, z);
        AndNode andNode = null;
        if (computeNextLevelParents != null && (!z || checkRidRange(computeNextLevelParents))) {
            TreeSet treeSet = new TreeSet(computeNextLevelParents.values());
            Iterator it = treeSet.iterator();
            while (it.hasNext()) {
                ((Node) ((Provider) it.next()).get()).lock.acquireWriteLock();
            }
            if (node.andChildren == null || !node.andChildren.containsKey(refinement)) {
                andNode = new AndNode(model, node.level + 1, computeNextLevelParents);
            }
            Iterator it2 = treeSet.iterator();
            while (it2.hasNext()) {
                ((Node) ((Provider) it2.next()).get()).lock.releaseWriteLock();
            }
        }
        return andNode;
    }

    public static void addNextLevelActivation(Document document, NodeActivation<AndNode> nodeActivation, NodeActivation<AndNode> nodeActivation2, Provider<AndNode> provider, InterprNode interprNode) {
        NodeActivation.Key<AndNode> key = nodeActivation.key;
        InterprNode add = InterprNode.add(document, true, key.o, nodeActivation2.key.o);
        if (add != null) {
            if (interprNode == null || add.contains(interprNode, false)) {
                AndNode andNode = provider.get();
                andNode.addActivation(document, new NodeActivation.Key(andNode, Range.mergeRange(key.r, nodeActivation2.key.r), Utils.nullSafeMin(key.rid, nodeActivation2.key.rid), add), prepareInputActs(nodeActivation, nodeActivation2));
            }
        }
    }

    @Override // org.aika.lattice.Node
    public void changeNumberOfNeuronRefs(int i, long j, int i2) {
        Node.ThreadState<AndNode, NodeActivation<AndNode>> threadState = getThreadState(i, true);
        if (threadState.visitedNeuronRefsChange == j) {
            return;
        }
        threadState.visitedNeuronRefsChange = j;
        this.numberOfNeuronRefs += i2;
        Iterator<Provider<? extends Node>> it = this.parents.values().iterator();
        while (it.hasNext()) {
            it.next().get().changeNumberOfNeuronRefs(i, j, i2);
        }
    }

    public static Collection<NodeActivation<?>> prepareInputActs(NodeActivation<?> nodeActivation, NodeActivation<?> nodeActivation2) {
        ArrayList arrayList = new ArrayList(2);
        arrayList.add(nodeActivation);
        arrayList.add(nodeActivation2);
        return arrayList;
    }

    public static SortedMap<Refinement, Provider<? extends Node>> computeNextLevelParents(Model model, int i, Node node, Refinement refinement, boolean z) {
        Collection<Refinement> collectNodeAndRefinements = node.collectNodeAndRefinements(refinement);
        long j = visitedCounter;
        visitedCounter = j + 1;
        TreeMap treeMap = new TreeMap();
        for (Refinement refinement2 : collectNodeAndRefinements) {
            TreeSet treeSet = new TreeSet(collectNodeAndRefinements);
            treeSet.remove(refinement2);
            try {
                if (!refinement2.input.get().computeAndParents(model, i, refinement2.getRelativePosition(), treeSet, treeMap, z, j)) {
                    return null;
                }
            } catch (Node.ThreadState.RidOutOfRange e) {
                return null;
            }
        }
        return treeMap;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    @Override // org.aika.lattice.Node
    public Collection<Refinement> collectNodeAndRefinements(Refinement refinement) {
        ArrayList arrayList = new ArrayList(this.parents.size() + 1);
        arrayList.add(refinement);
        int i = 0;
        Iterator<Refinement> it = this.parents.keySet().iterator();
        while (it.hasNext()) {
            if (it.next().rid != null) {
                i++;
            }
        }
        for (Refinement refinement2 : this.parents.keySet()) {
            if (refinement.rid != null && refinement.rid != null && (refinement.rid.intValue() < 0 || i == 1)) {
                arrayList.add(new Refinement(refinement2.getRelativePosition(), refinement.rid, refinement2.input));
            } else if (refinement2.rid == null || refinement.rid == null || refinement2.getOffset().intValue() >= 0) {
                arrayList.add(refinement2);
            } else {
                arrayList.add(new Refinement(0, Integer.valueOf(Math.min(-refinement2.getOffset().intValue(), refinement.getRelativePosition().intValue())), refinement2.input));
            }
        }
        return arrayList;
    }

    @Override // org.aika.lattice.Node
    public boolean isCovered(int i, Integer num, long j) throws Node.ThreadState.RidOutOfRange {
        for (Map.Entry<Refinement, Provider<? extends Node>> entry : this.parents.entrySet()) {
            if (entry.getValue().get().getThreadState(i, true).lookupVisited(Utils.nullSafeSub(num, true, entry.getKey().getOffset(), false)).outputNode == j) {
                return true;
            }
        }
        return false;
    }

    @Override // org.aika.lattice.Node
    public double computeSynapseWeightSum(Integer num, INeuron iNeuron) {
        double d = iNeuron.bias;
        while (this.parents.keySet().iterator().hasNext()) {
            d += Math.abs(r0.next().getSynapse(num, (Neuron) iNeuron.provider).w);
        }
        return d;
    }

    @Override // org.aika.lattice.Node
    protected NodeActivation<AndNode> createActivation(Document document, NodeActivation.Key key, boolean z) {
        int i = document.activationIdCounter;
        document.activationIdCounter = i + 1;
        NodeActivation<AndNode> nodeActivation = new NodeActivation<>(i, document, key);
        nodeActivation.isTrainingAct = z;
        return nodeActivation;
    }

    @Override // org.aika.lattice.Node
    public void deleteActivation(Document document, NodeActivation<AndNode> nodeActivation) {
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    @Override // org.aika.lattice.Node
    public void remove(Model model) {
        super.remove(model);
        for (Map.Entry<Refinement, Provider<? extends Node>> entry : this.parents.entrySet()) {
            Node node = entry.getValue().get();
            node.lock.acquireWriteLock();
            node.removeAndChild(entry.getKey());
            node.provider.setModified();
            node.lock.releaseWriteLock();
        }
    }

    @Override // org.aika.lattice.Node
    public String logicToString() {
        StringBuilder sb = new StringBuilder();
        sb.append("AND[");
        boolean z = true;
        for (Refinement refinement : this.parents.keySet()) {
            if (!z) {
                sb.append(",");
            }
            z = false;
            sb.append(refinement);
        }
        sb.append("]");
        return sb.toString();
    }

    @Override // org.aika.lattice.Node
    public String weightsToString() {
        return " -  F:" + this.frequency + "  BW:" + Utils.round(this.weight);
    }

    @Override // org.aika.lattice.Node, org.aika.Writable
    public void write(DataOutput dataOutput) throws IOException {
        dataOutput.writeBoolean(false);
        dataOutput.writeUTF("A");
        super.write(dataOutput);
        dataOutput.writeInt(this.numberOfPositionsNotify);
        dataOutput.writeInt(this.frequencyNotify);
        dataOutput.writeDouble(this.weight);
        dataOutput.writeInt(this.parents.size());
        for (Map.Entry<Refinement, Provider<? extends Node>> entry : this.parents.entrySet()) {
            entry.getKey().write(dataOutput);
            dataOutput.writeInt(entry.getValue().id.intValue());
        }
    }

    @Override // org.aika.lattice.Node, org.aika.Writable
    public void readFields(DataInput dataInput, Model model) throws IOException {
        super.readFields(dataInput, model);
        this.numberOfPositionsNotify = dataInput.readInt();
        this.frequencyNotify = dataInput.readInt();
        this.weight = dataInput.readDouble();
        int readInt = dataInput.readInt();
        for (int i = 0; i < readInt; i++) {
            this.parents.put(Refinement.read(dataInput, model), model.lookupNodeProvider(dataInput.readInt()));
        }
    }

    static {
        $assertionsDisabled = !AndNode.class.desiredAssertionStatus();
        SIGNIFICANCE_THRESHOLD = 0.98d;
        MAX_AND_NODE_SIZE = 4;
        MAX_NODES = 4;
        MAX_RID_RANGE = 5;
    }
}
