package network.aika.neuron.activation;

import java.util.ArrayList;
import java.util.Iterator;
import java.util.Map;
import java.util.TreeMap;
import java.util.TreeSet;
import network.aika.Document;
import network.aika.Utils;
import network.aika.neuron.activation.Activation;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:network/aika/neuron/activation/SearchNode.class */
public class SearchNode implements Comparable<SearchNode> {
    private static final Logger log = LoggerFactory.getLogger(SearchNode.class);
    public static int MAX_SEARCH_STEPS = Integer.MAX_VALUE;
    public static boolean ENABLE_CACHING = true;
    public static boolean OPTIMIZE_SEARCH = true;
    public static boolean COMPUTE_SOFT_MAX = false;
    public int id;
    SearchNode excludedParent;
    SearchNode selectedParent;
    long visited;
    public Candidate candidate;
    int level;
    DebugState debugState;
    double weightDelta;
    public double accumulatedWeight;
    private Decision preDecision;
    private long processVisited;
    private boolean bestPath;
    public Map<Activation, Activation.StateChange> modifiedActs = new TreeMap(Activation.ACTIVATION_ID_COMP);
    private Step step = Step.INIT;
    private SearchNode selectedChild = null;
    private SearchNode excludedChild = null;
    private double selectedWeight = 0.0d;
    private double excludedWeight = 0.0d;
    private double selectedWeightExpSum = 0.0d;
    private double excludedWeightExpSum = 0.0d;
    private Decision skip = Decision.UNKNOWN;

    /* loaded from: input_file:network/aika/neuron/activation/SearchNode$DebugState.class */
    public enum DebugState {
        CACHED,
        LIMITED,
        EXPLORE
    }

    /* loaded from: input_file:network/aika/neuron/activation/SearchNode$Decision.class */
    public enum Decision {
        SELECTED('S'),
        EXCLUDED('E'),
        UNKNOWN('U');

        char s;

        Decision(char c) {
            this.s = c;
        }
    }

    /* loaded from: input_file:network/aika/neuron/activation/SearchNode$SkipSelectStep.class */
    public interface SkipSelectStep {
        boolean evaluate(Activation activation);
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:network/aika/neuron/activation/SearchNode$Step.class */
    public enum Step {
        INIT,
        PREPARE_SELECT,
        SELECT,
        POST_SELECT,
        PREPARE_EXCLUDE,
        EXCLUDE,
        POST_EXCLUDE,
        FINAL
    }

    /* loaded from: input_file:network/aika/neuron/activation/SearchNode$TimeoutException.class */
    public static class TimeoutException extends RuntimeException {
        public TimeoutException(String str) {
            super(str);
        }
    }

    public SearchNode(Document document, SearchNode searchNode, SearchNode searchNode2, int i) {
        this.accumulatedWeight = 0.0d;
        int i2 = document.searchNodeIdCounter;
        document.searchNodeIdCounter = i2 + 1;
        this.id = i2;
        this.level = i;
        long j = document.visitedCounter;
        document.visitedCounter = j + 1;
        this.visited = j;
        this.selectedParent = searchNode;
        this.excludedParent = searchNode2;
        Candidate candidate = getParent() != null ? getParent().candidate : null;
        SearchNode searchNode3 = null;
        boolean z = true;
        if (candidate != null) {
            candidate.currentSearchNode = this;
            searchNode3 = candidate.cachedSearchNode;
            if (searchNode3 == null || searchNode3.getDecision() != getDecision()) {
                Activation activation = candidate.activation;
                activation.markDirty(this.visited);
                activation.getOutputLinks(false).forEach(link -> {
                    link.output.markDirty(this.visited);
                });
            } else {
                z = searchNode3.isModified();
                if (z) {
                    int[] iArr = candidate.debugComputed;
                    iArr[2] = iArr[2] + 1;
                }
            }
        }
        if (z) {
            this.weightDelta = document.vQueue.process(this);
            markDirty();
            if (candidate != null) {
                candidate.cachedSearchNode = this;
            }
        } else if (ENABLE_CACHING) {
            candidate.cachedSearchNode.changeState(Activation.Mode.NEW);
            this.weightDelta = candidate.cachedSearchNode.weightDelta;
            for (Activation activation2 : candidate.cachedSearchNode.modifiedActs.keySet()) {
                Map<Activation, Activation.StateChange> map = this.modifiedActs;
                long j2 = document.visitedCounter;
                document.visitedCounter = j2 + 1;
                activation2.saveOldState(map, j2);
                activation2.saveNewState();
            }
        } else {
            this.weightDelta = document.vQueue.process(this);
            if (Math.abs(this.weightDelta - searchNode3.weightDelta) > 1.0E-5d || !compareNewState(searchNode3)) {
                log.error("Cached search node activation do not match the newly computed results.");
                log.info("Computed results:");
                dumpDebugState();
                log.info("Cached results:");
                searchNode3.dumpDebugState();
            }
        }
        if (candidate != null) {
            int[] iArr2 = candidate.debugComputed;
            char c = z ? (char) 1 : (char) 0;
            iArr2[c] = iArr2[c] + 1;
        }
        if (getParent() != null) {
            this.accumulatedWeight = this.weightDelta + getParent().accumulatedWeight;
        }
    }

    private boolean isModified() {
        for (Activation.StateChange stateChange : this.modifiedActs.values()) {
            if (stateChange.getActivation().markedDirty > this.visited || stateChange.newState != stateChange.getActivation().decision) {
                return true;
            }
            if (stateChange.newRounds.isActive() && stateChange.getActivation().getOutputLinks(false).anyMatch(link -> {
                return link.output.decision != Decision.UNKNOWN && link.output.markedDirty > this.visited;
            })) {
                return true;
            }
        }
        return false;
    }

    private void markDirty() {
        if (getParent() == null || getParent().candidate == null) {
            return;
        }
        SearchNode searchNode = getParent().candidate.cachedSearchNode;
        TreeSet treeSet = new TreeSet(Activation.ACTIVATION_ID_COMP);
        treeSet.addAll(this.modifiedActs.keySet());
        if (searchNode != null) {
            treeSet.addAll(searchNode.modifiedActs.keySet());
        }
        treeSet.forEach(activation -> {
            Activation.StateChange stateChange = this.modifiedActs.get(activation);
            Activation.StateChange stateChange2 = searchNode != null ? searchNode.modifiedActs.get(activation) : null;
            if (stateChange == null || stateChange2 == null || !stateChange.newRounds.compare(stateChange2.newRounds)) {
                activation.getOutputLinks(false).forEach(link -> {
                    link.output.markDirty(this.visited);
                });
            }
        });
    }

    public boolean compareNewState(SearchNode searchNode) {
        if (this.modifiedActs == null && searchNode.modifiedActs == null) {
            return true;
        }
        if (this.modifiedActs == null || searchNode.modifiedActs == null || this.modifiedActs.size() != searchNode.modifiedActs.size()) {
            return false;
        }
        for (Map.Entry<Activation, Activation.StateChange> entry : this.modifiedActs.entrySet()) {
            if (!entry.getValue().newRounds.compare(searchNode.modifiedActs.get(entry.getKey()).newRounds)) {
                return false;
            }
        }
        return true;
    }

    public void dumpDebugState() {
        String str = "";
        Decision decision = Decision.UNKNOWN;
        for (SearchNode searchNode = this; searchNode != null && searchNode.level >= 0; searchNode = searchNode.getParent()) {
            System.out.println(searchNode.level + " " + searchNode.debugState + " DECISION:" + decision + str + " " + (searchNode.candidate != null ? searchNode.candidate.toString() : "") + " MOD-ACTS:" + searchNode.modifiedActs.size());
            decision = searchNode.getDecision();
            str = " AW:" + Utils.round(searchNode.accumulatedWeight) + " DW:" + Utils.round(searchNode.weightDelta);
        }
    }

    /* JADX WARN: Can't fix incorrect switch cases order, some code will duplicate */
    /* JADX WARN: Failed to find 'out' block for switch in B:6:0x0032. Please report as an issue. */
    public static void search(Document document, SearchNode searchNode, long j, Long l) throws TimeoutException {
        SearchNode searchNode2 = searchNode;
        double d = 0.0d;
        double d2 = 0.0d;
        long currentTimeMillis = System.currentTimeMillis();
        do {
            if (searchNode2.processVisited != j) {
                searchNode2.step = Step.INIT;
                searchNode2.processVisited = j;
            }
            switch (searchNode2.step) {
                case INIT:
                    if (searchNode2.level < document.candidates.size()) {
                        searchNode2.initStep(document);
                        searchNode2.step = Step.PREPARE_SELECT;
                    } else {
                        if (l != null && System.currentTimeMillis() > currentTimeMillis + l.longValue()) {
                            throw new TimeoutException("Interpretation search took too long: " + (System.currentTimeMillis() - currentTimeMillis) + "ms");
                        }
                        d = searchNode2.processResult(document);
                        d2 = Math.exp(d);
                        searchNode2.step = Step.FINAL;
                        searchNode2 = searchNode2.getParent();
                    }
                    break;
                case PREPARE_SELECT:
                    searchNode2.step = searchNode2.prepareSelectStep(document) ? Step.SELECT : Step.PREPARE_EXCLUDE;
                    break;
                case SELECT:
                    searchNode2.step = Step.POST_SELECT;
                    searchNode2 = searchNode2.selectedChild;
                    break;
                case POST_SELECT:
                    searchNode2.selectedWeight = d;
                    searchNode2.selectedWeightExpSum = d2;
                    searchNode2.storeSearchState(d2);
                    searchNode2.postReturn(searchNode2.selectedChild);
                    searchNode2.step = Step.PREPARE_EXCLUDE;
                    break;
                case PREPARE_EXCLUDE:
                    searchNode2.step = searchNode2.prepareExcludeStep(document) ? Step.EXCLUDE : Step.FINAL;
                    break;
                case EXCLUDE:
                    searchNode2.step = Step.POST_EXCLUDE;
                    searchNode2 = searchNode2.excludedChild;
                    break;
                case POST_EXCLUDE:
                    searchNode2.excludedWeight = d;
                    searchNode2.excludedWeightExpSum = d2;
                    searchNode2.postReturn(searchNode2.excludedChild);
                    searchNode2.step = (searchNode2.candidate.repeat && OPTIMIZE_SEARCH) ? Step.PREPARE_SELECT : Step.FINAL;
                    break;
                case FINAL:
                    d = searchNode2.finalStep();
                    d2 = searchNode2.getWeightExpSum();
                    SearchNode parent = searchNode2.getParent();
                    if (parent != null) {
                        parent.skip = searchNode2.getDecision();
                    }
                    searchNode2 = parent;
                    break;
            }
        } while (searchNode2 != null);
    }

    public double getWeightExpSum() {
        return this.selectedWeightExpSum + this.excludedWeightExpSum;
    }

    private void initStep(Document document) {
        this.candidate = document.candidates.get(this.level);
        boolean isActiveable = this.candidate.activation.isActiveable();
        this.preDecision = this.candidate.activation.inputDecision;
        if (this.preDecision == Decision.UNKNOWN && (!isActiveable || checkExcluded(this.candidate.activation))) {
            this.preDecision = Decision.EXCLUDED;
        }
        if (this.preDecision == Decision.UNKNOWN && !this.candidate.isConflicting()) {
            this.preDecision = Decision.SELECTED;
        }
        if (this.preDecision == Decision.UNKNOWN && OPTIMIZE_SEARCH) {
            Decision cachedDecision = getCachedDecision();
            switch (cachedDecision) {
                case SELECTED:
                    this.excludedWeightExpSum = this.candidate.alternativeCachedWeightExpSum;
                    break;
                case EXCLUDED:
                    this.selectedWeightExpSum = this.candidate.alternativeCachedWeightExpSum;
                    break;
            }
            this.preDecision = cachedDecision;
        }
        if (document.searchStepCounter > MAX_SEARCH_STEPS) {
            dumpDebugState();
            throw new RuntimeException("Max search step exceeded.");
        }
        document.searchStepCounter++;
        storeDebugInfos();
    }

    private Decision getCachedDecision() {
        return this.preDecision != Decision.EXCLUDED ? this.candidate.cachedDecision : Decision.UNKNOWN;
    }

    private boolean prepareSelectStep(Document document) {
        this.candidate.repeat = false;
        if (this.preDecision == Decision.EXCLUDED || this.skip == Decision.SELECTED || document.model.getSkipSelectStep().evaluate(this.candidate.activation)) {
            return false;
        }
        this.candidate.activation.setDecision(Decision.SELECTED, this.visited);
        if (this.candidate.cachedDecision == Decision.UNKNOWN) {
            invalidateCachedDecisions();
        }
        this.selectedChild = new SearchNode(document, this, this.excludedParent, this.level + 1);
        int[] iArr = this.candidate.debugDecisionCounts;
        iArr[0] = iArr[0] + 1;
        return true;
    }

    private boolean prepareExcludeStep(Document document) {
        if (this.preDecision == Decision.SELECTED || this.skip == Decision.EXCLUDED) {
            return false;
        }
        if (this.preDecision != Decision.EXCLUDED && generatesUnsuppressedExcluded()) {
            return false;
        }
        this.candidate.activation.setDecision(Decision.EXCLUDED, this.visited);
        this.excludedChild = new SearchNode(document, this.selectedParent, this, this.level + 1);
        int[] iArr = this.candidate.debugDecisionCounts;
        iArr[1] = iArr[1] + 1;
        return true;
    }

    private boolean generatesUnsuppressedExcluded() {
        for (Activation activation : this.candidate.activation.getConflicts()) {
            if (activation.decision == Decision.EXCLUDED) {
                for (Activation activation2 : activation.getConflicts()) {
                    if (this.candidate.activation == activation2 || activation2.decision == Decision.EXCLUDED) {
                    }
                }
                return true;
            }
        }
        Iterator<Activation> it = this.candidate.activation.getConflicts().iterator();
        while (it.hasNext()) {
            if (it.next().decision != Decision.EXCLUDED) {
                return false;
            }
        }
        return true;
    }

    private void postReturn(SearchNode searchNode) {
        searchNode.changeState(Activation.Mode.OLD);
        this.candidate.activation.setDecision(Decision.UNKNOWN, this.visited);
        this.candidate.activation.rounds.reset();
    }

    private double finalStep() {
        Decision decision;
        Decision cachedDecision = getCachedDecision();
        if (cachedDecision == Decision.UNKNOWN) {
            decision = this.preDecision != Decision.UNKNOWN ? this.preDecision : this.selectedWeight >= this.excludedWeight ? Decision.SELECTED : Decision.EXCLUDED;
            if (this.preDecision != Decision.EXCLUDED) {
                this.candidate.cachedDecision = decision;
                switch (this.candidate.cachedDecision) {
                    case SELECTED:
                        this.candidate.alternativeCachedWeightExpSum = this.excludedWeightExpSum;
                        break;
                    case EXCLUDED:
                        this.candidate.alternativeCachedWeightExpSum = this.selectedWeightExpSum;
                        break;
                }
            }
        } else {
            decision = cachedDecision;
        }
        SearchNode searchNode = decision == Decision.SELECTED ? this.selectedChild : this.excludedChild;
        if (searchNode != null && searchNode.bestPath) {
            this.candidate.bestChildNode = searchNode;
            this.bestPath = true;
        }
        if (!this.bestPath || decision != Decision.SELECTED) {
            this.selectedChild = null;
        }
        if (!this.bestPath || decision != Decision.EXCLUDED) {
            this.excludedChild = null;
        }
        return decision == Decision.SELECTED ? this.selectedWeight : this.excludedWeight;
    }

    private void invalidateCachedDecisions() {
        this.candidate.activation.getOutputLinks(false).filter(link -> {
            return !link.synapse.isNegative();
        }).forEach(link2 -> {
            invalidateCachedDecision(link2.output);
        });
    }

    public static void invalidateCachedDecision(Activation activation) {
        Candidate candidate = activation.candidate;
        if (candidate != null && candidate.cachedDecision == Decision.EXCLUDED) {
            candidate.cachedDecision = Decision.UNKNOWN;
            candidate.repeat = true;
        }
        Iterator<Activation> it = activation.getConflicts().iterator();
        while (it.hasNext()) {
            Candidate candidate2 = it.next().candidate;
            if (candidate2 != null && candidate2.cachedDecision == Decision.SELECTED) {
                candidate2.cachedDecision = Decision.UNKNOWN;
            }
        }
    }

    private double processResult(Document document) {
        double d = this.accumulatedWeight;
        if (this.level > document.selectedSearchNode.level || d > getSelectedAccumulatedWeight(document)) {
            document.selectedSearchNode = this;
            storeFinalState(this);
            this.bestPath = true;
        } else {
            this.bestPath = false;
        }
        if (COMPUTE_SOFT_MAX) {
            dumpDebugState();
            System.out.println(this.accumulatedWeight);
            System.out.println();
        }
        return this.accumulatedWeight;
    }

    private void storeSearchState(double d) {
        if (COMPUTE_SOFT_MAX) {
            Activation activation = this.candidate.activation;
            if (activation.searchStates == null) {
                activation.searchStates = new ArrayList();
            }
            activation.searchStates.add(new Activation.AvgState(activation.rounds.getLast(), d));
        }
    }

    private static void storeFinalState(SearchNode searchNode) {
        while (searchNode != null) {
            if (searchNode.candidate != null) {
                Activation activation = searchNode.candidate.activation;
                activation.finalRounds = activation.rounds.copy();
                activation.finalDecision = activation.decision;
            }
            searchNode = searchNode.getParent();
        }
    }

    private double getSelectedAccumulatedWeight(Document document) {
        if (document.selectedSearchNode != null) {
            return document.selectedSearchNode.accumulatedWeight;
        }
        return -1.0d;
    }

    private boolean checkExcluded(Activation activation) {
        Iterator<Activation> it = activation.getConflicts().iterator();
        while (it.hasNext()) {
            if (it.next().decision == Decision.SELECTED) {
                return true;
            }
        }
        return false;
    }

    public String pathToString() {
        return (this.selectedParent != null ? this.selectedParent.pathToString() : "") + " - " + toString();
    }

    public String toString() {
        return this.candidate.activation.id + " Decision:" + getDecision();
    }

    public void changeState(Activation.Mode mode) {
        Iterator<Activation.StateChange> it = this.modifiedActs.values().iterator();
        while (it.hasNext()) {
            it.next().restoreState(mode);
        }
    }

    @Override // java.lang.Comparable
    public int compareTo(SearchNode searchNode) {
        return Integer.compare(this.id, searchNode.id);
    }

    public SearchNode getParent() {
        return getDecision() == Decision.SELECTED ? this.selectedParent : this.excludedParent;
    }

    public Decision getDecision() {
        return (this.excludedParent == null || (this.selectedParent != null && this.selectedParent.id > this.excludedParent.id)) ? Decision.SELECTED : Decision.EXCLUDED;
    }

    private void storeDebugInfos() {
        if (this.preDecision != Decision.UNKNOWN) {
            this.debugState = DebugState.LIMITED;
        } else if (getCachedDecision() != Decision.UNKNOWN) {
            this.debugState = DebugState.CACHED;
        } else {
            this.debugState = DebugState.EXPLORE;
        }
        int[] iArr = this.candidate.debugCounts;
        int ordinal = this.debugState.ordinal();
        iArr[ordinal] = iArr[ordinal] + 1;
    }
}
