/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.rl4j.policy;

import java.util.Random;
import org.deeplearning4j.rl4j.learning.StepCountable;
import org.deeplearning4j.rl4j.mdp.MDP;
import org.deeplearning4j.rl4j.network.NeuralNet;
import org.deeplearning4j.rl4j.policy.Policy;
import org.deeplearning4j.rl4j.space.ActionSpace;
import org.deeplearning4j.rl4j.space.Encodable;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class EpsGreedy<O extends Encodable, A, AS extends ActionSpace<A>>
extends Policy<O, A> {
    private static final Logger log = LoggerFactory.getLogger(EpsGreedy.class);
    private final Policy<O, A> policy;
    private final MDP<O, A, AS> mdp;
    private final int updateStart;
    private final int epsilonNbStep;
    private final Random rd;
    private final float minEpsilon;
    private final StepCountable learning;

    @Override
    public NeuralNet getNeuralNet() {
        return this.policy.getNeuralNet();
    }

    @Override
    public A nextAction(INDArray input) {
        float ep = this.getEpsilon();
        if (this.learning.getStepCounter() % 500 == 1) {
            log.info("EP: " + ep + " " + this.learning.getStepCounter());
        }
        if (this.rd.nextFloat() > ep) {
            return this.policy.nextAction(input);
        }
        return (A)this.mdp.getActionSpace().randomAction();
    }

    public float getEpsilon() {
        return Math.min(1.0f, Math.max(this.minEpsilon, 1.0f - (float)(this.learning.getStepCounter() - this.updateStart) * 1.0f / (float)this.epsilonNbStep));
    }

    public EpsGreedy(Policy<O, A> policy, MDP<O, A, AS> mdp, int updateStart, int epsilonNbStep, Random rd, float minEpsilon, StepCountable learning) {
        this.policy = policy;
        this.mdp = mdp;
        this.updateStart = updateStart;
        this.epsilonNbStep = epsilonNbStep;
        this.rd = rd;
        this.minEpsilon = minEpsilon;
        this.learning = learning;
    }
}

