package org.deeplearning4j.rl4j.network.ac;

import java.io.IOException;
import java.io.OutputStream;
import java.util.Collection;
import java.util.Iterator;
import org.deeplearning4j.nn.api.NeuralNetwork;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.layers.recurrent.RnnOutputLayer;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.deeplearning4j.optimize.api.TrainingListener;
import org.deeplearning4j.rl4j.network.ac.ActorCriticSeparate;
import org.deeplearning4j.util.ModelSerializer;
import org.nd4j.linalg.api.ndarray.INDArray;

/* loaded from: input_file:org/deeplearning4j/rl4j/network/ac/ActorCriticSeparate.class */
public class ActorCriticSeparate<NN extends ActorCriticSeparate> implements IActorCritic<NN> {
    protected final MultiLayerNetwork valueNet;
    protected final MultiLayerNetwork policyNet;
    protected final boolean recurrent;

    public ActorCriticSeparate(MultiLayerNetwork multiLayerNetwork, MultiLayerNetwork multiLayerNetwork2) {
        this.valueNet = multiLayerNetwork;
        this.policyNet = multiLayerNetwork2;
        this.recurrent = multiLayerNetwork.getOutputLayer() instanceof RnnOutputLayer;
    }

    @Override // org.deeplearning4j.rl4j.network.NeuralNet
    public NeuralNetwork[] getNeuralNetworks() {
        return new NeuralNetwork[]{this.valueNet, this.policyNet};
    }

    public static ActorCriticSeparate load(String str, String str2) throws IOException {
        return new ActorCriticSeparate(ModelSerializer.restoreMultiLayerNetwork(str), ModelSerializer.restoreMultiLayerNetwork(str2));
    }

    @Override // org.deeplearning4j.rl4j.network.ac.IActorCritic, org.deeplearning4j.rl4j.network.NeuralNet
    public void reset() {
        if (this.recurrent) {
            this.valueNet.rnnClearPreviousState();
            this.policyNet.rnnClearPreviousState();
        }
    }

    @Override // org.deeplearning4j.rl4j.network.ac.IActorCritic, org.deeplearning4j.rl4j.network.NeuralNet
    public void fit(INDArray iNDArray, INDArray[] iNDArrayArr) {
        this.valueNet.fit(iNDArray, iNDArrayArr[0]);
        this.policyNet.fit(iNDArray, iNDArrayArr[1]);
    }

    @Override // org.deeplearning4j.rl4j.network.ac.IActorCritic, org.deeplearning4j.rl4j.network.NeuralNet
    public INDArray[] outputAll(INDArray iNDArray) {
        return this.recurrent ? new INDArray[]{this.valueNet.rnnTimeStep(iNDArray), this.policyNet.rnnTimeStep(iNDArray)} : new INDArray[]{this.valueNet.output(iNDArray), this.policyNet.output(iNDArray)};
    }

    @Override // org.deeplearning4j.rl4j.network.NeuralNet
    /* renamed from: clone, reason: merged with bridge method [inline-methods] */
    public NN m18clone() {
        NN nn = (NN) new ActorCriticSeparate(this.valueNet.clone(), this.policyNet.clone());
        nn.valueNet.setListeners(this.valueNet.getListeners());
        nn.policyNet.setListeners(this.policyNet.getListeners());
        return nn;
    }

    @Override // org.deeplearning4j.rl4j.network.NeuralNet
    public void copy(NN nn) {
        this.valueNet.setParams(nn.valueNet.params());
        this.policyNet.setParams(nn.policyNet.params());
    }

    @Override // org.deeplearning4j.rl4j.network.ac.IActorCritic, org.deeplearning4j.rl4j.network.NeuralNet
    public Gradient[] gradient(INDArray iNDArray, INDArray[] iNDArrayArr) {
        this.valueNet.setInput(iNDArray);
        this.valueNet.setLabels(iNDArrayArr[0]);
        this.valueNet.computeGradientAndScore();
        Collection listeners = this.valueNet.getListeners();
        if (listeners != null && listeners.size() > 0) {
            Iterator it = listeners.iterator();
            while (it.hasNext()) {
                ((TrainingListener) it.next()).onGradientCalculation(this.valueNet);
            }
        }
        this.policyNet.setInput(iNDArray);
        this.policyNet.setLabels(iNDArrayArr[1]);
        this.policyNet.computeGradientAndScore();
        Collection listeners2 = this.policyNet.getListeners();
        if (listeners2 != null && listeners2.size() > 0) {
            Iterator it2 = listeners2.iterator();
            while (it2.hasNext()) {
                ((TrainingListener) it2.next()).onGradientCalculation(this.policyNet);
            }
        }
        return new Gradient[]{this.valueNet.gradient(), this.policyNet.gradient()};
    }

    @Override // org.deeplearning4j.rl4j.network.ac.IActorCritic, org.deeplearning4j.rl4j.network.NeuralNet
    public void applyGradient(Gradient[] gradientArr, int i) {
        MultiLayerConfiguration layerWiseConfigurations = this.valueNet.getLayerWiseConfigurations();
        int iterationCount = layerWiseConfigurations.getIterationCount();
        int epochCount = layerWiseConfigurations.getEpochCount();
        this.valueNet.getUpdater().update(this.valueNet, gradientArr[0], iterationCount, epochCount, i, LayerWorkspaceMgr.noWorkspaces());
        this.valueNet.params().subi(gradientArr[0].gradient());
        Collection listeners = this.valueNet.getListeners();
        if (listeners != null && listeners.size() > 0) {
            Iterator it = listeners.iterator();
            while (it.hasNext()) {
                ((TrainingListener) it.next()).iterationDone(this.valueNet, iterationCount, epochCount);
            }
        }
        layerWiseConfigurations.setIterationCount(iterationCount + 1);
        MultiLayerConfiguration layerWiseConfigurations2 = this.policyNet.getLayerWiseConfigurations();
        int iterationCount2 = layerWiseConfigurations2.getIterationCount();
        int epochCount2 = layerWiseConfigurations2.getEpochCount();
        this.policyNet.getUpdater().update(this.policyNet, gradientArr[1], iterationCount2, epochCount2, i, LayerWorkspaceMgr.noWorkspaces());
        this.policyNet.params().subi(gradientArr[1].gradient());
        Collection listeners2 = this.policyNet.getListeners();
        if (listeners2 != null && listeners2.size() > 0) {
            Iterator it2 = listeners2.iterator();
            while (it2.hasNext()) {
                ((TrainingListener) it2.next()).iterationDone(this.policyNet, iterationCount2, epochCount2);
            }
        }
        layerWiseConfigurations2.setIterationCount(iterationCount2 + 1);
    }

    @Override // org.deeplearning4j.rl4j.network.ac.IActorCritic, org.deeplearning4j.rl4j.network.NeuralNet
    public double getLatestScore() {
        return this.valueNet.score();
    }

    @Override // org.deeplearning4j.rl4j.network.NeuralNet
    public void save(OutputStream outputStream) throws IOException {
        throw new UnsupportedOperationException("Call save(streamValue, streamPolicy)");
    }

    @Override // org.deeplearning4j.rl4j.network.NeuralNet
    public void save(String str) throws IOException {
        throw new UnsupportedOperationException("Call save(pathValue, pathPolicy)");
    }

    @Override // org.deeplearning4j.rl4j.network.ac.IActorCritic
    public void save(OutputStream outputStream, OutputStream outputStream2) throws IOException {
        ModelSerializer.writeModel(this.valueNet, outputStream, true);
        ModelSerializer.writeModel(this.policyNet, outputStream2, true);
    }

    @Override // org.deeplearning4j.rl4j.network.ac.IActorCritic
    public void save(String str, String str2) throws IOException {
        ModelSerializer.writeModel(this.valueNet, str, true);
        ModelSerializer.writeModel(this.policyNet, str2, true);
    }

    @Override // org.deeplearning4j.rl4j.network.ac.IActorCritic, org.deeplearning4j.rl4j.network.NeuralNet
    public boolean isRecurrent() {
        return this.recurrent;
    }
}
