package org.deeplearning4j.rl4j.learning.async.a3c.discrete;

import java.util.Random;
import java.util.Stack;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.rl4j.learning.Learning;
import org.deeplearning4j.rl4j.learning.async.AsyncGlobal;
import org.deeplearning4j.rl4j.learning.async.AsyncThreadDiscrete;
import org.deeplearning4j.rl4j.learning.async.MiniTrans;
import org.deeplearning4j.rl4j.learning.async.a3c.discrete.A3CDiscrete;
import org.deeplearning4j.rl4j.mdp.MDP;
import org.deeplearning4j.rl4j.network.ac.IActorCritic;
import org.deeplearning4j.rl4j.policy.ACPolicy;
import org.deeplearning4j.rl4j.policy.Policy;
import org.deeplearning4j.rl4j.space.DiscreteSpace;
import org.deeplearning4j.rl4j.space.Encodable;
import org.deeplearning4j.rl4j.util.DataManager;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.NDArrayIndex;

/* loaded from: input_file:org/deeplearning4j/rl4j/learning/async/a3c/discrete/A3CThreadDiscrete.class */
public class A3CThreadDiscrete<O extends Encodable> extends AsyncThreadDiscrete<O, IActorCritic> {
    protected final A3CDiscrete.A3CConfiguration conf;
    protected final MDP<O, Integer, DiscreteSpace> mdp;
    protected final AsyncGlobal<IActorCritic> asyncGlobal;
    protected final int threadNumber;
    protected final DataManager dataManager;

    public A3CThreadDiscrete(MDP<O, Integer, DiscreteSpace> mdp, AsyncGlobal<IActorCritic> asyncGlobal, A3CDiscrete.A3CConfiguration a3CConfiguration, int i, DataManager dataManager) {
        super(asyncGlobal, i);
        this.conf = a3CConfiguration;
        this.asyncGlobal = asyncGlobal;
        this.threadNumber = i;
        this.mdp = mdp;
        this.dataManager = dataManager;
        mdp.getActionSpace().setSeed(this.conf.getSeed());
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // org.deeplearning4j.rl4j.learning.async.AsyncThread
    public Policy<O, Integer> getPolicy(IActorCritic iActorCritic) {
        return new ACPolicy(iActorCritic, new Random(this.conf.getSeed()));
    }

    /* renamed from: calcGradient, reason: avoid collision after fix types in other method */
    public Gradient[] calcGradient2(IActorCritic iActorCritic, Stack<MiniTrans<Integer>> stack) {
        MiniTrans<Integer> pop = stack.pop();
        int size = stack.size();
        boolean isRecurrent = getAsyncGlobal().getCurrent().isRecurrent();
        int[] shape = getHistoryProcessor() == null ? this.mdp.getObservationSpace().getShape() : getHistoryProcessor().getConf().getShape();
        INDArray create = Nd4j.create(isRecurrent ? Learning.makeShape(1, shape, size) : Learning.makeShape(size, shape));
        INDArray create2 = isRecurrent ? Nd4j.create(new int[]{1, 1, size}) : Nd4j.create(size, 1);
        INDArray zeros = isRecurrent ? Nd4j.zeros(new int[]{1, this.mdp.getActionSpace().getSize(), size}) : Nd4j.zeros(size, this.mdp.getActionSpace().getSize());
        double reward = pop.getReward();
        for (int i = size - 1; i >= 0; i--) {
            MiniTrans<Integer> pop2 = stack.pop();
            reward = pop2.getReward() + (this.conf.getGamma() * reward);
            if (isRecurrent) {
                create.get(new INDArrayIndex[]{NDArrayIndex.point(0), NDArrayIndex.all(), NDArrayIndex.point(i)}).assign(pop2.getObs());
            } else {
                create.putRow(i, pop2.getObs());
            }
            create2.putScalar(i, reward);
            double d = reward - pop2.getOutput()[0].getDouble(0);
            if (isRecurrent) {
                zeros.putScalar(0, pop2.getAction().intValue(), i, d);
            } else {
                zeros.putScalar(i, pop2.getAction().intValue(), d);
            }
        }
        return iActorCritic.gradient(create, new INDArray[]{create2, zeros});
    }

    @Override // org.deeplearning4j.rl4j.learning.async.AsyncThread
    public A3CDiscrete.A3CConfiguration getConf() {
        return this.conf;
    }

    @Override // org.deeplearning4j.rl4j.learning.async.AsyncThread
    public MDP<O, Integer, DiscreteSpace> getMdp() {
        return this.mdp;
    }

    @Override // org.deeplearning4j.rl4j.learning.async.AsyncThread
    public AsyncGlobal<IActorCritic> getAsyncGlobal() {
        return this.asyncGlobal;
    }

    @Override // org.deeplearning4j.rl4j.learning.async.AsyncThread
    public int getThreadNumber() {
        return this.threadNumber;
    }

    @Override // org.deeplearning4j.rl4j.learning.async.AsyncThread
    public DataManager getDataManager() {
        return this.dataManager;
    }

    @Override // org.deeplearning4j.rl4j.learning.async.AsyncThreadDiscrete
    public /* bridge */ /* synthetic */ Gradient[] calcGradient(IActorCritic iActorCritic, Stack stack) {
        return calcGradient2(iActorCritic, (Stack<MiniTrans<Integer>>) stack);
    }
}
