package ai.libs.jaicore.search.algorithms.standard.mcts;

import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
import org.api4.java.common.control.ILoggingCustomizable;
import org.api4.java.datastructure.graph.ILabeledPath;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:ai/libs/jaicore/search/algorithms/standard/mcts/AUpdatingPolicy.class */
public abstract class AUpdatingPolicy<N, A> implements IPathUpdatablePolicy<N, A, Double>, ILoggingCustomizable {
    private String loggerName;
    private Logger logger;
    private final boolean maximize;
    private final Map<N, AUpdatingPolicy<N, A>.NodeLabel> labels;
    static final /* synthetic */ boolean $assertionsDisabled;

    /* loaded from: input_file:ai/libs/jaicore/search/algorithms/standard/mcts/AUpdatingPolicy$NodeLabel.class */
    public class NodeLabel {
        protected double mean;
        protected int visits;

        public NodeLabel() {
        }

        public String toString() {
            return "NodeLabel [mean=" + this.mean + ", visits=" + this.visits + "]";
        }
    }

    public AUpdatingPolicy() {
        this(true);
    }

    public AUpdatingPolicy(boolean z) {
        this.logger = LoggerFactory.getLogger(AUpdatingPolicy.class);
        this.labels = new HashMap();
        this.maximize = z;
    }

    public AUpdatingPolicy<N, A>.NodeLabel getLabelOfNode(N n) {
        if (this.labels.containsKey(n)) {
            return this.labels.get(n);
        }
        throw new IllegalArgumentException("No label for node " + n);
    }

    public abstract double getScore(N n, N n2);

    public abstract A getActionBasedOnScores(Map<A, Double> map);

    @Override // ai.libs.jaicore.search.algorithms.standard.mcts.IPathUpdatablePolicy
    public void updatePath(ILabeledPath<N, A> iLabeledPath, Double d, int i) {
        this.logger.debug("Updating path {} with score {}", iLabeledPath, d);
        int i2 = Integer.MAX_VALUE;
        for (Object obj : iLabeledPath.getNodes()) {
            NodeLabel nodeLabel = (NodeLabel) this.labels.computeIfAbsent(obj, obj2 -> {
                return new NodeLabel();
            });
            nodeLabel.mean = ((nodeLabel.visits * nodeLabel.mean) + d.doubleValue()) / (nodeLabel.visits + 1);
            nodeLabel.visits++;
            this.logger.trace("Updated label of node {}. Visits now {} with mean {}", new Object[]{obj, Integer.valueOf(nodeLabel.visits), Double.valueOf(nodeLabel.mean)});
            if (nodeLabel.visits > i2) {
                throw new IllegalStateException("Illegal visits stats of child " + nodeLabel.visits + " compared to parent " + i2 + "\nCheck whether the searched graph is really a tree!");
            }
            i2 = nodeLabel.visits;
        }
        this.logger.debug("Path update completed.");
    }

    @Override // ai.libs.jaicore.search.algorithms.standard.mcts.IPolicy
    public A getAction(N n, Map<A, N> map) {
        Set<A> keySet = map.keySet();
        this.logger.debug("Deriving action for node {}. The {} options are: {}", new Object[]{n, Integer.valueOf(keySet.size()), map});
        List list = (List) keySet.stream().filter(obj -> {
            return !this.labels.containsKey(map.get(obj));
        }).collect(Collectors.toList());
        if (!list.isEmpty()) {
            A a = (A) list.get(0);
            this.labels.put(map.get(a), new NodeLabel());
            this.logger.info("Dictating action {}, because this was never played before.", a);
            return a;
        }
        this.logger.debug("All actions have been tried. Label is: {}", this.labels.get(n));
        HashMap hashMap = new HashMap();
        for (A a2 : keySet) {
            N n2 = map.get(a2);
            AUpdatingPolicy<N, A>.NodeLabel nodeLabel = this.labels.get(n2);
            if (!$assertionsDisabled && nodeLabel.visits == 0) {
                throw new AssertionError("Visits of node " + n2 + " cannot be 0 if we already used this action before!");
            }
            this.logger.trace("Considering action {} whose successor state has stats {} and {} visits", new Object[]{a2, Double.valueOf(nodeLabel.mean), Integer.valueOf(nodeLabel.visits)});
            Double valueOf = Double.valueOf(getScore(n, n2));
            if (valueOf.isNaN()) {
                throw new IllegalStateException("Score of action " + a2 + " is NaN, which it must not be!");
            }
            hashMap.put(a2, valueOf);
            if (!$assertionsDisabled && valueOf.isNaN()) {
                throw new AssertionError("The score of action " + a2 + " is NaN, which cannot be the case. Score mean is " + nodeLabel.mean + ", number of visits is " + nodeLabel.visits);
            }
        }
        A actionBasedOnScores = getActionBasedOnScores(hashMap);
        if (actionBasedOnScores == null) {
            throw new IllegalStateException("Would return null, but this must not be the case! Check the method that chooses an action given the scores.");
        }
        this.logger.info("Recommending action {}.", actionBasedOnScores);
        return actionBasedOnScores;
    }

    public boolean isMaximize() {
        return this.maximize;
    }

    public String getLoggerName() {
        return this.loggerName;
    }

    public void setLoggerName(String str) {
        this.loggerName = str;
        this.logger = LoggerFactory.getLogger(str);
        this.logger.info("Set logger of {} to {}", this, str);
    }

    static {
        $assertionsDisabled = !AUpdatingPolicy.class.desiredAssertionStatus();
    }
}
