package ai.libs.jaicore.search.algorithms.mdp.mcts.thompson;

import ai.libs.jaicore.basic.sets.Pair;
import ai.libs.jaicore.search.algorithms.mdp.mcts.ActionPredictionFailedException;
import ai.libs.jaicore.search.algorithms.mdp.mcts.IPathUpdatablePolicy;
import com.google.common.eventbus.EventBus;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.function.Predicate;
import org.apache.commons.math3.distribution.GammaDistribution;
import org.apache.commons.math3.distribution.NormalDistribution;
import org.api4.java.common.control.ILoggingCustomizable;
import org.api4.java.common.event.IRelaxedEventEmitter;
import org.api4.java.datastructure.graph.ILabeledPath;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:ai/libs/jaicore/search/algorithms/mdp/mcts/thompson/DNGPolicy.class */
public class DNGPolicy<N, A> implements IPathUpdatablePolicy<N, A, Double>, ILoggingCustomizable, IRelaxedEventEmitter {
    private final boolean maximize;
    private final double initLambda;
    private static final double INIT_ALPHA = 1.0d;
    private final double initBeta;
    private static final double INIT_MU = 0.5d;
    private final double gammaMDP;
    private final Predicate<N> terminalStatePredicate;
    private final double varianceFactor;
    private Logger logger = LoggerFactory.getLogger(DNGPolicy.class);
    private EventBus eventBus = new EventBus();
    private final Map<N, Double> alpha = new HashMap();
    private final Map<N, Double> beta = new HashMap();
    private final Map<N, Double> mu = new HashMap();
    private final Map<N, Double> lambda = new HashMap();
    private final Map<N, Map<A, Map<N, Integer>>> rho = new HashMap();
    private final Map<N, Map<A, Double>> rewardsMDP = new HashMap();
    private boolean sampling = true;

    public DNGPolicy(double d, Predicate<N> predicate, double d2, double d3, boolean z) {
        this.gammaMDP = d;
        this.terminalStatePredicate = predicate;
        this.varianceFactor = d2;
        this.initLambda = d3;
        this.initBeta = INIT_ALPHA / this.initLambda;
        this.maximize = z;
    }

    public boolean isSampling() {
        return this.sampling;
    }

    public void setSampling(boolean z) {
        this.sampling = z;
    }

    @Override // ai.libs.jaicore.search.algorithms.mdp.mcts.IPolicy
    public A getAction(N n, Collection<A> collection) throws ActionPredictionFailedException, InterruptedException {
        return sampleWithThompson(n, collection);
    }

    public A sampleWithThompson(N n, Collection<A> collection) throws InterruptedException {
        A a = null;
        this.logger.info("Determining best action for state {}", n);
        double d = (this.maximize ? -1 : 1) * Double.MAX_VALUE;
        for (A a2 : collection) {
            double qValue = getQValue(n, a2);
            this.logger.debug("Score for action {} is {}", a2, Double.valueOf(qValue));
            this.eventBus.post(new DNGQSampleEvent(null, n, a2, qValue));
            if (a == null || qValue < d || (this.maximize && qValue > d)) {
                a = a2;
                d = qValue;
                this.logger.debug("Considering this as the new best action.");
            }
        }
        Objects.requireNonNull(a, "Best action cannot be null if there were " + collection.size() + " options!");
        this.logger.info("Recommending action {}", a);
        return a;
    }

    /* JADX WARN: Multi-variable type inference failed */
    public double getQValue(N n, A a) throws InterruptedException {
        Map<N, Integer> map = this.rho.get(n).get(a);
        if (map == null) {
            throw new IllegalStateException("Have no rho vector for state/action pair " + n + "/" + a);
        }
        ArrayList arrayList = new ArrayList(map.keySet());
        int size = arrayList.size();
        if (map.size() < size) {
            throw new IllegalStateException("The rho vector for state/action pair " + n + "/" + a + " is incomplete and only has " + map.size() + " instead of " + size + " entries.");
        }
        double d = 0.0d;
        this.logger.debug("Now determining q-value of action {}. Sampling: {}", a, Boolean.valueOf(this.sampling));
        if (this.sampling) {
            double[] dArr = new double[size];
            double d2 = 0.0d;
            for (int i = 0; i < size; i++) {
                double sample = new GammaDistribution(map.get(arrayList.get(i)).intValue(), INIT_ALPHA).sample();
                dArr[i] = sample;
                d2 += sample;
            }
            if (d2 == 0.0d) {
                throw new IllegalStateException("The gamma estimates must not sum up to 0!");
            }
            for (int i2 = 0; i2 < size; i2++) {
                d += (dArr[i2] / d2) * getValue(arrayList.get(i2));
            }
        } else {
            double intValue = map.values().stream().reduce((num, num2) -> {
                return Integer.valueOf(num.intValue() + num2.intValue());
            }).get().intValue();
            while (arrayList.iterator().hasNext()) {
                d += map.get(r0.next()).intValue() / intValue;
            }
        }
        double doubleValue = this.rewardsMDP.get(n).get(a).doubleValue();
        double d3 = doubleValue + (this.gammaMDP * d);
        this.logger.debug("Considering a reward of {} + {} * {} = {}", new Object[]{Double.valueOf(doubleValue), Double.valueOf(this.gammaMDP), Double.valueOf(d), Double.valueOf(d3)});
        return d3;
    }

    public Pair<Double, Double> sampleWithNormalGamma(N n) {
        double sample = new GammaDistribution(this.alpha.get(n).doubleValue(), this.beta.get(n).doubleValue()).sample();
        double doubleValue = INIT_ALPHA / (this.lambda.get(n).doubleValue() * sample);
        return new Pair<>(Double.valueOf(doubleValue > 0.0d ? new NormalDistribution(this.mu.get(n).doubleValue(), doubleValue).sample() : this.mu.get(n).doubleValue()), Double.valueOf(sample));
    }

    public double getValue(N n) throws InterruptedException {
        boolean test = this.terminalStatePredicate.test(n);
        if (Thread.interrupted()) {
            throw new InterruptedException();
        }
        if (test) {
            this.logger.debug("Returning value of 0 for terminal state {}", n);
            return 0.0d;
        }
        if (!this.sampling) {
            double doubleValue = this.mu.get(n).doubleValue();
            this.logger.debug("Returning fixed value of {}", Double.valueOf(doubleValue));
            return doubleValue;
        }
        Pair<Double, Double> sampleWithNormalGamma = sampleWithNormalGamma(n);
        double doubleValue2 = ((Double) sampleWithNormalGamma.getX()).doubleValue() - (this.varianceFactor * Math.sqrt(((Double) sampleWithNormalGamma.getY()).doubleValue()));
        this.logger.debug("Returning sampled value of {}", Double.valueOf(doubleValue2));
        return doubleValue2;
    }

    @Override // ai.libs.jaicore.search.algorithms.mdp.mcts.IPathUpdatablePolicy
    public void updatePath(ILabeledPath<N, A> iLabeledPath, List<Double> list) {
        List nodes = iLabeledPath.getNodes();
        List arcs = iLabeledPath.getArcs();
        int numberOfNodes = iLabeledPath.getNumberOfNodes();
        this.logger.info("Updating path with scores {}", list);
        double d = 0.0d;
        for (int i = numberOfNodes - 2; i >= 0; i--) {
            Object obj = nodes.get(i);
            Object obj2 = arcs.get(i);
            double doubleValue = list.get(i) != null ? list.get(i).doubleValue() : Double.NaN;
            ((Map) this.rewardsMDP.computeIfAbsent(obj, obj3 -> {
                return new HashMap();
            })).putIfAbsent(obj2, Double.valueOf(doubleValue));
            d = doubleValue + (this.gammaMDP * d);
            this.logger.debug("Updating statistics for {}-th node with accumulated score {}. State here is: {}", new Object[]{Integer.valueOf(i), Double.valueOf(d), obj});
            if (this.lambda.containsKey(obj)) {
                double doubleValue2 = this.lambda.get(obj).doubleValue();
                double doubleValue3 = this.mu.get(obj).doubleValue();
                this.alpha.put(obj, Double.valueOf(this.alpha.get(obj).doubleValue() + INIT_MU));
                this.beta.put(obj, Double.valueOf(this.beta.get(obj).doubleValue() + (((doubleValue2 * Math.pow(d - doubleValue3, 2.0d)) / (doubleValue2 + INIT_ALPHA)) / 2.0d)));
                this.mu.put(obj, Double.valueOf(((doubleValue3 * doubleValue2) + d) / (doubleValue2 + INIT_ALPHA)));
                this.lambda.put(obj, Double.valueOf(doubleValue2 + INIT_ALPHA));
                Object obj4 = nodes.get(i + 1);
                ((Map) this.rho.get(obj).computeIfAbsent(obj2, obj5 -> {
                    return new HashMap();
                })).put(obj4, Integer.valueOf(((Integer) this.rho.get(obj).get(obj2).computeIfAbsent(obj4, obj6 -> {
                    return 0;
                })).intValue() + 1));
                this.eventBus.post(new DNGBeliefUpdateEvent(null, obj, this.mu.get(obj).doubleValue(), this.alpha.get(obj).doubleValue(), this.beta.get(obj).doubleValue(), this.lambda.get(obj).doubleValue()));
            } else {
                this.lambda.put(obj, Double.valueOf(this.initLambda));
                this.mu.put(obj, Double.valueOf(INIT_MU));
                this.alpha.put(obj, Double.valueOf(INIT_ALPHA));
                this.beta.put(obj, Double.valueOf(this.initBeta));
                Object obj7 = nodes.get(i + 1);
                HashMap hashMap = new HashMap();
                hashMap.put(obj7, 1);
                HashMap hashMap2 = new HashMap();
                hashMap2.put(obj2, hashMap);
                this.rho.put(obj, hashMap2);
            }
        }
    }

    public String getLoggerName() {
        return this.logger.getName();
    }

    public void setLoggerName(String str) {
        this.logger = LoggerFactory.getLogger(str);
        this.logger.info("Logger is now {}", str);
    }

    public void registerListener(Object obj) {
        this.eventBus.register(obj);
    }
}
