/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.rl4j.network.ac;

import java.util.Arrays;
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
import org.deeplearning4j.nn.conf.InputPreProcessor;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.ConvolutionLayer;
import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.LSTM;
import org.deeplearning4j.nn.conf.layers.Layer;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.conf.layers.RnnOutputLayer;
import org.deeplearning4j.nn.conf.preprocessor.CnnToFeedForwardPreProcessor;
import org.deeplearning4j.nn.conf.preprocessor.FeedForwardToRnnPreProcessor;
import org.deeplearning4j.nn.conf.preprocessor.RnnToCnnPreProcessor;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.optimize.api.TrainingListener;
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
import org.deeplearning4j.rl4j.network.ac.ActorCriticCompGraph;
import org.deeplearning4j.rl4j.network.ac.ActorCriticFactoryCompGraph;
import org.deeplearning4j.rl4j.network.ac.ActorCriticLoss;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.learning.config.Adam;
import org.nd4j.linalg.learning.config.IUpdater;
import org.nd4j.linalg.lossfunctions.ILossFunction;
import org.nd4j.linalg.lossfunctions.LossFunctions;

public final class ActorCriticFactoryCompGraphStdConv
implements ActorCriticFactoryCompGraph {
    private final Configuration conf;

    @Override
    public ActorCriticCompGraph buildActorCritic(int[] shapeInputs, int numOutputs) {
        if (shapeInputs.length == 1) {
            throw new AssertionError((Object)"Impossible to apply convolutional layer on a shape == 1");
        }
        int h = ((shapeInputs[1] - 8) / 4 + 1 - 4) / 2 + 1;
        int w = ((shapeInputs[2] - 8) / 4 + 1 - 4) / 2 + 1;
        ComputationGraphConfiguration.GraphBuilder confB = new NeuralNetConfiguration.Builder().seed(12345L).optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).updater((IUpdater)(this.conf.getUpdater() != null ? this.conf.getUpdater() : new Adam())).weightInit(WeightInit.XAVIER).l2(this.conf.getL2()).graphBuilder().addInputs(new String[]{"input"}).addLayer("0", (Layer)((ConvolutionLayer.Builder)((ConvolutionLayer.Builder)((ConvolutionLayer.Builder)new ConvolutionLayer.Builder(new int[]{8, 8}).nIn(shapeInputs[0])).nOut(16)).stride(new int[]{4, 4}).activation(Activation.RELU)).build(), new String[]{"input"});
        confB.addLayer("1", (Layer)((ConvolutionLayer.Builder)((ConvolutionLayer.Builder)((ConvolutionLayer.Builder)new ConvolutionLayer.Builder(new int[]{4, 4}).nIn(16)).nOut(32)).stride(new int[]{2, 2}).activation(Activation.RELU)).build(), new String[]{"0"});
        confB.addLayer("2", (Layer)((DenseLayer.Builder)((DenseLayer.Builder)((DenseLayer.Builder)new DenseLayer.Builder().nIn(w * h * 32)).nOut(256)).activation(Activation.RELU)).build(), new String[]{"1"});
        if (this.conf.isUseLSTM()) {
            confB.addLayer("3", (Layer)((LSTM.Builder)((LSTM.Builder)((LSTM.Builder)new LSTM.Builder().nIn(256)).nOut(256)).activation(Activation.TANH)).build(), new String[]{"2"});
            confB.addLayer("value", (Layer)((RnnOutputLayer.Builder)((RnnOutputLayer.Builder)((RnnOutputLayer.Builder)new RnnOutputLayer.Builder(LossFunctions.LossFunction.MSE).activation(Activation.IDENTITY)).nIn(256)).nOut(1)).build(), new String[]{"3"});
            confB.addLayer("softmax", (Layer)((RnnOutputLayer.Builder)((RnnOutputLayer.Builder)((RnnOutputLayer.Builder)new RnnOutputLayer.Builder((ILossFunction)new ActorCriticLoss()).activation(Activation.SOFTMAX)).nIn(256)).nOut(numOutputs)).build(), new String[]{"3"});
        } else {
            confB.addLayer("value", (Layer)((OutputLayer.Builder)((OutputLayer.Builder)((OutputLayer.Builder)new OutputLayer.Builder(LossFunctions.LossFunction.MSE).activation(Activation.IDENTITY)).nIn(256)).nOut(1)).build(), new String[]{"2"});
            confB.addLayer("softmax", (Layer)((OutputLayer.Builder)((OutputLayer.Builder)((OutputLayer.Builder)new OutputLayer.Builder((ILossFunction)new ActorCriticLoss()).activation(Activation.SOFTMAX)).nIn(256)).nOut(numOutputs)).build(), new String[]{"2"});
        }
        confB.setOutputs(new String[]{"value", "softmax"});
        if (this.conf.isUseLSTM()) {
            confB.inputPreProcessor("0", (InputPreProcessor)new RnnToCnnPreProcessor(shapeInputs[1], shapeInputs[2], shapeInputs[0]));
            confB.inputPreProcessor("2", (InputPreProcessor)new CnnToFeedForwardPreProcessor((long)h, (long)w, 32L));
            confB.inputPreProcessor("3", (InputPreProcessor)new FeedForwardToRnnPreProcessor());
        } else {
            confB.setInputTypes(new InputType[]{InputType.convolutional((long)shapeInputs[1], (long)shapeInputs[2], (long)shapeInputs[0])});
        }
        ComputationGraphConfiguration cgconf = confB.pretrain(false).backprop(true).build();
        ComputationGraph model = new ComputationGraph(cgconf);
        model.init();
        if (this.conf.getListeners() != null) {
            model.setListeners(this.conf.getListeners());
        } else {
            model.setListeners(new TrainingListener[]{new ScoreIterationListener(50)});
        }
        return new ActorCriticCompGraph(model);
    }

    public ActorCriticFactoryCompGraphStdConv(Configuration conf) {
        this.conf = conf;
    }

    public Configuration getConf() {
        return this.conf;
    }

    public boolean equals(Object o) {
        if (o == this) {
            return true;
        }
        if (!(o instanceof ActorCriticFactoryCompGraphStdConv)) {
            return false;
        }
        ActorCriticFactoryCompGraphStdConv other = (ActorCriticFactoryCompGraphStdConv)o;
        Configuration this$conf = this.getConf();
        Configuration other$conf = other.getConf();
        return !(this$conf == null ? other$conf != null : !((Object)this$conf).equals(other$conf));
    }

    public int hashCode() {
        int PRIME = 59;
        int result = 1;
        Configuration $conf = this.getConf();
        result = result * 59 + ($conf == null ? 43 : ((Object)$conf).hashCode());
        return result;
    }

    public String toString() {
        return "ActorCriticFactoryCompGraphStdConv(conf=" + this.getConf() + ")";
    }

    public static final class Configuration {
        private final double l2;
        private final IUpdater updater;
        private final TrainingListener[] listeners;
        private final boolean useLSTM;

        public static ConfigurationBuilder builder() {
            return new ConfigurationBuilder();
        }

        public double getL2() {
            return this.l2;
        }

        public IUpdater getUpdater() {
            return this.updater;
        }

        public TrainingListener[] getListeners() {
            return this.listeners;
        }

        public boolean isUseLSTM() {
            return this.useLSTM;
        }

        public boolean equals(Object o) {
            if (o == this) {
                return true;
            }
            if (!(o instanceof Configuration)) {
                return false;
            }
            Configuration other = (Configuration)o;
            if (Double.compare(this.getL2(), other.getL2()) != 0) {
                return false;
            }
            IUpdater this$updater = this.getUpdater();
            IUpdater other$updater = other.getUpdater();
            if (this$updater == null ? other$updater != null : !this$updater.equals(other$updater)) {
                return false;
            }
            if (!Arrays.deepEquals(this.getListeners(), other.getListeners())) {
                return false;
            }
            return this.isUseLSTM() == other.isUseLSTM();
        }

        public int hashCode() {
            int PRIME = 59;
            int result = 1;
            long $l2 = Double.doubleToLongBits(this.getL2());
            result = result * 59 + (int)($l2 >>> 32 ^ $l2);
            IUpdater $updater = this.getUpdater();
            result = result * 59 + ($updater == null ? 43 : $updater.hashCode());
            result = result * 59 + Arrays.deepHashCode(this.getListeners());
            result = result * 59 + (this.isUseLSTM() ? 79 : 97);
            return result;
        }

        public String toString() {
            return "ActorCriticFactoryCompGraphStdConv.Configuration(l2=" + this.getL2() + ", updater=" + this.getUpdater() + ", listeners=" + Arrays.deepToString(this.getListeners()) + ", useLSTM=" + this.isUseLSTM() + ")";
        }

        public Configuration(double l2, IUpdater updater, TrainingListener[] listeners, boolean useLSTM) {
            this.l2 = l2;
            this.updater = updater;
            this.listeners = listeners;
            this.useLSTM = useLSTM;
        }

        public static class ConfigurationBuilder {
            private double l2;
            private IUpdater updater;
            private TrainingListener[] listeners;
            private boolean useLSTM;

            ConfigurationBuilder() {
            }

            public ConfigurationBuilder l2(double l2) {
                this.l2 = l2;
                return this;
            }

            public ConfigurationBuilder updater(IUpdater updater) {
                this.updater = updater;
                return this;
            }

            public ConfigurationBuilder listeners(TrainingListener[] listeners) {
                this.listeners = listeners;
                return this;
            }

            public ConfigurationBuilder useLSTM(boolean useLSTM) {
                this.useLSTM = useLSTM;
                return this;
            }

            public Configuration build() {
                return new Configuration(this.l2, this.updater, this.listeners, this.useLSTM);
            }

            public String toString() {
                return "ActorCriticFactoryCompGraphStdConv.Configuration.ConfigurationBuilder(l2=" + this.l2 + ", updater=" + this.updater + ", listeners=" + Arrays.deepToString(this.listeners) + ", useLSTM=" + this.useLSTM + ")";
            }
        }
    }
}

