package org.deeplearning4j.rl4j.learning.async.a3c.discrete;

import java.util.List;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.rl4j.experience.StateActionReward;
import org.deeplearning4j.rl4j.learning.Learning;
import org.deeplearning4j.rl4j.learning.async.UpdateAlgorithm;
import org.deeplearning4j.rl4j.network.ac.IActorCritic;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.NDArrayIndex;

/* loaded from: input_file:org/deeplearning4j/rl4j/learning/async/a3c/discrete/AdvantageActorCriticUpdateAlgorithm.class */
public class AdvantageActorCriticUpdateAlgorithm implements UpdateAlgorithm<IActorCritic> {
    private final int[] shape;
    private final int actionSpaceSize;
    private final double gamma;
    private final boolean recurrent;

    public AdvantageActorCriticUpdateAlgorithm(boolean z, int[] iArr, int i, double d) {
        this.recurrent = z;
        this.shape = iArr;
        this.actionSpaceSize = i;
        this.gamma = d;
    }

    /* renamed from: computeGradients, reason: avoid collision after fix types in other method */
    public Gradient[] computeGradients2(IActorCritic iActorCritic, List<StateActionReward<Integer>> list) {
        int size = list.size();
        INDArray create = Nd4j.create(this.recurrent ? Learning.makeShape(1, this.shape, size) : Learning.makeShape(size, this.shape));
        INDArray create2 = this.recurrent ? Nd4j.create(new int[]{1, 1, size}) : Nd4j.create(new int[]{size, 1});
        INDArray zeros = this.recurrent ? Nd4j.zeros(new int[]{1, this.actionSpaceSize, size}) : Nd4j.zeros(new int[]{size, this.actionSpaceSize});
        StateActionReward<Integer> stateActionReward = list.get(size - 1);
        double d = stateActionReward.isTerminal() ? 0.0d : iActorCritic.outputAll(stateActionReward.getObservation().getChannelData(0))[0].getDouble(0L);
        for (int i = size - 1; i >= 0; i--) {
            StateActionReward<Integer> stateActionReward2 = list.get(i);
            INDArray channelData = stateActionReward2.getObservation().getChannelData(0);
            INDArray[] outputAll = iActorCritic.outputAll(channelData);
            d = stateActionReward2.getReward() + (this.gamma * d);
            if (this.recurrent) {
                create.get(new INDArrayIndex[]{NDArrayIndex.point(0L), NDArrayIndex.all(), NDArrayIndex.point(i)}).assign(channelData);
            } else {
                create.putRow(i, channelData);
            }
            create2.putScalar(i, d);
            double d2 = d - outputAll[0].getDouble(0L);
            if (this.recurrent) {
                zeros.putScalar(0L, stateActionReward2.getAction().intValue(), i, d2);
            } else {
                zeros.putScalar(i, stateActionReward2.getAction().intValue(), d2);
            }
        }
        return iActorCritic.gradient(create, new INDArray[]{create2, zeros});
    }

    @Override // org.deeplearning4j.rl4j.learning.async.UpdateAlgorithm
    public /* bridge */ /* synthetic */ Gradient[] computeGradients(IActorCritic iActorCritic, List list) {
        return computeGradients2(iActorCritic, (List<StateActionReward<Integer>>) list);
    }
}
