package org.deeplearning4j.rl4j.learning.async.nstep.discrete;

import java.util.List;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.rl4j.experience.StateActionReward;
import org.deeplearning4j.rl4j.helper.INDArrayHelper;
import org.deeplearning4j.rl4j.learning.async.UpdateAlgorithm;
import org.deeplearning4j.rl4j.network.dqn.IDQN;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;

/* loaded from: input_file:org/deeplearning4j/rl4j/learning/async/nstep/discrete/QLearningUpdateAlgorithm.class */
public class QLearningUpdateAlgorithm implements UpdateAlgorithm<IDQN> {
    private final int actionSpaceSize;
    private final double gamma;

    public QLearningUpdateAlgorithm(int i, double d) {
        this.actionSpaceSize = i;
        this.gamma = d;
    }

    /* renamed from: computeGradients, reason: avoid collision after fix types in other method */
    public Gradient[] computeGradients2(IDQN idqn, List<StateActionReward<Integer>> list) {
        int size = list.size();
        StateActionReward<Integer> stateActionReward = list.get(size - 1);
        INDArray channelData = stateActionReward.getObservation().getChannelData(0);
        INDArray createBatchForShape = INDArrayHelper.createBatchForShape(size, channelData.shape());
        INDArray create = Nd4j.create(new int[]{size, this.actionSpaceSize});
        double d = stateActionReward.isTerminal() ? 0.0d : Nd4j.max(idqn.outputAll(channelData)[0]).getDouble(0L);
        for (int i = size - 1; i >= 0; i--) {
            StateActionReward<Integer> stateActionReward2 = list.get(i);
            INDArray channelData2 = stateActionReward2.getObservation().getChannelData(0);
            createBatchForShape.putRow(i, channelData2);
            d = stateActionReward2.getReward() + (this.gamma * d);
            create.putRow(i, idqn.outputAll(channelData2)[0].putScalar(stateActionReward2.getAction().intValue(), d));
        }
        return idqn.gradient(createBatchForShape, create);
    }

    @Override // org.deeplearning4j.rl4j.learning.async.UpdateAlgorithm
    public /* bridge */ /* synthetic */ Gradient[] computeGradients(IDQN idqn, List list) {
        return computeGradients2(idqn, (List<StateActionReward<Integer>>) list);
    }
}
