/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.rl4j.policy;

import java.util.Random;
import org.deeplearning4j.rl4j.network.dqn.IDQN;
import org.deeplearning4j.rl4j.policy.Policy;
import org.deeplearning4j.rl4j.space.Encodable;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.ops.transforms.Transforms;

public class BoltzmannQ<O extends Encodable>
extends Policy<O, Integer> {
    private final IDQN dqn;
    private final Random rd = new Random(123L);

    @Override
    public IDQN getNeuralNet() {
        return this.dqn;
    }

    @Override
    public Integer nextAction(INDArray input) {
        INDArray output = this.dqn.output(input);
        INDArray exp = Transforms.exp((INDArray)output);
        double sum = exp.sum(new int[]{1}).getDouble(0L);
        double picked = this.rd.nextDouble() * sum;
        for (int i = 0; i < exp.columns(); ++i) {
            if (!(picked < exp.getDouble((long)i))) continue;
            return i;
        }
        return -1;
    }

    public BoltzmannQ(IDQN dqn) {
        this.dqn = dqn;
    }
}

