/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.rl4j.learning.async.a3c.discrete;

import java.util.Random;
import java.util.Stack;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.rl4j.learning.Learning;
import org.deeplearning4j.rl4j.learning.async.AsyncGlobal;
import org.deeplearning4j.rl4j.learning.async.AsyncThreadDiscrete;
import org.deeplearning4j.rl4j.learning.async.MiniTrans;
import org.deeplearning4j.rl4j.learning.async.a3c.discrete.A3CDiscrete;
import org.deeplearning4j.rl4j.mdp.MDP;
import org.deeplearning4j.rl4j.network.ac.IActorCritic;
import org.deeplearning4j.rl4j.policy.ACPolicy;
import org.deeplearning4j.rl4j.policy.Policy;
import org.deeplearning4j.rl4j.space.DiscreteSpace;
import org.deeplearning4j.rl4j.space.Encodable;
import org.deeplearning4j.rl4j.util.DataManager;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.NDArrayIndex;

public class A3CThreadDiscrete<O extends Encodable>
extends AsyncThreadDiscrete<O, IActorCritic> {
    protected final A3CDiscrete.A3CConfiguration conf;
    protected final MDP<O, Integer, DiscreteSpace> mdp;
    protected final AsyncGlobal<IActorCritic> asyncGlobal;
    protected final int threadNumber;
    protected final DataManager dataManager;
    private final Random random;

    public A3CThreadDiscrete(MDP<O, Integer, DiscreteSpace> mdp, AsyncGlobal<IActorCritic> asyncGlobal, A3CDiscrete.A3CConfiguration a3cc, int threadNumber, DataManager dataManager) {
        super(asyncGlobal, threadNumber);
        this.conf = a3cc;
        this.asyncGlobal = asyncGlobal;
        this.threadNumber = threadNumber;
        this.mdp = mdp;
        this.dataManager = dataManager;
        ((DiscreteSpace)mdp.getActionSpace()).setSeed(this.conf.getSeed() + threadNumber);
        this.random = new Random(this.conf.getSeed() + threadNumber);
    }

    @Override
    protected Policy<O, Integer> getPolicy(IActorCritic net) {
        return new ACPolicy(net, this.random);
    }

    @Override
    public Gradient[] calcGradient(IActorCritic iac, Stack<MiniTrans<Integer>> rewards) {
        MiniTrans<Integer> minTrans = rewards.pop();
        int size = rewards.size();
        boolean recurrent = this.getAsyncGlobal().getCurrent().isRecurrent();
        int[] shape = this.getHistoryProcessor() == null ? this.mdp.getObservationSpace().getShape() : this.getHistoryProcessor().getConf().getShape();
        int[] nshape = recurrent ? Learning.makeShape(1, shape, size) : Learning.makeShape(size, shape);
        INDArray input = Nd4j.create((int[])nshape);
        INDArray targets = recurrent ? Nd4j.create((int[])new int[]{1, 1, size}) : Nd4j.create((int)size, (int)1);
        INDArray logSoftmax = recurrent ? Nd4j.zeros((int[])new int[]{1, ((DiscreteSpace)this.mdp.getActionSpace()).getSize(), size}) : Nd4j.zeros((long)size, (long)((DiscreteSpace)this.mdp.getActionSpace()).getSize());
        double r = minTrans.getReward();
        for (int i = size - 1; i >= 0; --i) {
            minTrans = rewards.pop();
            r = minTrans.getReward() + this.conf.getGamma() * r;
            if (recurrent) {
                input.get(new INDArrayIndex[]{NDArrayIndex.point((long)0L), NDArrayIndex.all(), NDArrayIndex.point((long)i)}).assign(minTrans.getObs());
            } else {
                input.putRow((long)i, minTrans.getObs());
            }
            targets.putScalar((long)i, r);
            double expectedV = minTrans.getOutput()[0].getDouble(0L);
            double advantage = r - expectedV;
            if (recurrent) {
                logSoftmax.putScalar(0L, (long)minTrans.getAction().intValue(), (long)i, advantage);
                continue;
            }
            logSoftmax.putScalar((long)i, (long)minTrans.getAction().intValue(), advantage);
        }
        return iac.gradient(input, new INDArray[]{targets, logSoftmax});
    }

    @Override
    public A3CDiscrete.A3CConfiguration getConf() {
        return this.conf;
    }

    @Override
    public MDP<O, Integer, DiscreteSpace> getMdp() {
        return this.mdp;
    }

    @Override
    public AsyncGlobal<IActorCritic> getAsyncGlobal() {
        return this.asyncGlobal;
    }

    @Override
    public int getThreadNumber() {
        return this.threadNumber;
    }

    @Override
    public DataManager getDataManager() {
        return this.dataManager;
    }
}

