package ai.libs.jaicore.search.algorithms.standard.mcts;

import java.util.HashMap;
import java.util.Iterator;
import java.util.Map;
import org.api4.java.common.control.ILoggingCustomizable;
import org.api4.java.datastructure.graph.ILabeledPath;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:ai/libs/jaicore/search/algorithms/standard/mcts/SPUCBPolicy.class */
public class SPUCBPolicy<N, A> extends UCBPolicy<N, A> implements ILoggingCustomizable {
    private String loggerName;
    private Logger logger;
    private final double bigD;
    private Map<AUpdatingPolicy<N, A>.NodeLabel, Double> squaredObservations;

    public SPUCBPolicy(double d) {
        this.logger = LoggerFactory.getLogger(SPUCBPolicy.class);
        this.squaredObservations = new HashMap();
        this.bigD = d;
    }

    public SPUCBPolicy(boolean z, double d) {
        super(z);
        this.logger = LoggerFactory.getLogger(SPUCBPolicy.class);
        this.squaredObservations = new HashMap();
        this.bigD = d;
    }

    @Override // ai.libs.jaicore.search.algorithms.standard.mcts.UCBPolicy, ai.libs.jaicore.search.algorithms.standard.mcts.AUpdatingPolicy
    public String getLoggerName() {
        return this.loggerName;
    }

    @Override // ai.libs.jaicore.search.algorithms.standard.mcts.UCBPolicy, ai.libs.jaicore.search.algorithms.standard.mcts.AUpdatingPolicy
    public void setLoggerName(String str) {
        this.loggerName = str;
        super.setLoggerName(str + "._updating");
        this.logger = LoggerFactory.getLogger(str);
    }

    /* JADX WARN: Can't rename method to resolve collision */
    /* JADX WARN: Multi-variable type inference failed */
    @Override // ai.libs.jaicore.search.algorithms.standard.mcts.AUpdatingPolicy, ai.libs.jaicore.search.algorithms.standard.mcts.IPathUpdatablePolicy
    public void updatePath(ILabeledPath<N, A> iLabeledPath, Double d, int i) {
        super.updatePath((ILabeledPath) iLabeledPath, d, i);
        Iterator it = iLabeledPath.getNodes().iterator();
        while (it.hasNext()) {
            AUpdatingPolicy<N, A>.NodeLabel labelOfNode = getLabelOfNode(it.next());
            this.squaredObservations.put(labelOfNode, Double.valueOf(this.squaredObservations.computeIfAbsent(labelOfNode, nodeLabel -> {
                return Double.valueOf(0.0d);
            }).doubleValue() + Math.pow(d.doubleValue(), 2.0d)));
        }
    }

    @Override // ai.libs.jaicore.search.algorithms.standard.mcts.UCBPolicy, ai.libs.jaicore.search.algorithms.standard.mcts.AUpdatingPolicy
    public double getScore(N n, N n2) {
        double score = super.getScore(n, n2);
        int i = getLabelOfNode(n2).visits;
        AUpdatingPolicy<N, A>.NodeLabel labelOfNode = getLabelOfNode(n);
        double sqrt = (isMaximize() ? 1 : -1) * Math.sqrt((((this.squaredObservations.containsKey(labelOfNode) ? this.squaredObservations.get(labelOfNode).doubleValue() : 0.0d) - (i * Math.pow(labelOfNode.mean, 2.0d))) + this.bigD) / i);
        double d = score + sqrt;
        this.logger.debug("Computed score for child {}: {} = {} + {}", new Object[]{n2, Double.valueOf(d), Double.valueOf(score), Double.valueOf(sqrt)});
        return d;
    }
}
