package ai.libs.jaicore.search.algorithms.mdp.mcts.tag;

import ai.libs.jaicore.basic.sets.SetUtil;
import ai.libs.jaicore.search.algorithms.mdp.mcts.ActionPredictionFailedException;
import ai.libs.jaicore.search.algorithms.mdp.mcts.IPathUpdatablePolicy;
import java.util.Collection;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.PriorityQueue;
import org.api4.java.common.control.ILoggingCustomizable;
import org.api4.java.datastructure.graph.ILabeledPath;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:ai/libs/jaicore/search/algorithms/mdp/mcts/tag/TAGPolicy.class */
public class TAGPolicy<T, A> implements IPathUpdatablePolicy<T, A, Double>, ILoggingCustomizable {
    private String loggerName;
    private Logger logger;
    private double explorationConstant;
    private final int s;
    private final Map<T, Double> thresholdPerNode;
    private final double delta;
    private final double thresholdIncrement;
    private final boolean isMaximize;
    private final Map<T, Map<A, PriorityQueue<Double>>> statsPerNode;
    private final Map<T, Map<A, Integer>> pullsPerNodeAction;
    private final Map<T, Integer> visitsPerNode;
    static final /* synthetic */ boolean $assertionsDisabled;

    public TAGPolicy() {
        this(false);
    }

    public TAGPolicy(double d, int i, double d2, double d3, boolean z) {
        this.logger = LoggerFactory.getLogger(TAGPolicy.class);
        this.explorationConstant = Math.sqrt(2.0d);
        this.thresholdPerNode = new HashMap();
        this.statsPerNode = new HashMap();
        this.pullsPerNodeAction = new HashMap();
        this.visitsPerNode = new HashMap();
        this.explorationConstant = d;
        this.s = i;
        this.delta = d2;
        this.thresholdIncrement = d3;
        this.isMaximize = z;
    }

    public TAGPolicy(boolean z) {
        this(Math.sqrt(2.0d), 10, 1.0d, 0.01d, z);
    }

    @Override // ai.libs.jaicore.search.algorithms.mdp.mcts.IPolicy
    public A getAction(T t, Collection<A> collection) throws ActionPredictionFailedException {
        this.logger.info("Getting action for node {}", t);
        Map<A, Integer> computeIfAbsent = this.pullsPerNodeAction.computeIfAbsent(t, obj -> {
            return new HashMap();
        });
        collection.forEach(obj2 -> {
            computeIfAbsent.computeIfAbsent(obj2, obj2 -> {
                return 1;
            });
        });
        this.visitsPerNode.put(t, Integer.valueOf(this.visitsPerNode.computeIfAbsent(t, obj3 -> {
            return 0;
        }).intValue() + 1));
        this.logger.debug("Adjusting threshold.");
        adjustThreshold(t);
        this.logger.debug("Threshold adjusted. Is now {}", this.thresholdPerNode.get(t));
        A a = null;
        double d = (this.isMaximize ? -1 : 1) * Double.MAX_VALUE;
        int size = collection.size();
        for (A a2 : collection) {
            double utilityOfAction = getUtilityOfAction(t, a2, size);
            if (Double.isNaN(utilityOfAction) || ((!this.isMaximize || utilityOfAction <= d) && (this.isMaximize || utilityOfAction >= d))) {
                this.logger.trace("Skipping current solution {} since its score {} is not better than the currently best {}.", new Object[]{a2, Double.valueOf(utilityOfAction), Double.valueOf(d)});
            } else {
                this.logger.trace("Updating best choice {} with {} since it is better than the current solution with performance {}", new Object[]{a, a2, Double.valueOf(d)});
                d = utilityOfAction;
                a = a2;
            }
        }
        if (a == null) {
            this.logger.warn("All options have score NaN. Returning random element.");
            return (A) SetUtil.getRandomElement(collection, 0L);
        }
        computeIfAbsent.put(a, Integer.valueOf(computeIfAbsent.get(a).intValue() + 1));
        return a;
    }

    public void adjustThreshold(T t) {
        int i;
        Map<A, PriorityQueue<Double>> map = this.statsPerNode.get(t);
        double doubleValue = this.thresholdPerNode.computeIfAbsent(t, obj -> {
            return Double.valueOf(this.isMaximize ? 0.0d : 100.0d);
        }).doubleValue();
        this.logger.debug("Initial value for threshold is {}. Observations are: {}", Double.valueOf(doubleValue), map);
        if (map == null) {
            return;
        }
        boolean z = true;
        do {
            if (!z) {
                doubleValue += this.thresholdIncrement * (this.isMaximize ? 1 : -1);
            }
            i = 0;
            for (Map.Entry<A, PriorityQueue<Double>> entry : map.entrySet()) {
                double d = doubleValue;
                entry.getValue().removeIf(d2 -> {
                    return (this.isMaximize && d2.doubleValue() < d) || (!this.isMaximize && d2.doubleValue() > d);
                });
                i += entry.getValue().size();
            }
            z = false;
            if (i == Double.NaN) {
                break;
            }
        } while (i > this.s);
        this.logger.debug("Setting threshold to {}", Double.valueOf(doubleValue));
        this.thresholdPerNode.put(t, Double.valueOf(doubleValue));
    }

    public double getUtilityOfAction(T t, A a, int i) {
        if (!this.statsPerNode.containsKey(t) || !this.statsPerNode.get(t).containsKey(a)) {
            return Double.NaN;
        }
        double log = Math.log(((2 * this.visitsPerNode.get(t).intValue()) * i) / this.delta);
        int size = this.statsPerNode.get(t).get(a).size();
        if (log < 0.0d) {
            throw new IllegalStateException("Alpha must not be negative. Check delta value (must be smaller than 1)");
        }
        double sqrt = size + log + Math.sqrt((2 * size * log) + Math.pow(log, 2.0d));
        int intValue = this.pullsPerNodeAction.get(t).get(a).intValue();
        if (intValue == 0) {
            throw new IllegalArgumentException("Cannot compute score for child with no visits!");
        }
        double d = sqrt / intValue;
        this.logger.trace("Compute TAG score of {}", Double.valueOf(d));
        return d;
    }

    public double getExplorationConstant() {
        return this.explorationConstant;
    }

    public void setExplorationConstant(double d) {
        this.explorationConstant = d;
    }

    @Override // ai.libs.jaicore.search.algorithms.mdp.mcts.IPathUpdatablePolicy
    public void updatePath(ILabeledPath<T, A> iLabeledPath, List<Double> list) {
        int numberOfNodes = iLabeledPath.getNumberOfNodes() - 1;
        List nodes = iLabeledPath.getNodes();
        List arcs = iLabeledPath.getArcs();
        double d = 0.0d;
        for (int i = numberOfNodes - 1; i >= 0; i--) {
            PriorityQueue priorityQueue = (PriorityQueue) ((Map) this.statsPerNode.computeIfAbsent(nodes.get(i), obj -> {
                return new HashMap();
            })).computeIfAbsent(arcs.get(i), obj2 -> {
                return this.isMaximize ? new PriorityQueue((d2, d3) -> {
                    return Double.compare(d3.doubleValue(), d2.doubleValue());
                }) : new PriorityQueue();
            });
            if (!$assertionsDisabled && priorityQueue.contains(Double.valueOf(Double.NaN))) {
                throw new AssertionError();
            }
            if (d != Double.NaN && list.get(i) != null) {
                d += list.get(i).doubleValue();
            } else if (!Double.isNaN(d)) {
                d = Double.NaN;
            }
            if (Double.isNaN(d)) {
                return;
            }
            if (priorityQueue.size() < this.s) {
                priorityQueue.add(Double.valueOf(d));
            } else if (((Double) priorityQueue.peek()).doubleValue() < d) {
                priorityQueue.poll();
                priorityQueue.add(Double.valueOf(d));
            }
            if (!$assertionsDisabled && priorityQueue.contains(Double.valueOf(Double.NaN))) {
                throw new AssertionError();
            }
        }
    }

    public String getLoggerName() {
        return this.loggerName;
    }

    public void setLoggerName(String str) {
        this.loggerName = str;
        this.logger = LoggerFactory.getLogger(str);
    }

    static {
        $assertionsDisabled = !TAGPolicy.class.desiredAssertionStatus();
    }
}
