package ai.libs.jaicore.search.algorithms.standard.mcts.comparison;

import ai.libs.jaicore.search.algorithms.standard.mcts.ActionPredictionFailedException;
import ai.libs.jaicore.search.algorithms.standard.mcts.IPathUpdatablePolicy;
import java.util.HashMap;
import java.util.Iterator;
import java.util.Map;
import java.util.Objects;
import java.util.function.ToDoubleFunction;
import org.apache.commons.math3.stat.descriptive.DescriptiveStatistics;
import org.api4.java.datastructure.graph.ILabeledPath;

/* loaded from: input_file:ai/libs/jaicore/search/algorithms/standard/mcts/comparison/FixedCommitmentPolicy.class */
public class FixedCommitmentPolicy<N, A> implements IPathUpdatablePolicy<N, A, Double> {
    private final Map<N, DescriptiveStatistics> observationsPerNode = new HashMap();
    private final int k;
    private final ToDoubleFunction<DescriptiveStatistics> metric;

    public FixedCommitmentPolicy(int i, ToDoubleFunction<DescriptiveStatistics> toDoubleFunction) {
        this.k = i;
        this.metric = toDoubleFunction;
    }

    @Override // ai.libs.jaicore.search.algorithms.standard.mcts.IPolicy
    public A getAction(N n, Map<A, N> map) throws ActionPredictionFailedException {
        Map.Entry<A, N> entry = null;
        Map.Entry<A, N> entry2 = null;
        int i = Integer.MAX_VALUE;
        double d = Double.MAX_VALUE;
        for (Map.Entry<A, N> entry3 : map.entrySet()) {
            DescriptiveStatistics computeIfAbsent = this.observationsPerNode.computeIfAbsent(entry3.getValue(), obj -> {
                return new DescriptiveStatistics();
            });
            int n2 = (int) computeIfAbsent.getN();
            if (n2 < i) {
                entry = entry3;
                i = n2;
            }
            double applyAsDouble = this.metric.applyAsDouble(computeIfAbsent);
            if (applyAsDouble < d) {
                d = applyAsDouble;
                entry2 = entry3;
            }
        }
        Objects.requireNonNull(entry);
        Objects.requireNonNull(entry2);
        return i < this.k ? entry.getKey() : entry2.getKey();
    }

    @Override // ai.libs.jaicore.search.algorithms.standard.mcts.IPathUpdatablePolicy
    public void updatePath(ILabeledPath<N, A> iLabeledPath, Double d, int i) {
        Iterator it = iLabeledPath.getNodes().iterator();
        while (it.hasNext()) {
            ((DescriptiveStatistics) this.observationsPerNode.computeIfAbsent(it.next(), obj -> {
                return new DescriptiveStatistics();
            })).addValue(d.doubleValue());
        }
    }
}
