package org.deeplearning4j.rl4j.network.ac;

import java.io.OutputStream;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.nd4j.linalg.api.ndarray.INDArray;

/* loaded from: input_file:org/deeplearning4j/rl4j/network/ac/ActorCriticSeparate.class */
public class ActorCriticSeparate implements IActorCritic {
    protected final MultiLayerNetwork valueNet;
    protected final MultiLayerNetwork policyNet;

    public ActorCriticSeparate(MultiLayerNetwork multiLayerNetwork, MultiLayerNetwork multiLayerNetwork2) {
        this.valueNet = multiLayerNetwork;
        this.policyNet = multiLayerNetwork2;
    }

    @Override // org.deeplearning4j.rl4j.network.ac.IActorCritic, org.deeplearning4j.rl4j.network.NeuralNet
    public void fit(INDArray iNDArray, INDArray[] iNDArrayArr) {
        this.valueNet.fit(iNDArray, iNDArrayArr[0]);
        this.policyNet.fit(iNDArray, iNDArrayArr[1]);
    }

    @Override // org.deeplearning4j.rl4j.network.ac.IActorCritic, org.deeplearning4j.rl4j.network.NeuralNet
    public INDArray[] outputAll(INDArray iNDArray) {
        return new INDArray[]{this.valueNet.output(iNDArray), this.policyNet.output(iNDArray)};
    }

    @Override // org.deeplearning4j.rl4j.network.NeuralNet
    /* renamed from: clone, reason: merged with bridge method [inline-methods] */
    public ActorCriticSeparate m8clone() {
        return new ActorCriticSeparate(this.valueNet.clone(), this.policyNet.clone());
    }

    @Override // org.deeplearning4j.rl4j.network.ac.IActorCritic, org.deeplearning4j.rl4j.network.NeuralNet
    public Gradient[] gradient(INDArray iNDArray, INDArray[] iNDArrayArr) {
        this.valueNet.setInput(iNDArray);
        this.valueNet.setLabels(iNDArrayArr[0]);
        this.valueNet.computeGradientAndScore();
        this.policyNet.setInput(iNDArray);
        this.policyNet.setLabels(iNDArrayArr[1]);
        this.policyNet.computeGradientAndScore();
        return new Gradient[]{this.valueNet.gradient(), this.policyNet.gradient()};
    }

    @Override // org.deeplearning4j.rl4j.network.ac.IActorCritic, org.deeplearning4j.rl4j.network.NeuralNet
    public void applyGradient(Gradient[] gradientArr, int i) {
        this.valueNet.getUpdater().update(this.valueNet, gradientArr[0], 1, i);
        this.valueNet.params().subi(gradientArr[0].gradient());
        this.policyNet.getUpdater().update(this.policyNet, gradientArr[1], 1, i);
        this.policyNet.params().subi(gradientArr[1].gradient());
    }

    @Override // org.deeplearning4j.rl4j.network.ac.IActorCritic, org.deeplearning4j.rl4j.network.NeuralNet
    public double getLatestScore() {
        return this.valueNet.score();
    }

    @Override // org.deeplearning4j.rl4j.network.ac.IActorCritic, org.deeplearning4j.rl4j.network.NeuralNet
    public void save(OutputStream outputStream) {
        System.out.println("NOT IMPLEMENTED NOOO");
    }

    @Override // org.deeplearning4j.rl4j.network.ac.IActorCritic, org.deeplearning4j.rl4j.network.NeuralNet
    public void save(String str) {
        System.out.println("NOT IMPLEMENTED NOOO");
    }
}
