package ai.libs.jaicore.search.algorithms.mdp.mcts;

import ai.libs.jaicore.basic.algorithm.AAlgorithm;
import ai.libs.jaicore.basic.algorithm.EAlgorithmState;
import ai.libs.jaicore.search.probleminputs.IMDP;
import ai.libs.jaicore.search.probleminputs.MDPUtils;
import ai.libs.jaicore.timing.TimedComputation;
import com.google.common.eventbus.Subscribe;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Random;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeUnit;
import java.util.stream.Collectors;
import org.api4.java.algorithm.Timeout;
import org.api4.java.algorithm.exceptions.AlgorithmException;
import org.api4.java.algorithm.exceptions.AlgorithmExecutionCanceledException;
import org.api4.java.algorithm.exceptions.AlgorithmTimeoutedException;
import org.api4.java.common.control.ILoggingCustomizable;
import org.api4.java.common.event.IEvent;
import org.api4.java.common.event.IRelaxedEventEmitter;
import org.api4.java.datastructure.graph.ILabeledPath;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:ai/libs/jaicore/search/algorithms/mdp/mcts/MCTS.class */
public class MCTS<N, A> extends AAlgorithm<IMDP<N, A, Double>, IPolicy<N, A>> {
    private Logger logger;
    private static final Runtime runtime;
    private final IMDP<N, A, Double> mdp;
    private final int maxDepth;
    private final MDPUtils utils;
    private final IPathUpdatablePolicy<N, A, Double> treePolicy;
    private final IPolicy<N, A> defaultPolicy;
    private final boolean uniformSamplingDefaultPolicy;
    private final Random randomSourceOfUniformSamplyPolicy;
    private final int maxIterations;
    private int iterations;
    private final Collection<N> tpReadyStates;
    private final Map<N, Collection<A>> applicableActionsPerState;
    private final Map<N, List<A>> untriedActionsOfIncompleteStates;
    private int lastProgressReport;
    private int msSpentInRollouts;
    private int msSpentInTreePolicyQueries;
    private int msSpentInTreePolicyUpdates;
    private final boolean tabooExhaustedNodes;
    private Map<N, Collection<A>> tabooActions;
    private ILabeledPath<N, A> enforcedPrefixPath;
    static final /* synthetic */ boolean $assertionsDisabled;

    /* JADX INFO: Access modifiers changed from: package-private */
    /* renamed from: ai.libs.jaicore.search.algorithms.mdp.mcts.MCTS$2, reason: invalid class name */
    /* loaded from: input_file:ai/libs/jaicore/search/algorithms/mdp/mcts/MCTS$2.class */
    public static /* synthetic */ class AnonymousClass2 {
        static final /* synthetic */ int[] $SwitchMap$ai$libs$jaicore$basic$algorithm$EAlgorithmState = new int[EAlgorithmState.values().length];

        static {
            try {
                $SwitchMap$ai$libs$jaicore$basic$algorithm$EAlgorithmState[EAlgorithmState.CREATED.ordinal()] = 1;
            } catch (NoSuchFieldError e) {
            }
            try {
                $SwitchMap$ai$libs$jaicore$basic$algorithm$EAlgorithmState[EAlgorithmState.ACTIVE.ordinal()] = 2;
            } catch (NoSuchFieldError e2) {
            }
        }
    }

    public MCTS(IMDP<N, A, Double> imdp, IPathUpdatablePolicy<N, A, Double> iPathUpdatablePolicy, IPolicy<N, A> iPolicy, int i, double d, double d2, boolean z) {
        super(imdp);
        this.logger = LoggerFactory.getLogger(MCTS.class);
        this.utils = new MDPUtils();
        this.iterations = 0;
        this.tpReadyStates = new HashSet();
        this.applicableActionsPerState = new HashMap();
        this.untriedActionsOfIncompleteStates = new HashMap();
        this.lastProgressReport = 0;
        this.tabooActions = new HashMap();
        this.enforcedPrefixPath = null;
        Objects.requireNonNull(imdp);
        Objects.requireNonNull(iPathUpdatablePolicy);
        Objects.requireNonNull(iPolicy);
        this.mdp = imdp;
        this.treePolicy = iPathUpdatablePolicy;
        this.defaultPolicy = iPolicy;
        this.uniformSamplingDefaultPolicy = iPolicy instanceof UniformRandomPolicy;
        this.randomSourceOfUniformSamplyPolicy = this.uniformSamplingDefaultPolicy ? ((UniformRandomPolicy) iPolicy).getRandom() : null;
        this.maxIterations = i;
        this.maxDepth = MDPUtils.getTimeHorizon(d, d2);
        this.tabooExhaustedNodes = z;
        if (iPathUpdatablePolicy instanceof IRelaxedEventEmitter) {
            ((IRelaxedEventEmitter) iPathUpdatablePolicy).registerListener(new Object() { // from class: ai.libs.jaicore.search.algorithms.mdp.mcts.MCTS.1
                @Subscribe
                public void receiveEvent(IEvent iEvent) {
                    MCTS.this.post(iEvent);
                }
            });
        }
    }

    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r0v26, types: [java.util.List] */
    public List<A> getPotentialActions(ILabeledPath<N, A> iLabeledPath, Collection<A> collection) {
        Object head = iLabeledPath.getHead();
        ArrayList arrayList = new ArrayList(collection);
        if (arrayList.isEmpty()) {
            this.logger.warn("Computing potential actions for an empty set of applicable actions makes no sense! Returning an empty set for node {}.", head);
            return arrayList;
        }
        this.logger.debug("Computing potential actions based on {} applicable ones for state {}", Integer.valueOf(collection.size()), head);
        if (this.tabooExhaustedNodes) {
            Collection<A> collection2 = this.tabooActions.get(head);
            this.logger.debug("Found {} tabooed actions for this state.", Integer.valueOf(collection2 != null ? collection2.size() : 0));
            if (collection2 != null) {
                arrayList = (List) arrayList.stream().filter(obj -> {
                    return !collection2.contains(obj);
                }).collect(Collectors.toList());
            }
            if (arrayList.isEmpty() && iLabeledPath.getNumberOfNodes() > 1) {
                tabooLastActionOfPath(iLabeledPath);
            }
        }
        return arrayList;
    }

    private Collection<A> getApplicableActions(N n) throws AlgorithmTimeoutedException, ExecutionException, InterruptedException, AlgorithmExecutionCanceledException {
        Timeout timeout = new Timeout(getRemainingTimeToDeadline().milliseconds() - 1000, TimeUnit.MILLISECONDS);
        this.logger.debug("Computing all applicable actions with timeout {}.", timeout);
        try {
            Collection<A> unmodifiableCollection = Collections.unmodifiableCollection((Collection) TimedComputation.compute(() -> {
                return this.mdp.getApplicableActions(n);
            }, timeout, "Timeout bound hit."));
            this.logger.debug("Number of applicable actions is {}", Integer.valueOf(unmodifiableCollection.size()));
            return unmodifiableCollection;
        } catch (InterruptedException e) {
            checkAndConductTermination();
            throw e;
        }
    }

    /* JADX WARN: Code restructure failed: missing block: B:54:0x05c1, code lost:
    
        if (r16.tabooExhaustedNodes == false) goto L100;
     */
    /* JADX WARN: Code restructure failed: missing block: B:56:0x05c7, code lost:
    
        if (r35 != 1) goto L100;
     */
    /* JADX WARN: Code restructure failed: missing block: B:57:0x05ca, code lost:
    
        tabooLastActionOfPath(r0);
     */
    /* JADX WARN: Code restructure failed: missing block: B:58:0x05d0, code lost:
    
        r0 = (int) java.lang.Math.round((r16.iterations * 100.0d) / r16.maxIterations);
     */
    /* JADX WARN: Code restructure failed: missing block: B:59:0x05eb, code lost:
    
        if (r0 <= r16.lastProgressReport) goto L105;
     */
    /* JADX WARN: Code restructure failed: missing block: B:61:0x05f2, code lost:
    
        if ((r0 % 5) != 0) goto L105;
     */
    /* JADX WARN: Code restructure failed: missing block: B:62:0x05f5, code lost:
    
        r16.logger.info("Progress: {}%", java.lang.Long.valueOf(java.lang.Math.round((r16.iterations * 100.0d) / r16.maxIterations)));
        r16.lastProgressReport = r0;
     */
    /* JADX WARN: Code restructure failed: missing block: B:63:0x061b, code lost:
    
        r0 = r0.contains(null);
        r0 = r16.mdp.isTerminalState(r0.getHead());
     */
    /* JADX WARN: Code restructure failed: missing block: B:64:0x0639, code lost:
    
        if (r0 == false) goto L108;
     */
    /* JADX WARN: Code restructure failed: missing block: B:65:0x063c, code lost:
    
        r0 = Double.NaN;
     */
    /* JADX WARN: Code restructure failed: missing block: B:66:0x065d, code lost:
    
        r42 = r0;
        r16.logger.info("Found playout of length {}. Head is goal: {}. (Undiscounted) score of path is {}.", new java.lang.Object[]{java.lang.Integer.valueOf(r0.getNumberOfNodes()), java.lang.Boolean.valueOf(r0), java.lang.Double.valueOf(r42)});
        r16.logger.debug("Found leaf node with score {}. Now propagating this score over the path with actions {}. Leaf state is: {}.", new java.lang.Object[]{java.lang.Double.valueOf(r42), r0.getArcs(), r0.getHead()});
     */
    /* JADX WARN: Code restructure failed: missing block: B:67:0x06bd, code lost:
    
        if (r0.isPoint() != false) goto L112;
     */
    /* JADX WARN: Code restructure failed: missing block: B:68:0x06c0, code lost:
    
        r0 = java.lang.System.currentTimeMillis();
        r16.treePolicy.updatePath(r0, r0);
        r27 = 0 + (java.lang.System.currentTimeMillis() - r0);
     */
    /* JADX WARN: Code restructure failed: missing block: B:69:0x06dd, code lost:
    
        r0 = new ai.libs.jaicore.search.algorithms.mdp.mcts.MCTSIterationCompletedEvent(r16, r16.treePolicy, new ai.libs.jaicore.search.model.other.SearchGraphPath((org.api4.java.datastructure.graph.ILabeledPath) r0), r0);
        post(r0);
        summarizeIteration(java.lang.System.currentTimeMillis() - r0, r21, r23, r19, r20, r25, r27, r29);
     */
    /* JADX WARN: Code restructure failed: missing block: B:70:0x0717, code lost:
    
        r16.logger.debug("Unregistering thread {}", java.lang.Thread.currentThread());
        unregisterActiveThread();
     */
    /* JADX WARN: Code restructure failed: missing block: B:71:0x072a, code lost:
    
        return r0;
     */
    /* JADX WARN: Code restructure failed: missing block: B:72:0x0642, code lost:
    
        r0 = ((java.lang.Double) r0.stream().reduce(java.lang.Double.valueOf(0.0d), (v0, v1) -> { // java.util.function.BinaryOperator.apply(java.lang.Object, java.lang.Object):java.lang.Object
            return lambda$nextWithException$2(v0, v1);
        })).doubleValue();
     */
    /* JADX WARN: Multi-variable type inference failed */
    /*
        Code decompiled incorrectly, please refer to instructions dump.
        To view partially-correct add '--show-bad-code' argument
    */
    public org.api4.java.algorithm.events.IAlgorithmEvent nextWithException() throws java.lang.InterruptedException, org.api4.java.algorithm.exceptions.AlgorithmExecutionCanceledException, org.api4.java.algorithm.exceptions.AlgorithmTimeoutedException, org.api4.java.algorithm.exceptions.AlgorithmException {
        /*
            Method dump skipped, instructions count: 1919
            To view this dump add '--comments-level debug' option
        */
        throw new UnsupportedOperationException("Method not decompiled: ai.libs.jaicore.search.algorithms.mdp.mcts.MCTS.nextWithException():org.api4.java.algorithm.events.IAlgorithmEvent");
    }

    private void summarizeIteration(long j, long j2, long j3, int i, int i2, long j4, long j5, long j6) {
        this.msSpentInRollouts = (int) (this.msSpentInRollouts + j);
        this.msSpentInTreePolicyQueries = (int) (this.msSpentInTreePolicyQueries + j4);
        this.msSpentInTreePolicyUpdates = (int) (this.msSpentInTreePolicyUpdates + j5);
        this.logger.info("Finished rollout in {}ms. Time for computing applicable actions was {}ms and for computing successors {}ms. Time for TP {} queries was {}ms, time to update TP {}ms, time for {} DP queries was {}ms. Currently used memory: {}MB.", new Object[]{Long.valueOf(j), Long.valueOf(j2), Long.valueOf(j3), Integer.valueOf(i), Long.valueOf(j4), Long.valueOf(j5), Integer.valueOf(i2), Long.valueOf(j6), Long.valueOf((runtime.totalMemory() - runtime.freeMemory()) / 1048576)});
    }

    private void tabooLastActionOfPath(ILabeledPath<N, A> iLabeledPath) {
        if (iLabeledPath.isPoint()) {
            throw new IllegalArgumentException("The path is a point, which has no first action to taboo.");
        }
        Object parentOfHead = iLabeledPath.getParentOfHead();
        Object outArc = iLabeledPath.getOutArc(parentOfHead);
        ((Collection) this.tabooActions.computeIfAbsent(parentOfHead, obj -> {
            return new HashSet();
        })).add(outArc);
        this.logger.debug("Adding action {} to taboo list of state {}", outArc, parentOfHead);
    }

    public int getNumberOfRealizedPlayouts() {
        return this.iterations;
    }

    public IPathUpdatablePolicy<N, A, Double> getTreePolicy() {
        return this.treePolicy;
    }

    @Override // 
    /* renamed from: call, reason: merged with bridge method [inline-methods] */
    public IPolicy<N, A> mo2call() throws InterruptedException, AlgorithmExecutionCanceledException, AlgorithmTimeoutedException, AlgorithmException {
        while (hasNext()) {
            nextWithException();
        }
        return this.treePolicy;
    }

    public void enforcePrefixPathOnAllRollouts(ILabeledPath<N, A> iLabeledPath) {
        if (!iLabeledPath.getRoot().equals(this.mdp.getInitState())) {
            throw new IllegalArgumentException("Illegal prefix, since root does not coincide with algorithm root. Proposed root is: " + iLabeledPath.getRoot());
        }
        this.enforcedPrefixPath = iLabeledPath;
        Object obj = null;
        for (Object obj2 : iLabeledPath.getNodes()) {
            if (obj != null) {
                this.tpReadyStates.remove(obj);
                this.tpReadyStates.add(obj2);
            }
            obj = obj2;
        }
        throw new UnsupportedOperationException("Currently, enforced prefixes are ignored!");
    }

    public ILabeledPath<N, A> getEnforcedPrefixPath() {
        return this.enforcedPrefixPath.getUnmodifiableAccessor();
    }

    public void setLoggerName(String str) {
        this.logger = LoggerFactory.getLogger(str);
        super.setLoggerName(str + ".abstract");
        if (this.mdp instanceof ILoggingCustomizable) {
            this.mdp.setLoggerName(str + ".mdp");
        }
        if (this.treePolicy instanceof ILoggingCustomizable) {
            this.logger.info("Setting logger of tree policy to {}.treepolicy", str);
            this.treePolicy.setLoggerName(str + ".tp");
        } else {
            this.logger.info("Not setting logger of tree policy, because {} is not customizable.", this.treePolicy.getClass().getName());
        }
        if (this.defaultPolicy instanceof ILoggingCustomizable) {
            this.logger.info("Setting logger of default policy to {}.defaultpolicy", str);
            this.defaultPolicy.setLoggerName(str + ".dp");
        } else {
            this.logger.info("Not setting logger of default policy, because {} is not customizable.", this.defaultPolicy.getClass().getName());
        }
        this.utils.setLoggerName(str + ".utils");
    }

    public boolean hasTreePolicyReachedLeafs() {
        throw new UnsupportedOperationException("Currently not implemented.");
    }

    public String getLoggerName() {
        return this.logger.getName();
    }

    public int getNumberOfNodesInMemory() {
        return this.tpReadyStates.size();
    }

    public int getMsSpentInRollouts() {
        return this.msSpentInRollouts;
    }

    public int getMsSpentInTreePolicyQueries() {
        return this.msSpentInTreePolicyQueries;
    }

    public int getMsSpentInTreePolicyUpdates() {
        return this.msSpentInTreePolicyUpdates;
    }

    public boolean isTabooExhaustedNodes() {
        return this.tabooExhaustedNodes;
    }

    static {
        $assertionsDisabled = !MCTS.class.desiredAssertionStatus();
        runtime = Runtime.getRuntime();
    }
}
