package org.aika.lattice;

import java.io.DataInput;
import java.io.DataOutput;
import java.io.IOException;
import java.util.Collection;
import java.util.Comparator;
import java.util.Iterator;
import java.util.Map;
import java.util.NavigableMap;
import java.util.Set;
import java.util.SortedSet;
import java.util.TreeMap;
import java.util.TreeSet;
import java.util.concurrent.atomic.AtomicInteger;
import org.aika.AbstractNode;
import org.aika.Model;
import org.aika.Provider;
import org.aika.ReadWriteLock;
import org.aika.Utils;
import org.aika.Writable;
import org.aika.corpus.Document;
import org.aika.corpus.InterpretationNode;
import org.aika.corpus.Range;
import org.aika.lattice.AndNode;
import org.aika.lattice.Node;
import org.aika.lattice.NodeActivation;
import org.aika.lattice.OrNode;
import org.aika.neuron.INeuron;
import org.aika.training.PatternDiscovery;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/aika/lattice/Node.class */
public abstract class Node<T extends Node, A extends NodeActivation<T>> extends AbstractNode<Provider<T>> implements Comparable<Node> {
    public static int MAX_RELATIVE_RID;
    public static final Node MIN_NODE;
    public static final Node MAX_NODE;
    private static final Logger log;
    public TreeMap<ReverseAndRefinement, AndNode.Refinement> reverseAndChildren;
    public TreeMap<AndNode.Refinement, Provider<AndNode>> andChildren;
    public TreeSet<OrNode.OrEntry> orChildren;
    public TreeSet<OrNode.OrEntry> allOrChildren;
    public int level;
    public Writable statistic;
    public boolean isDiscovered;
    volatile boolean isRemoved;
    public ThreadState<T, A>[] threads;
    public static final Comparator<NodeActivation.Key> BEGIN_COMP;
    public static final Comparator<NodeActivation.Key> END_COMP;
    public static final Comparator<NodeActivation.Key> RID_COMP;
    static final /* synthetic */ boolean $assertionsDisabled;
    public AtomicInteger numberOfNeuronRefs = new AtomicInteger(0);
    public ReadWriteLock lock = new ReadWriteLock();

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:org/aika/lattice/Node$ReverseAndRefinement.class */
    public static class ReverseAndRefinement implements Comparable<ReverseAndRefinement> {
        boolean dir;
        Provider node;

        public ReverseAndRefinement(Provider provider, Integer num, Integer num2) {
            this.node = provider;
            this.dir = Utils.compareNullSafe(num, num2);
        }

        @Override // java.lang.Comparable
        public int compareTo(ReverseAndRefinement reverseAndRefinement) {
            int compareTo = this.node.compareTo(reverseAndRefinement.node);
            return compareTo != 0 ? compareTo : Boolean.compare(this.dir, reverseAndRefinement.dir);
        }
    }

    /* loaded from: input_file:org/aika/lattice/Node$RidVisited.class */
    public static class RidVisited {
        public long computeParents = -1;
    }

    /* loaded from: input_file:org/aika/lattice/Node$ThreadState.class */
    public static class ThreadState<T extends Node, A extends NodeActivation<T>> {
        public long lastUsed;
        public long visited;
        public long queueId;
        private RidVisited nullRidVisited;
        public boolean isQueued = false;
        private RidVisited[] ridVisited = new RidVisited[2 * Node.MAX_RELATIVE_RID];
        public NavigableMap<NodeActivation.Key, Set<NodeActivation<?>>> added = new TreeMap();

        /* loaded from: input_file:org/aika/lattice/Node$ThreadState$RidOutOfRange.class */
        public static class RidOutOfRange extends Exception {
            public RidOutOfRange(String str) {
                super(str);
            }
        }

        public RidVisited lookupVisited(Integer num) throws RidOutOfRange {
            if (num != null && (num.intValue() >= Node.MAX_RELATIVE_RID || num.intValue() <= (-Node.MAX_RELATIVE_RID))) {
                Node.log.warn("RID too large:" + num);
                throw new RidOutOfRange("RID too large:" + num);
            }
            if (num == null) {
                if (this.nullRidVisited == null) {
                    this.nullRidVisited = new RidVisited();
                }
                return this.nullRidVisited;
            }
            RidVisited ridVisited = this.ridVisited[num.intValue() + Node.MAX_RELATIVE_RID];
            if (ridVisited == null) {
                ridVisited = new RidVisited();
                this.ridVisited[num.intValue() + Node.MAX_RELATIVE_RID] = ridVisited;
            }
            return ridVisited;
        }
    }

    public ThreadState<T, A> getThreadState(int i, boolean z) {
        ThreadState<T, A> threadState = this.threads[i];
        if (threadState == null) {
            if (!z) {
                return null;
            }
            threadState = new ThreadState<>();
            this.threads[i] = threadState;
        }
        threadState.lastUsed = this.provider.model.docIdCounter.get();
        return threadState;
    }

    abstract A createActivation(Document document, NodeActivation.Key key);

    public abstract void propagateAddedActivation(Document document, A a);

    public abstract boolean isAllowedOption(int i, InterpretationNode interpretationNode, NodeActivation<?> nodeActivation, long j);

    public abstract double computeSynapseWeightSum(Integer num, INeuron iNeuron);

    abstract void apply(Document document, A a);

    public abstract void discover(Document document, NodeActivation<T> nodeActivation, PatternDiscovery.Config config);

    /* JADX INFO: Access modifiers changed from: package-private */
    public abstract Collection<AndNode.Refinement> collectNodeAndRefinements(AndNode.Refinement refinement);

    /* JADX INFO: Access modifiers changed from: package-private */
    public abstract boolean contains(AndNode.Refinement refinement);

    public abstract void cleanup();

    public abstract String logicToString();

    /* JADX INFO: Access modifiers changed from: protected */
    public Node() {
    }

    /* JADX WARN: Type inference failed for: r1v5, types: [P extends org.aika.Provider<? extends org.aika.AbstractNode>, org.aika.Provider] */
    public Node(Model model, int i) {
        this.threads = new ThreadState[model.numberOfThreads];
        this.provider = new Provider(model, this);
        this.level = i;
        setModified();
        if (model.nodeStatisticFactory != null) {
            this.statistic = model.nodeStatisticFactory.createStatisticObject();
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public void addOrChild(OrNode.OrEntry orEntry, boolean z) {
        this.lock.acquireWriteLock();
        if (z) {
            if (this.allOrChildren == null) {
                this.allOrChildren = new TreeSet<>();
            }
            this.allOrChildren.add(orEntry);
        } else {
            if (this.orChildren == null) {
                this.orChildren = new TreeSet<>();
            }
            this.orChildren.add(orEntry);
        }
        this.lock.releaseWriteLock();
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public void removeOrChild(OrNode.OrEntry orEntry, boolean z) {
        this.lock.acquireWriteLock();
        if (z) {
            if (this.allOrChildren != null) {
                this.allOrChildren.remove(orEntry);
                if (this.allOrChildren.isEmpty()) {
                    this.allOrChildren = null;
                }
            }
        } else if (this.orChildren != null) {
            this.orChildren.remove(orEntry);
            if (this.orChildren.isEmpty()) {
                this.orChildren = null;
            }
        }
        this.lock.releaseWriteLock();
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public void addAndChild(AndNode.Refinement refinement, Provider<AndNode> provider) {
        if (this.andChildren == null) {
            this.andChildren = new TreeMap<>();
            this.reverseAndChildren = new TreeMap<>();
        }
        Provider<AndNode> put = this.andChildren.put(refinement, provider);
        if (!$assertionsDisabled && put != null) {
            throw new AssertionError();
        }
        this.reverseAndChildren.put(new ReverseAndRefinement(provider, refinement.rid, 0), refinement);
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public void removeAndChild(AndNode.Refinement refinement) {
        if (this.andChildren != null) {
            this.reverseAndChildren.remove(new ReverseAndRefinement(this.andChildren.remove(refinement), refinement.rid, 0));
            if (this.andChildren.isEmpty()) {
                this.andChildren = null;
                this.reverseAndChildren = null;
            }
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public A processAddedActivation(Document document, NodeActivation.Key<T> key, Collection<NodeActivation> collection) {
        if (Document.APPLY_DEBUG_OUTPUT) {
            log.info("add: " + key + " - " + key.node);
        }
        A createActivation = createActivation(document, key);
        register(createActivation, document);
        createActivation.link(collection);
        propagateAddedActivation(document, createActivation);
        return createActivation;
    }

    public void register(A a, Document document) {
        NodeActivation.Key<T> key = a.key;
        if (key.interpretation.activations == null) {
            key.interpretation.activations = new TreeMap();
        }
        key.interpretation.activations.put(key, a);
    }

    public void processChanges(Document document) {
        ThreadState<T, A> threadState = getThreadState(document.threadId, true);
        NavigableMap<NodeActivation.Key, Set<NodeActivation<?>>> navigableMap = threadState.added;
        threadState.added = new TreeMap();
        navigableMap.forEach((key, collection) -> {
            processAddedActivation(document, key, collection);
        });
    }

    public static <T extends Node, A extends NodeActivation<T>> void addActivationAndPropagate(Document document, NodeActivation.Key<T> key, Collection<NodeActivation<?>> collection) {
        ThreadState<T, A> threadState = key.node.getThreadState(document.threadId, true);
        Set set = (Set) threadState.added.get(key);
        if (set == null) {
            set = new TreeSet();
            threadState.added.put(key, set);
        }
        set.addAll(collection);
        document.queue.add(key.node);
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    /* JADX WARN: Multi-variable type inference failed */
    public boolean computeAndParents(Model model, int i, Integer num, SortedSet<AndNode.Refinement> sortedSet, Map<AndNode.Refinement, Provider<? extends Node>> map, PatternDiscovery.Config config, long j) throws ThreadState.RidOutOfRange {
        RidVisited lookupVisited = getThreadState(i, true).lookupVisited(num);
        if (lookupVisited.computeParents == j) {
            return true;
        }
        lookupVisited.computeParents = j;
        if (sortedSet.size() == 1) {
            map.put(sortedSet.first(), this.provider);
            return true;
        }
        for (AndNode.Refinement refinement : sortedSet) {
            TreeSet treeSet = new TreeSet((SortedSet) sortedSet);
            treeSet.remove(refinement);
            AndNode.Refinement refinement2 = new AndNode.Refinement(refinement.getRelativePosition(), num, refinement.input);
            this.lock.acquireReadLock();
            Provider provider = this.andChildren != null ? this.andChildren.get(refinement2) : null;
            this.lock.releaseReadLock();
            if (provider == null) {
                if (config != null) {
                    return false;
                }
                provider = AndNode.createNextLevelNode(model, i, this, refinement2, config).provider;
                if (provider == null) {
                    return false;
                }
            }
            if (!((AndNode) provider.get()).computeAndParents(model, i, Utils.nullSafeMin(refinement.getRelativePosition(), num), treeSet, map, config, j)) {
                return false;
            }
        }
        return true;
    }

    public void remove() {
        if (!$assertionsDisabled && this.isRemoved) {
            throw new AssertionError();
        }
        this.lock.acquireWriteLock();
        setModified();
        while (this.andChildren != null && !this.andChildren.isEmpty()) {
            this.andChildren.firstEntry().getValue().get().remove();
        }
        while (this.orChildren != null && !this.orChildren.isEmpty()) {
            this.orChildren.pollFirst().node.get().remove();
        }
        this.lock.releaseWriteLock();
        this.isRemoved = true;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public Provider<AndNode> getAndChild(AndNode.Refinement refinement) {
        this.lock.acquireReadLock();
        Provider<AndNode> provider = this.andChildren != null ? this.andChildren.get(refinement) : null;
        this.lock.releaseReadLock();
        return provider;
    }

    public boolean isRequired() {
        return this.numberOfNeuronRefs.get() > 0 || this.isDiscovered;
    }

    public void changeNumberOfNeuronRefs(int i, long j, int i2) {
        ThreadState<T, A> threadState = getThreadState(i, true);
        if (threadState.visited == j) {
            return;
        }
        threadState.visited = j;
        this.numberOfNeuronRefs.addAndGet(i2);
    }

    public String getNeuronLabel() {
        return "";
    }

    public String toString() {
        return getNeuronLabel() + " - " + logicToString() + " - " + weightsToString();
    }

    public String weightsToString() {
        return "";
    }

    /* JADX WARN: Multi-variable type inference failed */
    @Override // java.lang.Comparable
    public int compareTo(Node node) {
        if (this == node) {
            return 0;
        }
        if (this == MIN_NODE) {
            return -1;
        }
        if (node == MIN_NODE || this == MAX_NODE) {
            return 1;
        }
        if (node == MAX_NODE) {
            return -1;
        }
        return this.provider.compareTo(node.provider);
    }

    @Override // org.aika.Writable
    public void write(DataOutput dataOutput) throws IOException {
        dataOutput.writeInt(this.level);
        dataOutput.writeBoolean(this.statistic != null);
        if (this.statistic != null) {
            this.statistic.write(dataOutput);
        }
        dataOutput.writeBoolean(this.isDiscovered);
        dataOutput.writeInt(this.numberOfNeuronRefs.get());
        if (this.andChildren != null) {
            dataOutput.writeInt(this.andChildren.size());
            for (Map.Entry<AndNode.Refinement, Provider<AndNode>> entry : this.andChildren.entrySet()) {
                entry.getKey().write(dataOutput);
                dataOutput.writeInt(entry.getValue().id.intValue());
            }
        } else {
            dataOutput.writeInt(0);
        }
        if (this.orChildren == null) {
            dataOutput.writeInt(0);
            return;
        }
        dataOutput.writeInt(this.orChildren.size());
        Iterator<OrNode.OrEntry> it = this.orChildren.iterator();
        while (it.hasNext()) {
            it.next().write(dataOutput);
        }
    }

    @Override // org.aika.Writable
    public void readFields(DataInput dataInput, Model model) throws IOException {
        this.level = dataInput.readInt();
        if (dataInput.readBoolean()) {
            this.statistic = model.nodeStatisticFactory.createStatisticObject();
            this.statistic.readFields(dataInput, model);
        }
        this.isDiscovered = dataInput.readBoolean();
        this.numberOfNeuronRefs.set(dataInput.readInt());
        int readInt = dataInput.readInt();
        for (int i = 0; i < readInt; i++) {
            addAndChild(AndNode.Refinement.read(dataInput, model), model.lookupNodeProvider(dataInput.readInt()));
        }
        int readInt2 = dataInput.readInt();
        for (int i2 = 0; i2 < readInt2; i2++) {
            if (this.orChildren == null) {
                this.orChildren = new TreeSet<>();
            }
            this.orChildren.add(OrNode.OrEntry.read(dataInput, model));
        }
        this.threads = new ThreadState[model.numberOfThreads];
    }

    /* JADX WARN: Multi-variable type inference failed */
    public static Node readNode(DataInput dataInput, Provider provider) throws IOException {
        Node node = null;
        switch (dataInput.readChar()) {
            case 'A':
                node = new AndNode();
                break;
            case 'I':
                node = new InputNode();
                break;
            case 'O':
                node = new OrNode();
                break;
        }
        node.provider = provider;
        node.readFields(dataInput, provider.model);
        return node;
    }

    static {
        $assertionsDisabled = !Node.class.desiredAssertionStatus();
        MAX_RELATIVE_RID = 25;
        MIN_NODE = new InputNode();
        MAX_NODE = new InputNode();
        log = LoggerFactory.getLogger(Node.class);
        BEGIN_COMP = (key, key2) -> {
            int compare = Range.compare(key.range, key2.range, false);
            if (compare != 0) {
                return compare;
            }
            int compareInteger = Utils.compareInteger(key.rid, key2.rid);
            return compareInteger != 0 ? compareInteger : InterpretationNode.compare(key.interpretation, key2.interpretation);
        };
        END_COMP = (key3, key4) -> {
            int compare = Range.compare(key3.range, key4.range, true);
            if (compare != 0) {
                return compare;
            }
            int compareInteger = Utils.compareInteger(key3.rid, key4.rid);
            return compareInteger != 0 ? compareInteger : InterpretationNode.compare(key3.interpretation, key4.interpretation);
        };
        RID_COMP = (key5, key6) -> {
            int compareInteger = Utils.compareInteger(key5.rid, key6.rid);
            if (compareInteger != 0) {
                return compareInteger;
            }
            int compare = Range.compare(key5.range, key6.range, false);
            return compare != 0 ? compare : InterpretationNode.compare(key5.interpretation, key6.interpretation);
        };
    }
}
