package ai.libs.jaicore.search.algorithms.mdp.mcts.spuct;

import ai.libs.jaicore.search.algorithms.mdp.mcts.NodeLabel;
import ai.libs.jaicore.search.algorithms.mdp.mcts.uct.UCBPolicy;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.api4.java.common.control.ILoggingCustomizable;
import org.api4.java.datastructure.graph.ILabeledPath;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:ai/libs/jaicore/search/algorithms/mdp/mcts/spuct/SPUCBPolicy.class */
public class SPUCBPolicy<N, A> extends UCBPolicy<N, A> implements ILoggingCustomizable {
    private String loggerName;
    private Logger logger;
    private final double bigD;
    private Map<NodeLabel<A>, Double> squaredObservations;

    public SPUCBPolicy(double d, double d2) {
        this(d, true, d2);
    }

    public SPUCBPolicy(double d, boolean z, double d2) {
        super(d, z);
        this.logger = LoggerFactory.getLogger(SPUCBPolicy.class);
        this.squaredObservations = new HashMap();
        this.bigD = d2;
    }

    @Override // ai.libs.jaicore.search.algorithms.mdp.mcts.uct.UCBPolicy, ai.libs.jaicore.search.algorithms.mdp.mcts.uct.AUpdatingPolicy
    public String getLoggerName() {
        return this.loggerName;
    }

    @Override // ai.libs.jaicore.search.algorithms.mdp.mcts.uct.UCBPolicy, ai.libs.jaicore.search.algorithms.mdp.mcts.uct.AUpdatingPolicy
    public void setLoggerName(String str) {
        this.loggerName = str;
        super.setLoggerName(str + "._updating");
        this.logger = LoggerFactory.getLogger(str);
    }

    /* JADX WARN: Multi-variable type inference failed */
    @Override // ai.libs.jaicore.search.algorithms.mdp.mcts.uct.AUpdatingPolicy, ai.libs.jaicore.search.algorithms.mdp.mcts.IPathUpdatablePolicy
    public void updatePath(ILabeledPath<N, A> iLabeledPath, List<Double> list) {
        super.updatePath(iLabeledPath, list);
        List nodes = iLabeledPath.getNodes();
        double d = 0.0d;
        for (int size = nodes.size() - 2; size >= 0; size--) {
            NodeLabel<A> labelOfNode = getLabelOfNode(nodes.get(size));
            if (!Double.isNaN(d) && list.get(size) != null) {
                d = list.get(size).doubleValue() + (getGamma() * d);
            } else if (!Double.isNaN(d)) {
                d = Double.NaN;
            }
            this.squaredObservations.put(labelOfNode, Double.valueOf(this.squaredObservations.computeIfAbsent(labelOfNode, nodeLabel -> {
                return Double.valueOf(0.0d);
            }).doubleValue() + Math.pow(d, 2.0d)));
        }
    }

    @Override // ai.libs.jaicore.search.algorithms.mdp.mcts.uct.UCBPolicy, ai.libs.jaicore.search.algorithms.mdp.mcts.uct.AUpdatingPolicy
    public double getScore(N n, A a) {
        double empiricalMean = super.getEmpiricalMean(n, a);
        double empiricalMean2 = empiricalMean + super.getEmpiricalMean(n, a);
        NodeLabel<A> labelOfNode = getLabelOfNode(n);
        int numPulls = labelOfNode.getNumPulls(a);
        double sqrt = (isMaximize() ? 1 : -1) * Math.sqrt((((this.squaredObservations.containsKey(labelOfNode) ? this.squaredObservations.get(labelOfNode).doubleValue() : 0.0d) - (numPulls * Math.pow(empiricalMean, 2.0d))) + this.bigD) / numPulls);
        double d = empiricalMean2 + sqrt;
        this.logger.debug("Computed score for action {}: {} = {} + {}", new Object[]{a, Double.valueOf(d), Double.valueOf(empiricalMean2), Double.valueOf(sqrt)});
        return d;
    }
}
