/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.rl4j.learning.async.nstep.discrete;

import java.util.Random;
import java.util.Stack;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.rl4j.learning.Learning;
import org.deeplearning4j.rl4j.learning.async.AsyncGlobal;
import org.deeplearning4j.rl4j.learning.async.AsyncThreadDiscrete;
import org.deeplearning4j.rl4j.learning.async.MiniTrans;
import org.deeplearning4j.rl4j.learning.async.nstep.discrete.AsyncNStepQLearningDiscrete;
import org.deeplearning4j.rl4j.mdp.MDP;
import org.deeplearning4j.rl4j.network.dqn.IDQN;
import org.deeplearning4j.rl4j.policy.DQNPolicy;
import org.deeplearning4j.rl4j.policy.EpsGreedy;
import org.deeplearning4j.rl4j.policy.Policy;
import org.deeplearning4j.rl4j.space.DiscreteSpace;
import org.deeplearning4j.rl4j.space.Encodable;
import org.deeplearning4j.rl4j.util.DataManager;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;

public class AsyncNStepQLearningThreadDiscrete<O extends Encodable>
extends AsyncThreadDiscrete<O, IDQN> {
    protected final AsyncNStepQLearningDiscrete.AsyncNStepQLConfiguration conf;
    protected final MDP<O, Integer, DiscreteSpace> mdp;
    protected final AsyncGlobal<IDQN> asyncGlobal;
    protected final int threadNumber;
    protected final DataManager dataManager;
    private final Random random;

    public AsyncNStepQLearningThreadDiscrete(MDP<O, Integer, DiscreteSpace> mdp, AsyncGlobal<IDQN> asyncGlobal, AsyncNStepQLearningDiscrete.AsyncNStepQLConfiguration conf, int threadNumber, DataManager dataManager) {
        super(asyncGlobal, threadNumber);
        this.conf = conf;
        this.asyncGlobal = asyncGlobal;
        this.threadNumber = threadNumber;
        this.mdp = mdp;
        this.dataManager = dataManager;
        ((DiscreteSpace)mdp.getActionSpace()).setSeed(conf.getSeed() + threadNumber);
        this.random = new Random(conf.getSeed() + threadNumber);
    }

    @Override
    public Policy<O, Integer> getPolicy(IDQN nn) {
        return new EpsGreedy<O, Integer, DiscreteSpace>(new DQNPolicy(nn), this.mdp, this.conf.getUpdateStart(), this.conf.getEpsilonNbStep(), this.random, this.conf.getMinEpsilon(), this);
    }

    @Override
    public Gradient[] calcGradient(IDQN current, Stack<MiniTrans<Integer>> rewards) {
        MiniTrans<Integer> minTrans = rewards.pop();
        int size = rewards.size();
        int[] shape = this.getHistoryProcessor() == null ? this.mdp.getObservationSpace().getShape() : this.getHistoryProcessor().getConf().getShape();
        int[] nshape = Learning.makeShape(size, shape);
        INDArray input = Nd4j.create((int[])nshape);
        INDArray targets = Nd4j.create((int)size, (int)((DiscreteSpace)this.mdp.getActionSpace()).getSize());
        double r = minTrans.getReward();
        for (int i = size - 1; i >= 0; --i) {
            minTrans = rewards.pop();
            r = minTrans.getReward() + this.conf.getGamma() * r;
            input.putRow((long)i, minTrans.getObs());
            INDArray row = minTrans.getOutput()[0];
            row = row.putScalar((long)minTrans.getAction().intValue(), r);
            targets.putRow((long)i, row);
        }
        return current.gradient(input, targets);
    }

    @Override
    public AsyncNStepQLearningDiscrete.AsyncNStepQLConfiguration getConf() {
        return this.conf;
    }

    @Override
    public MDP<O, Integer, DiscreteSpace> getMdp() {
        return this.mdp;
    }

    @Override
    public AsyncGlobal<IDQN> getAsyncGlobal() {
        return this.asyncGlobal;
    }

    @Override
    public int getThreadNumber() {
        return this.threadNumber;
    }

    @Override
    public DataManager getDataManager() {
        return this.dataManager;
    }
}

