package ai.libs.jaicore.search.algorithms.standard.mcts.tag;

import ai.libs.jaicore.graph.LabeledGraph;
import ai.libs.jaicore.search.algorithms.standard.mcts.ActionPredictionFailedException;
import ai.libs.jaicore.search.algorithms.standard.mcts.IGraphDependentPolicy;
import ai.libs.jaicore.search.algorithms.standard.mcts.IPathUpdatablePolicy;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.PriorityQueue;
import java.util.stream.Collectors;
import org.api4.java.common.control.ILoggingCustomizable;
import org.api4.java.datastructure.graph.ILabeledPath;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:ai/libs/jaicore/search/algorithms/standard/mcts/tag/TAGPolicy.class */
public class TAGPolicy<T, A> implements IPathUpdatablePolicy<T, A, Double>, IGraphDependentPolicy<T, A>, ILoggingCustomizable {
    private String loggerName;
    private Logger logger;
    private LabeledGraph<T, A> explorationGraph;
    private double explorationConstant;
    private final int s;
    private final double delta;
    private final boolean isMaximize;
    private final Map<T, PriorityQueue<Double>> statsPerNode;
    private final Map<T, Integer> visitsPerNode;

    public TAGPolicy() {
        this(false);
    }

    public TAGPolicy(double d, int i, double d2, boolean z) {
        this.logger = LoggerFactory.getLogger(TAGPolicy.class);
        this.explorationConstant = Math.sqrt(2.0d);
        this.statsPerNode = new HashMap();
        this.visitsPerNode = new HashMap();
        this.explorationConstant = d;
        this.s = i;
        this.delta = d2;
        this.isMaximize = z;
    }

    public TAGPolicy(boolean z) {
        this(Math.sqrt(2.0d), 10, 0.01d, z);
    }

    public String getLoggerName() {
        return this.loggerName;
    }

    public void setLoggerName(String str) {
        this.loggerName = str;
        this.logger = LoggerFactory.getLogger(str);
    }

    public double getScoreOfChild(T t, T t2) {
        double doubleValue = this.statsPerNode.get(t).peek().doubleValue();
        int size = ((List) this.statsPerNode.get(t2).stream().filter(d -> {
            return this.isMaximize ? d.doubleValue() >= doubleValue : d.doubleValue() <= doubleValue;
        }).collect(Collectors.toList())).size();
        double log = Math.log(((2 * this.visitsPerNode.get(t).intValue()) * this.explorationGraph.getSuccessors(t).size()) / this.delta);
        if (log < 0.0d) {
            throw new IllegalStateException("Alpha must not be negative. Check delta value (must be smaller than 1)");
        }
        int intValue = this.visitsPerNode.get(t2).intValue();
        if (intValue == 0) {
            throw new IllegalArgumentException("Cannot compute score for child with no visits!");
        }
        Double valueOf = Double.valueOf(((size + log) + Math.sqrt(((2 * size) * log) + Math.pow(log, 2.0d))) / intValue);
        this.logger.trace("Compute TAG score of {}", valueOf);
        return valueOf.doubleValue();
    }

    public double getExplorationConstant() {
        return this.explorationConstant;
    }

    public void setExplorationConstant(double d) {
        this.explorationConstant = d;
    }

    @Override // ai.libs.jaicore.search.algorithms.standard.mcts.IPolicy
    public A getAction(T t, Map<A, T> map) throws ActionPredictionFailedException {
        A a = null;
        double d = (this.isMaximize ? -1 : 1) * Double.MAX_VALUE;
        for (Map.Entry<A, T> entry : map.entrySet()) {
            A key = entry.getKey();
            Double valueOf = Double.valueOf(getScoreOfChild(t, entry.getValue()));
            if (valueOf.isNaN()) {
                throw new IllegalStateException("Score for option " + key + " is NaN");
            }
            if ((!this.isMaximize || valueOf.doubleValue() <= d) && (this.isMaximize || valueOf.doubleValue() >= d)) {
                this.logger.trace("Skipping current solution {} since its score {} is not better than the currently best {}.", new Object[]{key, valueOf, Double.valueOf(d)});
            } else {
                this.logger.trace("Updating best choice {} with {} since it is better than the current solution with performance {}", new Object[]{a, key, Double.valueOf(d)});
                d = valueOf.doubleValue();
                a = key;
            }
        }
        return a;
    }

    @Override // ai.libs.jaicore.search.algorithms.standard.mcts.IPathUpdatablePolicy
    public void updatePath(ILabeledPath<T, A> iLabeledPath, Double d, int i) {
        for (Object obj : iLabeledPath.getNodes()) {
            this.visitsPerNode.put(obj, Integer.valueOf(((Integer) this.visitsPerNode.computeIfAbsent(obj, obj2 -> {
                return 0;
            })).intValue() + 1));
            PriorityQueue priorityQueue = (PriorityQueue) this.statsPerNode.computeIfAbsent(obj, obj3 -> {
                return this.isMaximize ? new PriorityQueue((d2, d3) -> {
                    return Double.compare(d3.doubleValue(), d2.doubleValue());
                }) : new PriorityQueue();
            });
            if (priorityQueue.size() < this.s) {
                priorityQueue.add(d);
            } else if (((Double) priorityQueue.peek()).doubleValue() < d.doubleValue()) {
                priorityQueue.poll();
                priorityQueue.add(d);
            }
        }
    }

    @Override // ai.libs.jaicore.search.algorithms.standard.mcts.IGraphDependentPolicy
    public void setGraph(LabeledGraph<T, A> labeledGraph) {
        this.explorationGraph = labeledGraph;
    }
}
