package ai.libs.jaicore.search.algorithms.mdp.mcts.comparison;

import ai.libs.jaicore.search.algorithms.mdp.mcts.ActionPredictionFailedException;
import ai.libs.jaicore.search.algorithms.mdp.mcts.IPathUpdatablePolicy;
import java.util.Collection;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.function.ToDoubleFunction;
import org.apache.commons.math3.stat.descriptive.DescriptiveStatistics;
import org.api4.java.datastructure.graph.ILabeledPath;

/* loaded from: input_file:ai/libs/jaicore/search/algorithms/mdp/mcts/comparison/FixedCommitmentPolicy.class */
public class FixedCommitmentPolicy<N, A> implements IPathUpdatablePolicy<N, A, Double> {
    private final Map<N, Map<A, DescriptiveStatistics>> observationsPerNode = new HashMap();
    private final int k;
    private final ToDoubleFunction<DescriptiveStatistics> metric;

    public FixedCommitmentPolicy(int i, ToDoubleFunction<DescriptiveStatistics> toDoubleFunction) {
        this.k = i;
        this.metric = toDoubleFunction;
    }

    @Override // ai.libs.jaicore.search.algorithms.mdp.mcts.IPolicy
    public A getAction(N n, Collection<A> collection) throws ActionPredictionFailedException {
        A a = null;
        A a2 = null;
        int i = Integer.MAX_VALUE;
        double d = Double.MAX_VALUE;
        for (A a3 : collection) {
            DescriptiveStatistics computeIfAbsent = this.observationsPerNode.computeIfAbsent(n, obj -> {
                return new HashMap();
            }).computeIfAbsent(a3, obj2 -> {
                return new DescriptiveStatistics();
            });
            int n2 = (int) computeIfAbsent.getN();
            if (n2 < i) {
                a = a3;
                i = n2;
            }
            double applyAsDouble = this.metric.applyAsDouble(computeIfAbsent);
            if (applyAsDouble < d) {
                d = applyAsDouble;
                a2 = a3;
            }
        }
        Objects.requireNonNull(a);
        Objects.requireNonNull(a2);
        return i < this.k ? a : a2;
    }

    @Override // ai.libs.jaicore.search.algorithms.mdp.mcts.IPathUpdatablePolicy
    public void updatePath(ILabeledPath<N, A> iLabeledPath, List<Double> list) {
        List nodes = iLabeledPath.getNodes();
        List arcs = iLabeledPath.getArcs();
        double d = 0.0d;
        for (int size = nodes.size() - 2; size >= 0; size--) {
            Object obj = nodes.get(size);
            Object obj2 = arcs.get(size);
            d += list.get(size).doubleValue();
            ((DescriptiveStatistics) ((Map) this.observationsPerNode.computeIfAbsent(obj, obj3 -> {
                return new HashMap();
            })).computeIfAbsent(obj2, obj4 -> {
                return new DescriptiveStatistics();
            })).addValue(d);
        }
    }
}
