package ai.libs.jaicore.search.algorithms.standard.mcts;

import ai.libs.jaicore.basic.algorithm.AlgorithmFinishedEvent;
import ai.libs.jaicore.basic.algorithm.AlgorithmInitializedEvent;
import ai.libs.jaicore.basic.algorithm.EAlgorithmState;
import ai.libs.jaicore.basic.sets.SetUtil;
import ai.libs.jaicore.search.algorithms.mdp.mcts.GraphBasedMDP;
import ai.libs.jaicore.search.algorithms.mdp.mcts.MCTS;
import ai.libs.jaicore.search.algorithms.mdp.mcts.MCTSFactory;
import ai.libs.jaicore.search.algorithms.mdp.mcts.MCTSIterationCompletedEvent;
import ai.libs.jaicore.search.algorithms.standard.bestfirst.events.EvaluatedSearchSolutionCandidateFoundEvent;
import ai.libs.jaicore.search.core.interfaces.AOptimalPathInORGraphSearch;
import ai.libs.jaicore.search.model.other.EvaluatedSearchGraphPath;
import ai.libs.jaicore.search.probleminputs.IMDP;
import com.google.common.eventbus.Subscribe;
import java.util.HashSet;
import java.util.Set;
import java.util.concurrent.TimeUnit;
import org.api4.java.ai.graphsearch.problem.IPathSearchWithPathEvaluationsInput;
import org.api4.java.algorithm.Timeout;
import org.api4.java.algorithm.events.IAlgorithmEvent;
import org.api4.java.algorithm.exceptions.AlgorithmException;
import org.api4.java.algorithm.exceptions.AlgorithmExecutionCanceledException;
import org.api4.java.algorithm.exceptions.AlgorithmTimeoutedException;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:ai/libs/jaicore/search/algorithms/standard/mcts/MCTSPathSearch.class */
public class MCTSPathSearch<I extends IPathSearchWithPathEvaluationsInput<N, A, Double>, N, A> extends AOptimalPathInORGraphSearch<I, N, A, Double> {
    private Logger logger;
    private final GraphBasedMDP<N, A> mdp;
    private final MCTS<N, A> mcts;
    private final Set<Integer> hashCodesOfReturnedPaths;

    /* renamed from: ai.libs.jaicore.search.algorithms.standard.mcts.MCTSPathSearch$2, reason: invalid class name */
    /* loaded from: input_file:ai/libs/jaicore/search/algorithms/standard/mcts/MCTSPathSearch$2.class */
    static /* synthetic */ class AnonymousClass2 {
        static final /* synthetic */ int[] $SwitchMap$ai$libs$jaicore$basic$algorithm$EAlgorithmState = new int[EAlgorithmState.values().length];

        static {
            try {
                $SwitchMap$ai$libs$jaicore$basic$algorithm$EAlgorithmState[EAlgorithmState.CREATED.ordinal()] = 1;
            } catch (NoSuchFieldError e) {
            }
            try {
                $SwitchMap$ai$libs$jaicore$basic$algorithm$EAlgorithmState[EAlgorithmState.ACTIVE.ordinal()] = 2;
            } catch (NoSuchFieldError e2) {
            }
        }
    }

    public MCTSPathSearch(I i, MCTSFactory<N, A, ?> mCTSFactory) {
        super(i);
        this.logger = LoggerFactory.getLogger(MCTSPathSearch.class);
        this.hashCodesOfReturnedPaths = new HashSet();
        this.mdp = new GraphBasedMDP<>(i);
        this.mcts = mCTSFactory.getAlgorithm(this.mdp);
        this.mcts.registerListener(new Object() { // from class: ai.libs.jaicore.search.algorithms.standard.mcts.MCTSPathSearch.1
            @Subscribe
            public void receiveMCTSEvent(IAlgorithmEvent iAlgorithmEvent) {
                if ((iAlgorithmEvent instanceof AlgorithmInitializedEvent) || (iAlgorithmEvent instanceof AlgorithmFinishedEvent)) {
                    return;
                }
                MCTSPathSearch.this.post(iAlgorithmEvent);
            }
        });
    }

    public IAlgorithmEvent nextWithException() throws InterruptedException, AlgorithmExecutionCanceledException, AlgorithmTimeoutedException, AlgorithmException {
        switch (AnonymousClass2.$SwitchMap$ai$libs$jaicore$basic$algorithm$EAlgorithmState[getState().ordinal()]) {
            case 1:
                this.mdp.setLoggerName(getLoggerName() + ".mdp");
                do {
                } while (!(this.mcts.next() instanceof AlgorithmInitializedEvent));
                return activate();
            case 2:
                if (this.mcts.getState() != EAlgorithmState.ACTIVE) {
                    return terminate();
                }
                while (true) {
                    MCTSIterationCompletedEvent nextWithException = this.mcts.nextWithException();
                    if (nextWithException instanceof AlgorithmFinishedEvent) {
                        return terminate();
                    }
                    if (nextWithException instanceof MCTSIterationCompletedEvent) {
                        MCTSIterationCompletedEvent mCTSIterationCompletedEvent = nextWithException;
                        double sum = SetUtil.sum(mCTSIterationCompletedEvent.getScores());
                        this.logger.info("Registered rollout with score {}. Updating best seen solution correspondingly.", Double.valueOf(sum));
                        EvaluatedSearchGraphPath evaluatedSearchGraphPath = new EvaluatedSearchGraphPath(mCTSIterationCompletedEvent.getRollout(), Double.valueOf(sum));
                        if (getGoalTester().isGoal(evaluatedSearchGraphPath)) {
                            updateBestSeenSolution(evaluatedSearchGraphPath);
                            int hashCode = evaluatedSearchGraphPath.hashCode();
                            if (!this.hashCodesOfReturnedPaths.contains(Integer.valueOf(hashCode))) {
                                this.hashCodesOfReturnedPaths.add(Integer.valueOf(hashCode));
                                EvaluatedSearchSolutionCandidateFoundEvent evaluatedSearchSolutionCandidateFoundEvent = new EvaluatedSearchSolutionCandidateFoundEvent(this, evaluatedSearchGraphPath);
                                post(evaluatedSearchSolutionCandidateFoundEvent);
                                return evaluatedSearchSolutionCandidateFoundEvent;
                            }
                            this.logger.info("Skipping (and supressing) previously found solution with hash code {}", Integer.valueOf(hashCode));
                        } else {
                            continue;
                        }
                    }
                }
            default:
                throw new IllegalStateException();
        }
    }

    public void setTimeout(Timeout timeout) {
        if (timeout.seconds() < 2) {
            throw new IllegalArgumentException("Cannot run MCTS with a timeout of less than 2 seconds.");
        }
        super.setTimeout(timeout);
        this.mcts.setTimeout(new Timeout(timeout.seconds() - 1, TimeUnit.SECONDS));
    }

    public void cancel() {
        super.cancel();
        this.mcts.cancel();
    }

    @Override // ai.libs.jaicore.search.core.interfaces.AOptimalPathInORGraphSearch
    public void setLoggerName(String str) {
        super.setLoggerName(str + "._algorithm");
        this.logger = LoggerFactory.getLogger(str);
        this.mcts.setLoggerName(str + ".mcts");
    }

    @Override // ai.libs.jaicore.search.core.interfaces.AOptimalPathInORGraphSearch
    public String getLoggerName() {
        return this.logger.getName();
    }

    public IMDP<N, A, Double> getMdp() {
        return this.mdp;
    }

    public MCTS<N, A> getMcts() {
        return this.mcts;
    }

    public int getNumberOfNodesInMemory() {
        return this.mcts.getNumberOfNodesInMemory();
    }
}
