package ai.libs.jaicore.search.algorithms.standard.mcts;

import java.util.Map;
import org.api4.java.common.control.ILoggingCustomizable;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:ai/libs/jaicore/search/algorithms/standard/mcts/UCBPolicy.class */
public class UCBPolicy<T, A> extends AUpdatingPolicy<T, A> implements ILoggingCustomizable {
    private String loggerName;
    private Logger logger;
    private double explorationConstant;

    public UCBPolicy() {
        this.logger = LoggerFactory.getLogger(UCBPolicy.class);
        this.explorationConstant = Math.sqrt(2.0d);
    }

    public UCBPolicy(double d) {
        this.logger = LoggerFactory.getLogger(UCBPolicy.class);
        this.explorationConstant = Math.sqrt(2.0d);
        this.explorationConstant = d;
    }

    public UCBPolicy(boolean z) {
        super(z);
        this.logger = LoggerFactory.getLogger(UCBPolicy.class);
        this.explorationConstant = Math.sqrt(2.0d);
    }

    @Override // ai.libs.jaicore.search.algorithms.standard.mcts.AUpdatingPolicy
    public String getLoggerName() {
        return this.loggerName;
    }

    @Override // ai.libs.jaicore.search.algorithms.standard.mcts.AUpdatingPolicy
    public void setLoggerName(String str) {
        this.loggerName = str;
        super.setLoggerName(str + "._updating");
        this.logger = LoggerFactory.getLogger(str);
    }

    @Override // ai.libs.jaicore.search.algorithms.standard.mcts.AUpdatingPolicy
    public double getScore(T t, T t2) {
        AUpdatingPolicy<T, A>.NodeLabel labelOfNode = getLabelOfNode(t);
        AUpdatingPolicy<T, A>.NodeLabel labelOfNode2 = getLabelOfNode(t2);
        double sqrt = (isMaximize() ? 1 : -1) * this.explorationConstant * Math.sqrt(Math.log(labelOfNode.visits) / labelOfNode2.visits);
        double d = labelOfNode2.mean + sqrt;
        Logger logger = this.logger;
        Object[] objArr = new Object[7];
        objArr[0] = Double.valueOf(d);
        objArr[1] = Double.valueOf(labelOfNode2.mean);
        objArr[2] = Integer.valueOf(isMaximize() ? 1 : -1);
        objArr[3] = Double.valueOf(this.explorationConstant);
        objArr[4] = Integer.valueOf(labelOfNode.visits);
        objArr[5] = Integer.valueOf(labelOfNode2.visits);
        objArr[6] = Double.valueOf(sqrt);
        logger.trace("Computed UCB score {} = {} + {} * {} * sqrt(log({})/{}). That is, exploration term is {}", objArr);
        return d;
    }

    public double getExplorationConstant() {
        return this.explorationConstant;
    }

    public void setExplorationConstant(double d) {
        this.explorationConstant = d;
    }

    @Override // ai.libs.jaicore.search.algorithms.standard.mcts.AUpdatingPolicy
    public A getActionBasedOnScores(Map<A, Double> map) {
        A a = null;
        this.logger.debug("Getting action for scores {}", map);
        double d = (isMaximize() ? -1 : 1) * Double.MAX_VALUE;
        for (Map.Entry<A, Double> entry : map.entrySet()) {
            A key = entry.getKey();
            double doubleValue = entry.getValue().doubleValue();
            if ((!isMaximize() || doubleValue <= d) && (isMaximize() || doubleValue >= d)) {
                this.logger.trace("Skipping current solution {} since its score {} is not better than the currently best {}.", new Object[]{key, Double.valueOf(doubleValue), Double.valueOf(d)});
            } else {
                this.logger.trace("Updating best choice {} with {} since it is better than the current solution with performance {}", new Object[]{a, key, Double.valueOf(d)});
                d = doubleValue;
                a = key;
            }
        }
        return a;
    }
}
