package ai.libs.jaicore.search.algorithms.standard.mcts.thompson;

import ai.libs.jaicore.basic.sets.Pair;
import ai.libs.jaicore.search.algorithms.standard.mcts.ActionPredictionFailedException;
import ai.libs.jaicore.search.algorithms.standard.mcts.IPathUpdatablePolicy;
import com.google.common.eventbus.EventBus;
import java.util.HashMap;
import java.util.Map;
import org.apache.commons.math3.distribution.GammaDistribution;
import org.apache.commons.math3.distribution.NormalDistribution;
import org.api4.java.ai.graphsearch.problem.implicit.graphgenerator.INodeGoalTester;
import org.api4.java.common.attributedobjects.IObjectEvaluator;
import org.api4.java.common.attributedobjects.ObjectEvaluationFailedException;
import org.api4.java.common.control.ILoggingCustomizable;
import org.api4.java.common.event.IRelaxedEventEmitter;
import org.api4.java.datastructure.graph.ILabeledPath;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:ai/libs/jaicore/search/algorithms/standard/mcts/thompson/DNGPolicy.class */
public class DNGPolicy<N, A> implements IPathUpdatablePolicy<N, A, Double>, ILoggingCustomizable, IRelaxedEventEmitter {
    private final double initLambda;
    private static final double INIT_ALPHA = 1.0d;
    private final double initBeta;
    private static final double INIT_MU = 1.0d;
    private final INodeGoalTester<N, A> goalTester;
    private final IObjectEvaluator<N, Double> leafNodeEvaluator;
    private final double varianceFactor;
    private Logger logger = LoggerFactory.getLogger(DNGPolicy.class);
    private EventBus eventBus = new EventBus();
    private final Map<N, Double> alpha = new HashMap();
    private final Map<N, Double> beta = new HashMap();
    private final Map<N, Double> mu = new HashMap();
    private final Map<N, Double> lambda = new HashMap();

    public DNGPolicy(INodeGoalTester<N, A> iNodeGoalTester, IObjectEvaluator<N, Double> iObjectEvaluator, double d, double d2) {
        this.goalTester = iNodeGoalTester;
        this.leafNodeEvaluator = iObjectEvaluator;
        this.varianceFactor = d;
        this.initLambda = d2;
        this.initBeta = 1.0d / this.initLambda;
    }

    @Override // ai.libs.jaicore.search.algorithms.standard.mcts.IPolicy
    public A getAction(N n, Map<A, N> map) throws ActionPredictionFailedException {
        try {
            return sampleWithThompson(n, map);
        } catch (ObjectEvaluationFailedException e) {
            throw new ActionPredictionFailedException(e);
        } catch (InterruptedException e2) {
            this.logger.info("Policy thread has been interrupted. Re-interrupting myself, because no InterruptedException can be thrown here.");
            Thread.currentThread().interrupt();
            return null;
        }
    }

    @Override // ai.libs.jaicore.search.algorithms.standard.mcts.IPathUpdatablePolicy
    public void updatePath(ILabeledPath<N, A> iLabeledPath, Double d, int i) {
        for (Object obj : iLabeledPath.getNodes()) {
            double doubleValue = ((Double) this.lambda.computeIfAbsent(obj, obj2 -> {
                return Double.valueOf(this.initLambda);
            })).doubleValue();
            double doubleValue2 = ((Double) this.mu.computeIfAbsent(obj, obj3 -> {
                return Double.valueOf(1.0d);
            })).doubleValue();
            this.alpha.put(obj, Double.valueOf(((Double) this.alpha.computeIfAbsent(obj, obj4 -> {
                return Double.valueOf(1.0d);
            })).doubleValue() + 0.5d));
            this.beta.put(obj, Double.valueOf(((Double) this.beta.computeIfAbsent(obj, obj5 -> {
                return Double.valueOf(this.initBeta);
            })).doubleValue() + (((doubleValue * Math.pow(d.doubleValue() - doubleValue2, 2.0d)) / (doubleValue + 1.0d)) / 2.0d)));
            this.mu.put(obj, Double.valueOf(((doubleValue2 * doubleValue) + d.doubleValue()) / (doubleValue + 1.0d)));
            this.lambda.put(obj, Double.valueOf(doubleValue + 1.0d));
            this.eventBus.post(new DNGBeliefUpdateEvent(null, obj, this.mu.get(obj).doubleValue(), this.alpha.get(obj).doubleValue(), this.beta.get(obj).doubleValue(), this.lambda.get(obj).doubleValue()));
        }
    }

    public A sampleWithThompson(N n, Map<A, N> map) throws ObjectEvaluationFailedException, InterruptedException {
        A a = null;
        double d = Double.MAX_VALUE;
        for (Map.Entry<A, N> entry : map.entrySet()) {
            double qValue = getQValue(n, entry.getValue());
            this.eventBus.post(new DNGQSampleEvent(null, n, entry.getValue(), entry.getKey(), qValue));
            if (qValue < d) {
                a = entry.getKey();
                d = qValue;
            }
        }
        return a;
    }

    public double getQValue(N n, N n2) throws ObjectEvaluationFailedException, InterruptedException {
        return getValue(n2);
    }

    public Pair<Double, Double> sampleWithNormalGamma(N n) {
        double sample = new GammaDistribution(this.alpha.get(n).doubleValue(), this.beta.get(n).doubleValue()).sample();
        return new Pair<>(Double.valueOf(new NormalDistribution(this.mu.get(n).doubleValue(), 1.0d / (this.lambda.get(n).doubleValue() * sample)).sample()), Double.valueOf(sample));
    }

    public double getValue(N n) throws ObjectEvaluationFailedException, InterruptedException {
        if (this.goalTester.isGoal(n)) {
            return ((Double) this.leafNodeEvaluator.evaluate(n)).doubleValue();
        }
        Pair<Double, Double> sampleWithNormalGamma = sampleWithNormalGamma(n);
        return ((Double) sampleWithNormalGamma.getX()).doubleValue() - (this.varianceFactor * Math.sqrt(((Double) sampleWithNormalGamma.getY()).doubleValue()));
    }

    public String getLoggerName() {
        return this.logger.getName();
    }

    public void setLoggerName(String str) {
        this.logger = LoggerFactory.getLogger(str);
    }

    public void registerListener(Object obj) {
        this.eventBus.register(obj);
    }
}
