/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.rl4j.network.dqn;

import java.io.IOException;
import java.io.OutputStream;
import java.util.Collection;
import org.deeplearning4j.nn.api.Model;
import org.deeplearning4j.nn.api.NeuralNetwork;
import org.deeplearning4j.nn.api.Trainable;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.deeplearning4j.optimize.api.TrainingListener;
import org.deeplearning4j.rl4j.network.dqn.IDQN;
import org.deeplearning4j.util.ModelSerializer;
import org.nd4j.linalg.api.ndarray.INDArray;

public class DQN<NN extends DQN>
implements IDQN<NN> {
    protected final MultiLayerNetwork mln;
    int i = 0;

    public DQN(MultiLayerNetwork mln) {
        this.mln = mln;
    }

    @Override
    public NeuralNetwork[] getNeuralNetworks() {
        return new NeuralNetwork[]{this.mln};
    }

    public static DQN load(String path) throws IOException {
        return new DQN(ModelSerializer.restoreMultiLayerNetwork((String)path));
    }

    @Override
    public boolean isRecurrent() {
        return false;
    }

    @Override
    public void reset() {
    }

    @Override
    public void fit(INDArray input, INDArray labels) {
        this.mln.fit(input, labels);
    }

    @Override
    public void fit(INDArray input, INDArray[] labels) {
        this.fit(input, labels[0]);
    }

    @Override
    public INDArray output(INDArray batch) {
        return this.mln.output(batch);
    }

    @Override
    public INDArray[] outputAll(INDArray batch) {
        return new INDArray[]{this.output(batch)};
    }

    @Override
    public NN clone() {
        DQN<NN> nn = new DQN<NN>(this.mln.clone());
        nn.mln.setListeners(this.mln.getListeners());
        return (NN)nn;
    }

    @Override
    public void copy(NN from) {
        this.mln.setParams(((DQN)from).mln.params());
    }

    @Override
    public Gradient[] gradient(INDArray input, INDArray labels) {
        this.mln.setInput(input);
        this.mln.setLabels(labels);
        this.mln.computeGradientAndScore();
        Collection iterationListeners = this.mln.getListeners();
        if (iterationListeners != null && iterationListeners.size() > 0) {
            for (TrainingListener l : iterationListeners) {
                l.onGradientCalculation((Model)this.mln);
            }
        }
        return new Gradient[]{this.mln.gradient()};
    }

    @Override
    public Gradient[] gradient(INDArray input, INDArray[] labels) {
        return this.gradient(input, labels[0]);
    }

    @Override
    public void applyGradient(Gradient[] gradient, int batchSize) {
        MultiLayerConfiguration mlnConf = this.mln.getLayerWiseConfigurations();
        int iterationCount = mlnConf.getIterationCount();
        int epochCount = mlnConf.getEpochCount();
        this.mln.getUpdater().update((Trainable)this.mln, gradient[0], iterationCount, epochCount, batchSize, LayerWorkspaceMgr.noWorkspaces());
        this.mln.params().subi(gradient[0].gradient());
        Collection iterationListeners = this.mln.getListeners();
        if (iterationListeners != null && iterationListeners.size() > 0) {
            for (TrainingListener listener : iterationListeners) {
                listener.iterationDone((Model)this.mln, iterationCount, epochCount);
            }
        }
        mlnConf.setIterationCount(iterationCount + 1);
    }

    @Override
    public double getLatestScore() {
        return this.mln.score();
    }

    @Override
    public void save(OutputStream stream) throws IOException {
        ModelSerializer.writeModel((Model)this.mln, (OutputStream)stream, (boolean)true);
    }

    @Override
    public void save(String path) throws IOException {
        ModelSerializer.writeModel((Model)this.mln, (String)path, (boolean)true);
    }
}

