package ai.libs.jaicore.search.algorithms.standard.mcts.brue;

import ai.libs.jaicore.basic.sets.Pair;
import ai.libs.jaicore.search.algorithms.standard.mcts.ActionPredictionFailedException;
import ai.libs.jaicore.search.algorithms.standard.mcts.IPathUpdatablePolicy;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import java.util.Random;
import org.api4.java.datastructure.graph.ILabeledPath;

/* loaded from: input_file:ai/libs/jaicore/search/algorithms/standard/mcts/brue/BRUEPolicy.class */
public class BRUEPolicy<N, A> implements IPathUpdatablePolicy<N, A, Double> {
    private final Map<Pair<N, A>, Integer> nCounter;
    private final Map<Pair<N, A>, Double> qHat;
    private final Random random;
    private final int timeHorizon;
    private final boolean maximize;
    private int n;

    public BRUEPolicy(boolean z, int i, Random random) {
        this.nCounter = new HashMap();
        this.qHat = new HashMap();
        this.n = 0;
        this.maximize = z;
        this.timeHorizon = i;
        this.random = random;
    }

    public BRUEPolicy(boolean z) {
        this(z, 1000, new Random(0L));
    }

    @Override // ai.libs.jaicore.search.algorithms.standard.mcts.IPolicy
    public A getAction(N n, Map<A, N> map) throws ActionPredictionFailedException {
        if (map.isEmpty()) {
            throw new IllegalArgumentException();
        }
        double d = (this.maximize ? -1 : 1) * Double.MAX_VALUE;
        ArrayList arrayList = new ArrayList();
        for (Map.Entry<A, N> entry : map.entrySet()) {
            Pair pair = new Pair(entry.getValue(), entry.getKey());
            double doubleValue = this.qHat.containsKey(pair) ? this.qHat.get(pair).doubleValue() : d;
            if (doubleValue < d) {
                arrayList.clear();
                d = doubleValue;
                arrayList.add(pair.getY());
            } else if (doubleValue == d) {
                arrayList.add(pair.getY());
            }
        }
        if (arrayList.isEmpty()) {
            throw new IllegalStateException();
        }
        if (arrayList.size() > 1) {
            Collections.shuffle(arrayList, this.random);
        }
        A a = (A) arrayList.get(0);
        Pair<N, A> pair2 = new Pair<>(n, a);
        this.nCounter.put(pair2, Integer.valueOf(this.nCounter.computeIfAbsent(pair2, pair3 -> {
            return 0;
        }).intValue() + 1));
        return a;
    }

    @Override // ai.libs.jaicore.search.algorithms.standard.mcts.IPathUpdatablePolicy
    public void updatePath(ILabeledPath<N, A> iLabeledPath, Double d, int i) {
        int switchingPoint = getSwitchingPoint(this.n);
        this.n++;
        if (switchingPoint > iLabeledPath.getNumberOfNodes() - 2) {
            return;
        }
        Pair<N, A> pair = new Pair<>(iLabeledPath.getNodes().get(switchingPoint - 1), iLabeledPath.getArcs().get(switchingPoint - 1));
        double doubleValue = this.qHat.containsKey(pair) ? this.qHat.get(pair).doubleValue() : (this.maximize ? -1 : 1) * Double.MAX_VALUE;
        if (!this.nCounter.containsKey(pair)) {
            throw new IllegalStateException("No visit stats for updated pair " + pair + " available.");
        }
        this.qHat.put(pair, Double.valueOf(doubleValue + ((d.doubleValue() - doubleValue) / this.nCounter.get(pair).intValue())));
    }

    public int getSwitchingPoint(int i) {
        return this.timeHorizon - ((i - 1) % this.timeHorizon);
    }
}
