package org.deeplearning4j.rl4j.agent.learning.algorithm.dqn;

import lombok.NonNull;
import org.deeplearning4j.rl4j.agent.learning.algorithm.dqn.BaseTransitionTDAlgorithm;
import org.deeplearning4j.rl4j.agent.learning.update.Features;
import org.deeplearning4j.rl4j.network.IOutputNeuralNet;
import org.nd4j.linalg.api.ndarray.INDArray;

/* loaded from: input_file:org/deeplearning4j/rl4j/agent/learning/algorithm/dqn/BaseDQNAlgorithm.class */
public abstract class BaseDQNAlgorithm extends BaseTransitionTDAlgorithm {
    private final IOutputNeuralNet targetQNetwork;
    protected INDArray qNetworkNextFeatures;
    protected INDArray targetQNetworkNextFeatures;

    /* JADX INFO: Access modifiers changed from: protected */
    public BaseDQNAlgorithm(IOutputNeuralNet iOutputNeuralNet, @NonNull IOutputNeuralNet iOutputNeuralNet2, BaseTransitionTDAlgorithm.Configuration configuration) {
        super(iOutputNeuralNet, configuration);
        if (iOutputNeuralNet2 == null) {
            throw new NullPointerException("targetQNetwork is marked non-null but is null");
        }
        this.targetQNetwork = iOutputNeuralNet2;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // org.deeplearning4j.rl4j.agent.learning.algorithm.dqn.BaseTransitionTDAlgorithm
    public void initComputation(Features features, Features features2) {
        super.initComputation(features, features2);
        this.qNetworkNextFeatures = this.qNetwork.output(features2).get("Q");
        this.targetQNetworkNextFeatures = this.targetQNetwork.output(features2).get("Q");
    }
}
