/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.rl4j.network.ac;

import java.io.IOException;
import java.io.OutputStream;
import java.util.Collection;
import org.deeplearning4j.nn.api.Model;
import org.deeplearning4j.nn.api.NeuralNetwork;
import org.deeplearning4j.nn.api.Trainable;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.layers.recurrent.RnnOutputLayer;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.deeplearning4j.optimize.api.TrainingListener;
import org.deeplearning4j.rl4j.network.ac.IActorCritic;
import org.deeplearning4j.util.ModelSerializer;
import org.nd4j.linalg.api.ndarray.INDArray;

public class ActorCriticSeparate<NN extends ActorCriticSeparate>
implements IActorCritic<NN> {
    protected final MultiLayerNetwork valueNet;
    protected final MultiLayerNetwork policyNet;
    protected final boolean recurrent;

    public ActorCriticSeparate(MultiLayerNetwork valueNet, MultiLayerNetwork policyNet) {
        this.valueNet = valueNet;
        this.policyNet = policyNet;
        this.recurrent = valueNet.getOutputLayer() instanceof RnnOutputLayer;
    }

    @Override
    public NeuralNetwork[] getNeuralNetworks() {
        return new NeuralNetwork[]{this.valueNet, this.policyNet};
    }

    public static ActorCriticSeparate load(String pathValue, String pathPolicy) throws IOException {
        return new ActorCriticSeparate(ModelSerializer.restoreMultiLayerNetwork((String)pathValue), ModelSerializer.restoreMultiLayerNetwork((String)pathPolicy));
    }

    @Override
    public void reset() {
        if (this.recurrent) {
            this.valueNet.rnnClearPreviousState();
            this.policyNet.rnnClearPreviousState();
        }
    }

    @Override
    public void fit(INDArray input, INDArray[] labels) {
        this.valueNet.fit(input, labels[0]);
        this.policyNet.fit(input, labels[1]);
    }

    @Override
    public INDArray[] outputAll(INDArray batch) {
        if (this.recurrent) {
            return new INDArray[]{this.valueNet.rnnTimeStep(batch), this.policyNet.rnnTimeStep(batch)};
        }
        return new INDArray[]{this.valueNet.output(batch), this.policyNet.output(batch)};
    }

    @Override
    public NN clone() {
        ActorCriticSeparate<NN> nn = new ActorCriticSeparate<NN>(this.valueNet.clone(), this.policyNet.clone());
        nn.valueNet.setListeners(this.valueNet.getListeners());
        nn.policyNet.setListeners(this.policyNet.getListeners());
        return (NN)nn;
    }

    @Override
    public void copy(NN from) {
        this.valueNet.setParams(((ActorCriticSeparate)from).valueNet.params());
        this.policyNet.setParams(((ActorCriticSeparate)from).policyNet.params());
    }

    @Override
    public Gradient[] gradient(INDArray input, INDArray[] labels) {
        this.valueNet.setInput(input);
        this.valueNet.setLabels(labels[0]);
        this.valueNet.computeGradientAndScore();
        Collection valueIterationListeners = this.valueNet.getListeners();
        if (valueIterationListeners != null && valueIterationListeners.size() > 0) {
            for (TrainingListener l : valueIterationListeners) {
                l.onGradientCalculation((Model)this.valueNet);
            }
        }
        this.policyNet.setInput(input);
        this.policyNet.setLabels(labels[1]);
        this.policyNet.computeGradientAndScore();
        Collection policyIterationListeners = this.policyNet.getListeners();
        if (policyIterationListeners != null && policyIterationListeners.size() > 0) {
            for (TrainingListener l : policyIterationListeners) {
                l.onGradientCalculation((Model)this.policyNet);
            }
        }
        return new Gradient[]{this.valueNet.gradient(), this.policyNet.gradient()};
    }

    @Override
    public void applyGradient(Gradient[] gradient, int batchSize) {
        MultiLayerConfiguration valueConf = this.valueNet.getLayerWiseConfigurations();
        int valueIterationCount = valueConf.getIterationCount();
        int valueEpochCount = valueConf.getEpochCount();
        this.valueNet.getUpdater().update((Trainable)this.valueNet, gradient[0], valueIterationCount, valueEpochCount, batchSize, LayerWorkspaceMgr.noWorkspaces());
        this.valueNet.params().subi(gradient[0].gradient());
        Collection valueIterationListeners = this.valueNet.getListeners();
        if (valueIterationListeners != null && valueIterationListeners.size() > 0) {
            for (TrainingListener listener : valueIterationListeners) {
                listener.iterationDone((Model)this.valueNet, valueIterationCount, valueEpochCount);
            }
        }
        valueConf.setIterationCount(valueIterationCount + 1);
        MultiLayerConfiguration policyConf = this.policyNet.getLayerWiseConfigurations();
        int policyIterationCount = policyConf.getIterationCount();
        int policyEpochCount = policyConf.getEpochCount();
        this.policyNet.getUpdater().update((Trainable)this.policyNet, gradient[1], policyIterationCount, policyEpochCount, batchSize, LayerWorkspaceMgr.noWorkspaces());
        this.policyNet.params().subi(gradient[1].gradient());
        Collection policyIterationListeners = this.policyNet.getListeners();
        if (policyIterationListeners != null && policyIterationListeners.size() > 0) {
            for (TrainingListener listener : policyIterationListeners) {
                listener.iterationDone((Model)this.policyNet, policyIterationCount, policyEpochCount);
            }
        }
        policyConf.setIterationCount(policyIterationCount + 1);
    }

    @Override
    public double getLatestScore() {
        return this.valueNet.score();
    }

    @Override
    public void save(OutputStream stream) throws IOException {
        throw new UnsupportedOperationException("Call save(streamValue, streamPolicy)");
    }

    @Override
    public void save(String path) throws IOException {
        throw new UnsupportedOperationException("Call save(pathValue, pathPolicy)");
    }

    @Override
    public void save(OutputStream streamValue, OutputStream streamPolicy) throws IOException {
        ModelSerializer.writeModel((Model)this.valueNet, (OutputStream)streamValue, (boolean)true);
        ModelSerializer.writeModel((Model)this.policyNet, (OutputStream)streamPolicy, (boolean)true);
    }

    @Override
    public void save(String pathValue, String pathPolicy) throws IOException {
        ModelSerializer.writeModel((Model)this.valueNet, (String)pathValue, (boolean)true);
        ModelSerializer.writeModel((Model)this.policyNet, (String)pathPolicy, (boolean)true);
    }

    @Override
    public boolean isRecurrent() {
        return this.recurrent;
    }
}

