package org.aika.corpus;

import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.TreeMap;
import java.util.TreeSet;
import org.aika.Utils;
import org.aika.corpus.Conflicts;
import org.aika.lattice.NodeActivation;
import org.aika.lattice.OrNode;
import org.aika.neuron.Activation;
import org.aika.neuron.INeuron;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/aika/corpus/SearchNode.class */
public class SearchNode implements Comparable<SearchNode> {
    private static final Logger log = LoggerFactory.getLogger(SearchNode.class);
    public static int MAX_SEARCH_STEPS = 100000;
    public int id;
    public SearchNode excludedParent;
    public SearchNode selectedParent;
    public int visited;
    List<InterprNode> refinement;
    Candidate candidate;
    int level;
    DebugState debugState;
    INeuron.NormWeight weightDelta;
    INeuron.NormWeight accumulatedWeight;
    public List<StateChange> modifiedActs = new ArrayList();

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/aika/corpus/SearchNode$Candidate.class */
    public static class Candidate implements Comparable<Candidate> {
        public InterprNode refinement;
        int id;
        Integer minBegin;
        Integer maxEnd;
        Integer minRid;
        public TreeMap<SearchNode, Boolean> cache = new TreeMap<>();
        int[] debugCounts = new int[3];

        public Candidate(InterprNode interprNode, int i) {
            this.refinement = interprNode;
            if (interprNode.act != null) {
                this.minBegin = interprNode.act.key.r.begin;
                this.maxEnd = interprNode.act.key.r.end;
                this.minRid = interprNode.act.key.rid;
            } else {
                for (NodeActivation nodeActivation : interprNode.getActivations()) {
                    if (nodeActivation.key.r != null) {
                        this.minBegin = Utils.nullSafeMin(this.minBegin, nodeActivation.key.r.begin);
                        this.maxEnd = Utils.nullSafeMax(this.maxEnd, nodeActivation.key.r.end);
                    }
                    this.minRid = Utils.nullSafeMin(this.minRid, nodeActivation.key.rid);
                }
            }
            this.id = i;
        }

        @Override // java.lang.Comparable
        public int compareTo(Candidate candidate) {
            int compareInteger = Utils.compareInteger(this.minBegin, candidate.minBegin);
            if (compareInteger != 0) {
                return compareInteger;
            }
            int compareInteger2 = Utils.compareInteger(candidate.maxEnd, this.maxEnd);
            if (compareInteger2 != 0) {
                return compareInteger2;
            }
            int compareInteger3 = Utils.compareInteger(this.minRid, candidate.minRid);
            return compareInteger3 != 0 ? compareInteger3 : Integer.compare(this.id, candidate.id);
        }
    }

    /* loaded from: input_file:org/aika/corpus/SearchNode$Coverage.class */
    public enum Coverage {
        SELECTED,
        UNKNOWN,
        EXCLUDED
    }

    /* loaded from: input_file:org/aika/corpus/SearchNode$DebugState.class */
    public enum DebugState {
        CACHED,
        LIMITED,
        EXPLORE
    }

    /* loaded from: input_file:org/aika/corpus/SearchNode$StateChange.class */
    public static class StateChange {
        public Activation act;
        public Activation.Rounds oldRounds;
        public Activation.Rounds newRounds;

        /* loaded from: input_file:org/aika/corpus/SearchNode$StateChange$Mode.class */
        public enum Mode {
            OLD,
            NEW
        }

        public static void saveOldState(List<StateChange> list, Activation activation, long j) {
            if (activation.currentStateChange == null || activation.currentStateV != j) {
                StateChange stateChange = new StateChange();
                stateChange.oldRounds = activation.rounds.copy();
                activation.currentStateChange = stateChange;
                activation.currentStateV = j;
                stateChange.act = activation;
                if (list != null) {
                    list.add(stateChange);
                }
            }
        }

        public static void saveNewState(Activation activation) {
            activation.currentStateChange.newRounds = activation.rounds;
        }

        public void restoreState(Mode mode) {
            this.act.rounds = (mode == Mode.OLD ? this.oldRounds : this.newRounds).copy();
        }
    }

    public SearchNode(Document document, SearchNode searchNode, SearchNode searchNode2, Candidate candidate, int i, List<InterprNode> list) {
        this.weightDelta = INeuron.NormWeight.ZERO_WEIGHT;
        int i2 = document.searchNodeIdCounter;
        document.searchNodeIdCounter = i2 + 1;
        this.id = i2;
        this.level = i;
        int i3 = document.visitedCounter;
        document.visitedCounter = i3 + 1;
        this.visited = i3;
        this.selectedParent = searchNode;
        this.excludedParent = searchNode2;
        if (candidate != null) {
            List<InterprNode> singletonList = Collections.singletonList(candidate.refinement);
            int i4 = document.visitedCounter;
            document.visitedCounter = i4 + 1;
            this.refinement = expandRefinement(singletonList, i4);
            this.candidate = candidate;
        }
        this.weightDelta = document.vQueue.adjustWeight(this, list);
        if (getParent() != null) {
            this.accumulatedWeight = this.weightDelta.add(getParent().accumulatedWeight);
        }
        if (Document.OPTIMIZE_DEBUG_OUTPUT) {
            log.info("Search Step: " + this.id + "  Candidate Weight Delta: " + this.weightDelta);
            log.info(document.neuronActivationsToString(true, true, false) + "\n");
        }
    }

    private void collectResults(Collection<InterprNode> collection) {
        if (this.refinement != null) {
            collection.addAll(this.refinement);
        }
        if (this.selectedParent != null) {
            this.selectedParent.collectResults(collection);
        }
    }

    public void computeBestInterpretation(Document document) {
        ArrayList arrayList = new ArrayList();
        arrayList.add(document.bottom);
        int[] iArr = new int[1];
        List<InterprNode> expandRootRefinement = expandRootRefinement(document);
        int i = document.visitedCounter;
        document.visitedCounter = i + 1;
        this.refinement = expandRefinement(expandRootRefinement, i);
        if (Document.OPTIMIZE_DEBUG_OUTPUT) {
            log.info("Root SearchNode:" + toString());
        }
        Candidate[] generateCandidates = generateCandidates(document);
        Candidate candidate = generateCandidates.length > this.level + 1 ? generateCandidates[this.level + 1] : null;
        ArrayList arrayList2 = new ArrayList();
        arrayList2.addAll(this.refinement);
        markSelected(arrayList2, this.refinement);
        markExcluded(arrayList2, this.refinement);
        new SearchNode(document, this, null, candidate, this.level + 1, arrayList2).search(document, iArr, generateCandidates);
        if (document.selectedSearchNode != null) {
            document.selectedSearchNode.reconstructSelectedResult(document);
            document.selectedSearchNode.collectResults(arrayList);
        }
        document.bestInterpretation = arrayList;
        if (document.interrupted) {
            log.warn("The search for the best interpretation has been interrupted. Too many search steps!");
        }
    }

    private void reconstructSelectedResult(Document document) {
        if (this.selectedParent != null) {
            this.selectedParent.reconstructSelectedResult(document);
        }
        changeState(StateChange.Mode.NEW);
        Iterator<StateChange> it = this.modifiedActs.iterator();
        while (it.hasNext()) {
            Activation activation = it.next().act;
            if (activation.finalState != null && activation.finalState.value > 0.0d) {
                document.finallyActivatedNeurons.add(((OrNode) activation.key.n).neuron.get());
            }
        }
    }

    public void dumpDebugState() {
        SearchNode searchNode = this;
        while (true) {
            SearchNode searchNode2 = searchNode;
            if (searchNode2 == null || searchNode2.level < 0) {
                return;
            }
            System.out.println(searchNode2.level + " " + searchNode2.debugState + " CS:" + searchNode2.candidate.cache.size() + " LIMITED:" + searchNode2.candidate.debugCounts[DebugState.LIMITED.ordinal()] + " CACHED:" + searchNode2.candidate.debugCounts[DebugState.CACHED.ordinal()] + " EXPLORE:" + searchNode2.candidate.debugCounts[DebugState.EXPLORE.ordinal()] + " " + searchNode2.candidate.refinement.act.key.r + " " + ((OrNode) searchNode2.candidate.refinement.act.key.n).neuron.get().label);
            searchNode = searchNode2.getParent();
        }
    }

    private double search(Document document, int[] iArr, Candidate[] candidateArr) {
        if (this.candidate == null) {
            return processResult(document);
        }
        double d = 0.0d;
        double d2 = 0.0d;
        boolean checkSelected = checkSelected(this.refinement);
        List<InterprNode> list = this.refinement;
        int i = document.visitedCounter;
        document.visitedCounter = i + 1;
        boolean checkExcluded = checkExcluded(list, i);
        if (iArr[0] > MAX_SEARCH_STEPS) {
            document.interrupted = true;
            dumpDebugState();
        }
        iArr[0] = iArr[0] + 1;
        if (Document.OPTIMIZE_DEBUG_OUTPUT) {
            log.info("Search Step: " + this.id);
            log.info(toString());
        }
        if (Document.OPTIMIZE_DEBUG_OUTPUT) {
            log.info(document.neuronActivationsToString(true, true, false) + "\n");
        }
        if (checkExcluded || checkSelected) {
            this.debugState = DebugState.LIMITED;
        } else {
            this.debugState = DebugState.EXPLORE;
        }
        Boolean cachedDecision = (checkExcluded || checkSelected) ? null : getCachedDecision();
        int[] iArr2 = this.candidate.debugCounts;
        int ordinal = this.debugState.ordinal();
        iArr2[ordinal] = iArr2[ordinal] + 1;
        if (!checkExcluded) {
            ArrayList arrayList = new ArrayList();
            arrayList.add(this.candidate.refinement);
            markSelected(arrayList, this.refinement);
            markExcluded(arrayList, this.refinement);
            if (cachedDecision == null || cachedDecision.booleanValue()) {
                SearchNode searchNode = new SearchNode(document, this, this.excludedParent, candidateArr.length > this.level + 1 ? candidateArr[this.level + 1] : null, this.level + 1, arrayList);
                d = searchNode.search(document, iArr, candidateArr);
                searchNode.changeState(StateChange.Mode.OLD);
            }
        }
        if (document.interrupted) {
            return 0.0d;
        }
        if (!checkSelected) {
            this.candidate.refinement.markedExcludedRefinement = true;
            List singletonList = Collections.singletonList(this.candidate.refinement);
            if (cachedDecision == null || !cachedDecision.booleanValue()) {
                SearchNode searchNode2 = new SearchNode(document, this.selectedParent, this, candidateArr.length > this.level + 1 ? candidateArr[this.level + 1] : null, this.level + 1, singletonList);
                d2 = searchNode2.search(document, iArr, candidateArr);
                searchNode2.changeState(StateChange.Mode.OLD);
            }
            this.candidate.refinement.markedExcludedRefinement = false;
        }
        if (cachedDecision == null && !checkExcluded && !checkSelected) {
            this.candidate.cache.put(this, Boolean.valueOf(d >= d2));
        }
        return Math.max(d, d2);
    }

    private double processResult(Document document) {
        double normWeight = this.accumulatedWeight.getNormWeight();
        if (normWeight > (document.selectedSearchNode != null ? document.selectedSearchNode.accumulatedWeight.getNormWeight() : -1.0d)) {
            document.selectedSearchNode = this;
            InterprNode interprNode = document.bottom;
            int i = document.visitedCounter;
            document.visitedCounter = i + 1;
            interprNode.storeFinalWeight(i);
        }
        return normWeight;
    }

    public Candidate[] generateCandidates(Document document) {
        TreeSet treeSet = new TreeSet();
        int i = 0;
        Iterator<InterprNode> it = collectConflicts(document).iterator();
        while (it.hasNext()) {
            int i2 = i;
            i++;
            treeSet.add(new Candidate(it.next(), i2));
        }
        return (Candidate[]) treeSet.toArray(new Candidate[treeSet.size()]);
    }

    private boolean checkSelected(List<InterprNode> list) {
        Iterator<InterprNode> it = list.iterator();
        while (it.hasNext()) {
            if (!isCovered(it.next().markedSelected)) {
                return false;
            }
        }
        return true;
    }

    private boolean checkExcluded(List<InterprNode> list, int i) {
        Iterator<InterprNode> it = list.iterator();
        while (it.hasNext()) {
            if (checkExcluded(it.next(), i)) {
                return true;
            }
        }
        return false;
    }

    private boolean checkExcluded(InterprNode interprNode, int i) {
        if (interprNode.visitedCheckExcluded == i) {
            return false;
        }
        interprNode.visitedCheckExcluded = i;
        if (isCovered(interprNode.markedExcluded)) {
            return true;
        }
        for (InterprNode interprNode2 : interprNode.parents) {
            if (checkExcluded(interprNode2, i)) {
                return true;
            }
        }
        return false;
    }

    public static Set<InterprNode> collectConflicts(Document document) {
        TreeSet treeSet = new TreeSet();
        document.visitedCounter++;
        for (InterprNode interprNode : document.bottom.children) {
            if (!interprNode.conflicts.primary.isEmpty()) {
                treeSet.add(interprNode);
            }
            Iterator<Conflicts.Conflict> it = interprNode.conflicts.secondary.values().iterator();
            while (it.hasNext()) {
                treeSet.add(it.next().secondary);
            }
        }
        return treeSet;
    }

    private static List<InterprNode> expandRootRefinement(Document document) {
        ArrayList arrayList = new ArrayList();
        arrayList.add(document.bottom);
        for (InterprNode interprNode : document.bottom.children) {
            if ((interprNode.orInterprNodes == null || interprNode.orInterprNodes.isEmpty()) && interprNode.conflicts.primary.isEmpty() && interprNode.conflicts.secondary.isEmpty()) {
                arrayList.add(interprNode);
            }
        }
        return arrayList;
    }

    private List<InterprNode> expandRefinement(List<InterprNode> list, int i) {
        ArrayList arrayList = new ArrayList();
        for (InterprNode interprNode : list) {
            markExpandRefinement(interprNode, i);
            arrayList.add(interprNode);
        }
        Iterator<InterprNode> it = list.iterator();
        while (it.hasNext()) {
            expandRefinementRecursiveStep(arrayList, it.next(), i);
        }
        return list.size() == arrayList.size() ? arrayList : expandRefinement(arrayList, i);
    }

    private void markExpandRefinement(InterprNode interprNode, int i) {
        if (interprNode.markedExpandRefinement == i) {
            return;
        }
        interprNode.markedExpandRefinement = i;
        for (InterprNode interprNode2 : interprNode.parents) {
            markExpandRefinement(interprNode2, i);
        }
    }

    private boolean hasUncoveredConflicts(InterprNode interprNode) {
        if (!interprNode.conflicts.hasConflicts()) {
            return false;
        }
        ArrayList arrayList = new ArrayList();
        Conflicts.collectDirectConflicting(arrayList, interprNode);
        Iterator it = arrayList.iterator();
        while (it.hasNext()) {
            if (!isCovered(((InterprNode) it.next()).markedExcluded)) {
                return true;
            }
        }
        return false;
    }

    private void expandRefinementRecursiveStep(Collection<InterprNode> collection, InterprNode interprNode, int i) {
        if (interprNode.visitedExpandRefinementRecursiveStep == i) {
            return;
        }
        interprNode.visitedExpandRefinementRecursiveStep = i;
        if (interprNode.refByOrInterprNode != null) {
            for (InterprNode interprNode2 : interprNode.refByOrInterprNode) {
                if (interprNode2.markedExpandRefinement != i && !hasUncoveredConflicts(interprNode2) && !isCovered(interprNode2.markedSelected)) {
                    markExpandRefinement(interprNode2, i);
                    collection.add(interprNode2);
                }
            }
        }
        for (InterprNode interprNode3 : interprNode.parents) {
            if (!interprNode3.isBottom()) {
                expandRefinementRecursiveStep(collection, interprNode3, i);
            }
        }
        if (interprNode.isBottom()) {
            return;
        }
        for (InterprNode interprNode4 : interprNode.children) {
            if (interprNode4.visitedExpandRefinementRecursiveStep == i) {
                return;
            }
            boolean z = true;
            InterprNode[] interprNodeArr = interprNode4.parents;
            int length = interprNodeArr.length;
            int i2 = 0;
            while (true) {
                if (i2 >= length) {
                    break;
                }
                InterprNode interprNode5 = interprNodeArr[i2];
                if (interprNode5.visitedExpandRefinementRecursiveStep != i && !isCovered(interprNode5.markedSelected)) {
                    z = false;
                    break;
                }
                i2++;
            }
            if (z) {
                expandRefinementRecursiveStep(collection, interprNode4, i);
            }
        }
    }

    public Coverage getCoverage(InterprNode interprNode) {
        return interprNode.markedExcludedRefinement ? Coverage.EXCLUDED : isCovered(interprNode.markedSelected) ? Coverage.SELECTED : isCovered(interprNode.markedExcluded) ? Coverage.EXCLUDED : Coverage.UNKNOWN;
    }

    public boolean isCovered(int i) {
        SearchNode searchNode = this;
        while (i != searchNode.visited) {
            if (i > searchNode.visited) {
                return false;
            }
            searchNode = searchNode.selectedParent;
            if (searchNode == null) {
                return false;
            }
        }
        return true;
    }

    private void markSelected(List<InterprNode> list, List<InterprNode> list2) {
        Iterator<InterprNode> it = list2.iterator();
        while (it.hasNext()) {
            markSelected(list, it.next());
        }
    }

    private boolean markSelected(List<InterprNode> list, InterprNode interprNode) {
        if (isCovered(interprNode.markedSelected)) {
            return false;
        }
        interprNode.markedSelected = this.visited;
        if (interprNode.isBottom() || list == null) {
            return false;
        }
        list.add(interprNode);
        return false;
    }

    private void markExcluded(List<InterprNode> list, List<InterprNode> list2) {
        Iterator<InterprNode> it = list2.iterator();
        while (it.hasNext()) {
            markExcluded(list, it.next());
        }
    }

    private void markExcluded(List<InterprNode> list, InterprNode interprNode) {
        ArrayList arrayList = new ArrayList();
        Document document = interprNode.doc;
        int i = document.visitedCounter;
        document.visitedCounter = i + 1;
        Conflicts.collectAllConflicting(arrayList, interprNode, i);
        Iterator it = arrayList.iterator();
        while (it.hasNext()) {
            markExcludedRecursiveStep(list, (InterprNode) it.next());
        }
    }

    private void markExcludedRecursiveStep(List<InterprNode> list, InterprNode interprNode) {
        if (isCovered(interprNode.markedExcluded)) {
            return;
        }
        interprNode.markedExcluded = this.visited;
        for (InterprNode interprNode2 : interprNode.children) {
            markExcludedRecursiveStep(list, interprNode2);
        }
        if (interprNode.linkedByLCS != null) {
            for (InterprNode interprNode3 : interprNode.linkedByLCS) {
                if (checkOrNodeExcluded(interprNode3)) {
                    markExcludedRecursiveStep(list, interprNode3);
                }
            }
        }
        if (list != null) {
            list.add(interprNode);
        }
    }

    private boolean checkOrNodeExcluded(InterprNode interprNode) {
        Iterator<InterprNode> it = interprNode.orInterprNodes.values().iterator();
        while (it.hasNext()) {
            if (!isCovered(it.next().markedExcluded)) {
                return false;
            }
        }
        return true;
    }

    public boolean containedInSelectedBranch(InterprNode interprNode) {
        for (InterprNode interprNode2 : interprNode.parents) {
            if (!isCovered(interprNode2.markedSelected)) {
                return false;
            }
        }
        return true;
    }

    public String pathToString(Document document) {
        return (this.selectedParent != null ? this.selectedParent.pathToString(document) : "") + " - " + toString(document);
    }

    public String toString(Document document) {
        TreeSet treeSet = new TreeSet();
        for (InterprNode interprNode : this.refinement) {
            int i = document.interprIdCounter;
            document.interprIdCounter = i + 1;
            interprNode.collectPrimitiveNodes(treeSet, i);
        }
        StringBuilder sb = new StringBuilder();
        Iterator it = treeSet.iterator();
        while (it.hasNext()) {
            sb.append(((InterprNode) it.next()).primId);
            sb.append(", ");
        }
        return sb.toString();
    }

    public void changeState(StateChange.Mode mode) {
        Iterator<StateChange> it = this.modifiedActs.iterator();
        while (it.hasNext()) {
            it.next().restoreState(mode);
        }
    }

    @Override // java.lang.Comparable
    public int compareTo(SearchNode searchNode) {
        return Integer.compare(this.id, searchNode.id);
    }

    public SearchNode getParent() {
        return getDecision() ? this.selectedParent : this.excludedParent;
    }

    private boolean getDecision() {
        return this.excludedParent == null || this.selectedParent.id > this.excludedParent.id;
    }

    public Boolean getCachedDecision() {
        for (Map.Entry<SearchNode, Boolean> entry : this.candidate.cache.entrySet()) {
            SearchNode searchNode = this;
            SearchNode key = entry.getKey();
            do {
                if (searchNode.getDecision() == key.getDecision() || !affectsUnknown(searchNode.getParent())) {
                    searchNode = searchNode.getParent();
                    key = key.getParent();
                }
            } while (searchNode.selectedParent != null);
            this.debugState = DebugState.CACHED;
            return entry.getValue();
        }
        return null;
    }

    public boolean affectsUnknown(SearchNode searchNode) {
        for (InterprNode interprNode : searchNode.refinement) {
            if (interprNode.act != null) {
                Iterator<Activation.SynapseActivation> it = interprNode.act.neuronOutputs.iterator();
                while (it.hasNext()) {
                    Activation.SynapseActivation next = it.next();
                    if (next.s.key.isRecurrent && !next.s.isNegative() && getCoverage(next.output.key.o) == Coverage.UNKNOWN) {
                        return true;
                    }
                }
            }
        }
        return false;
    }
}
