package org.deeplearning4j.rl4j.learning.async;

import java.util.Stack;
import org.deeplearning4j.gym.StepReply;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.rl4j.learning.Learning;
import org.deeplearning4j.rl4j.learning.async.AsyncThread;
import org.deeplearning4j.rl4j.learning.sync.Transition;
import org.deeplearning4j.rl4j.network.NeuralNet;
import org.deeplearning4j.rl4j.policy.Policy;
import org.deeplearning4j.rl4j.space.DiscreteSpace;
import org.deeplearning4j.rl4j.space.Encodable;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;

/* loaded from: input_file:org/deeplearning4j/rl4j/learning/async/AsyncThreadDiscrete.class */
public abstract class AsyncThreadDiscrete<O extends Encodable, NN extends NeuralNet> extends AsyncThread<O, Integer, DiscreteSpace, NN> {
    public AsyncThreadDiscrete(AsyncGlobal<NN> asyncGlobal, int i) {
        super(asyncGlobal, i);
    }

    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r0v86, types: [org.deeplearning4j.rl4j.space.Encodable] */
    @Override // org.deeplearning4j.rl4j.learning.async.AsyncThread
    public AsyncThread.SubEpochReturn<O> trainSubEpoch(O o, int i) {
        Integer nextAction;
        NN cloneCurrent = getAsyncGlobal().cloneCurrent();
        Stack<MiniTrans<Integer>> stack = new Stack<>();
        O o2 = o;
        Policy<O, Integer> policy = getPolicy(cloneCurrent);
        Integer num = null;
        INDArray[] iNDArrayArr = null;
        boolean z = getHistoryProcessor() != null;
        int skipFrame = z ? getHistoryProcessor().getConf().getSkipFrame() : 1;
        double d = 0.0d;
        double d2 = 0.0d;
        int i2 = 0;
        while (!getMdp().isDone() && i2 < i * skipFrame) {
            INDArray input = Learning.getInput(getMdp(), o2);
            if (getStepCounter() % skipFrame != 0) {
                nextAction = num;
            } else {
                if (iNDArrayArr == null) {
                    if (z) {
                        getHistoryProcessor().add(input);
                        iNDArrayArr = getHistoryProcessor().getHistory();
                    } else {
                        iNDArrayArr = new INDArray[]{input};
                    }
                }
                INDArray concat = Transition.concat(iNDArrayArr);
                if (concat.shape().length > 2) {
                    concat = concat.reshape(Learning.makeShape(1, concat.shape()));
                }
                nextAction = policy.nextAction(concat);
            }
            num = nextAction;
            StepReply step = getMdp().step(nextAction);
            d2 += step.getReward() * getConf().getRewardFactor();
            if (getStepCounter() % skipFrame == 0 || step.isDone()) {
                o2 = (Encodable) step.getObservation();
                if (input.shape().length > 2) {
                    input = input.reshape(Learning.makeShape(1, input.shape()));
                }
                stack.add(new MiniTrans<>(Transition.concat(iNDArrayArr), nextAction, cloneCurrent.outputAll(input), d2));
                d += step.getReward();
                if (z) {
                    getHistoryProcessor().add(Learning.getInput(getMdp(), (Encodable) step.getObservation()));
                }
                iNDArrayArr = z ? getHistoryProcessor().getHistory() : new INDArray[]{Learning.getInput(getMdp(), (Encodable) step.getObservation())};
                d2 = 0.0d;
            }
            i2++;
        }
        INDArray input2 = Learning.getInput(getMdp(), o2);
        if (getMdp().isDone()) {
            stack.add(new MiniTrans<>(input2, null, null, 0.0d));
        } else {
            INDArray[] outputAll = getConf().getTargetDqnUpdateFreq() == -1 ? cloneCurrent.outputAll(input2) : getAsyncGlobal().cloneTarget().outputAll(input2);
            stack.add(new MiniTrans<>(input2, null, outputAll, Nd4j.max(outputAll[0]).getDouble(0)));
        }
        getAsyncGlobal().enqueue(calcGradient(cloneCurrent, stack), Integer.valueOf(i2));
        return new AsyncThread.SubEpochReturn<>(i2, o2, d, cloneCurrent.getLatestScore());
    }

    public abstract Gradient[] calcGradient(NN nn, Stack<MiniTrans<Integer>> stack);
}
