/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.rl4j.network.ac;

import java.io.IOException;
import java.io.OutputStream;
import java.util.Collection;
import org.deeplearning4j.nn.api.Model;
import org.deeplearning4j.nn.api.NeuralNetwork;
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.layers.recurrent.RnnOutputLayer;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.deeplearning4j.optimize.api.TrainingListener;
import org.deeplearning4j.rl4j.network.ac.IActorCritic;
import org.deeplearning4j.util.ModelSerializer;
import org.nd4j.linalg.api.ndarray.INDArray;

public class ActorCriticCompGraph<NN extends ActorCriticCompGraph>
implements IActorCritic<NN> {
    protected final ComputationGraph cg;
    protected final boolean recurrent;

    public ActorCriticCompGraph(ComputationGraph cg) {
        this.cg = cg;
        this.recurrent = cg.getOutputLayer(0) instanceof RnnOutputLayer;
    }

    @Override
    public NeuralNetwork[] getNeuralNetworks() {
        return new NeuralNetwork[]{this.cg};
    }

    public static ActorCriticCompGraph load(String path) throws IOException {
        return new ActorCriticCompGraph(ModelSerializer.restoreComputationGraph((String)path));
    }

    @Override
    public void fit(INDArray input, INDArray[] labels) {
        this.cg.fit(new INDArray[]{input}, labels);
    }

    @Override
    public void reset() {
        if (this.recurrent) {
            this.cg.rnnClearPreviousState();
        }
    }

    @Override
    public INDArray[] outputAll(INDArray batch) {
        if (this.recurrent) {
            return this.cg.rnnTimeStep(new INDArray[]{batch});
        }
        return this.cg.output(new INDArray[]{batch});
    }

    @Override
    public NN clone() {
        ActorCriticCompGraph<NN> nn = new ActorCriticCompGraph<NN>(this.cg.clone());
        nn.cg.setListeners(this.cg.getListeners());
        return (NN)nn;
    }

    @Override
    public void copy(NN from) {
        this.cg.setParams(((ActorCriticCompGraph)from).cg.params());
    }

    @Override
    public Gradient[] gradient(INDArray input, INDArray[] labels) {
        this.cg.setInput(0, input);
        this.cg.setLabels(labels);
        this.cg.computeGradientAndScore();
        Collection iterationListeners = this.cg.getListeners();
        if (iterationListeners != null && iterationListeners.size() > 0) {
            for (TrainingListener l : iterationListeners) {
                l.onGradientCalculation((Model)this.cg);
            }
        }
        return new Gradient[]{this.cg.gradient()};
    }

    @Override
    public void applyGradient(Gradient[] gradient, int batchSize) {
        if (this.recurrent) {
            batchSize = 1;
        }
        ComputationGraphConfiguration cgConf = this.cg.getConfiguration();
        int iterationCount = cgConf.getIterationCount();
        int epochCount = cgConf.getEpochCount();
        this.cg.getUpdater().update(gradient[0], iterationCount, epochCount, batchSize, LayerWorkspaceMgr.noWorkspaces());
        this.cg.params().subi(gradient[0].gradient());
        Collection iterationListeners = this.cg.getListeners();
        if (iterationListeners != null && iterationListeners.size() > 0) {
            for (TrainingListener listener : iterationListeners) {
                listener.iterationDone((Model)this.cg, iterationCount, epochCount);
            }
        }
        cgConf.setIterationCount(iterationCount + 1);
    }

    @Override
    public double getLatestScore() {
        return this.cg.score();
    }

    @Override
    public void save(OutputStream stream) throws IOException {
        ModelSerializer.writeModel((Model)this.cg, (OutputStream)stream, (boolean)true);
    }

    @Override
    public void save(String path) throws IOException {
        ModelSerializer.writeModel((Model)this.cg, (String)path, (boolean)true);
    }

    @Override
    public void save(OutputStream streamValue, OutputStream streamPolicy) throws IOException {
        throw new UnsupportedOperationException("Call save(stream)");
    }

    @Override
    public void save(String pathValue, String pathPolicy) throws IOException {
        throw new UnsupportedOperationException("Call save(path)");
    }

    @Override
    public boolean isRecurrent() {
        return this.recurrent;
    }
}

