/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.rl4j.network.dqn;

import java.util.Arrays;
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.Layer;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.optimize.api.TrainingListener;
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
import org.deeplearning4j.rl4j.network.dqn.DQN;
import org.deeplearning4j.rl4j.network.dqn.DQNFactory;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.learning.config.Adam;
import org.nd4j.linalg.learning.config.IUpdater;
import org.nd4j.linalg.lossfunctions.LossFunctions;

public final class DQNFactoryStdDense
implements DQNFactory {
    private final Configuration conf;

    @Override
    public DQN buildDQN(int[] numInputs, int numOutputs) {
        int nIn = 1;
        for (int i : numInputs) {
            nIn *= i;
        }
        NeuralNetConfiguration.ListBuilder confB = new NeuralNetConfiguration.Builder().seed(12345L).optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).updater((IUpdater)(this.conf.getUpdater() != null ? this.conf.getUpdater() : new Adam())).weightInit(WeightInit.XAVIER).l2(this.conf.getL2()).list().layer(0, (Layer)((DenseLayer.Builder)((DenseLayer.Builder)((DenseLayer.Builder)new DenseLayer.Builder().nIn(nIn)).nOut(this.conf.getNumHiddenNodes())).activation(Activation.RELU)).build());
        for (int i = 1; i < this.conf.getNumLayer(); ++i) {
            confB.layer(i, (Layer)((DenseLayer.Builder)((DenseLayer.Builder)((DenseLayer.Builder)new DenseLayer.Builder().nIn(this.conf.getNumHiddenNodes())).nOut(this.conf.getNumHiddenNodes())).activation(Activation.RELU)).build());
        }
        confB.layer(this.conf.getNumLayer(), (Layer)((OutputLayer.Builder)((OutputLayer.Builder)((OutputLayer.Builder)new OutputLayer.Builder(LossFunctions.LossFunction.MSE).activation(Activation.IDENTITY)).nIn(this.conf.getNumHiddenNodes())).nOut(numOutputs)).build());
        MultiLayerConfiguration mlnconf = confB.pretrain(false).backprop(true).build();
        MultiLayerNetwork model = new MultiLayerNetwork(mlnconf);
        model.init();
        if (this.conf.getListeners() != null) {
            model.setListeners(this.conf.getListeners());
        } else {
            model.setListeners(new TrainingListener[]{new ScoreIterationListener(50)});
        }
        return new DQN(model);
    }

    public DQNFactoryStdDense(Configuration conf) {
        this.conf = conf;
    }

    public Configuration getConf() {
        return this.conf;
    }

    public boolean equals(Object o) {
        if (o == this) {
            return true;
        }
        if (!(o instanceof DQNFactoryStdDense)) {
            return false;
        }
        DQNFactoryStdDense other = (DQNFactoryStdDense)o;
        Configuration this$conf = this.getConf();
        Configuration other$conf = other.getConf();
        return !(this$conf == null ? other$conf != null : !((Object)this$conf).equals(other$conf));
    }

    public int hashCode() {
        int PRIME = 59;
        int result = 1;
        Configuration $conf = this.getConf();
        result = result * 59 + ($conf == null ? 43 : ((Object)$conf).hashCode());
        return result;
    }

    public String toString() {
        return "DQNFactoryStdDense(conf=" + this.getConf() + ")";
    }

    public static final class Configuration {
        private final int numLayer;
        private final int numHiddenNodes;
        private final double l2;
        private final IUpdater updater;
        private final TrainingListener[] listeners;

        public static ConfigurationBuilder builder() {
            return new ConfigurationBuilder();
        }

        public int getNumLayer() {
            return this.numLayer;
        }

        public int getNumHiddenNodes() {
            return this.numHiddenNodes;
        }

        public double getL2() {
            return this.l2;
        }

        public IUpdater getUpdater() {
            return this.updater;
        }

        public TrainingListener[] getListeners() {
            return this.listeners;
        }

        public boolean equals(Object o) {
            if (o == this) {
                return true;
            }
            if (!(o instanceof Configuration)) {
                return false;
            }
            Configuration other = (Configuration)o;
            if (this.getNumLayer() != other.getNumLayer()) {
                return false;
            }
            if (this.getNumHiddenNodes() != other.getNumHiddenNodes()) {
                return false;
            }
            if (Double.compare(this.getL2(), other.getL2()) != 0) {
                return false;
            }
            IUpdater this$updater = this.getUpdater();
            IUpdater other$updater = other.getUpdater();
            if (this$updater == null ? other$updater != null : !this$updater.equals(other$updater)) {
                return false;
            }
            return Arrays.deepEquals(this.getListeners(), other.getListeners());
        }

        public int hashCode() {
            int PRIME = 59;
            int result = 1;
            result = result * 59 + this.getNumLayer();
            result = result * 59 + this.getNumHiddenNodes();
            long $l2 = Double.doubleToLongBits(this.getL2());
            result = result * 59 + (int)($l2 >>> 32 ^ $l2);
            IUpdater $updater = this.getUpdater();
            result = result * 59 + ($updater == null ? 43 : $updater.hashCode());
            result = result * 59 + Arrays.deepHashCode(this.getListeners());
            return result;
        }

        public String toString() {
            return "DQNFactoryStdDense.Configuration(numLayer=" + this.getNumLayer() + ", numHiddenNodes=" + this.getNumHiddenNodes() + ", l2=" + this.getL2() + ", updater=" + this.getUpdater() + ", listeners=" + Arrays.deepToString(this.getListeners()) + ")";
        }

        public Configuration(int numLayer, int numHiddenNodes, double l2, IUpdater updater, TrainingListener[] listeners) {
            this.numLayer = numLayer;
            this.numHiddenNodes = numHiddenNodes;
            this.l2 = l2;
            this.updater = updater;
            this.listeners = listeners;
        }

        public static class ConfigurationBuilder {
            private int numLayer;
            private int numHiddenNodes;
            private double l2;
            private IUpdater updater;
            private TrainingListener[] listeners;

            ConfigurationBuilder() {
            }

            public ConfigurationBuilder numLayer(int numLayer) {
                this.numLayer = numLayer;
                return this;
            }

            public ConfigurationBuilder numHiddenNodes(int numHiddenNodes) {
                this.numHiddenNodes = numHiddenNodes;
                return this;
            }

            public ConfigurationBuilder l2(double l2) {
                this.l2 = l2;
                return this;
            }

            public ConfigurationBuilder updater(IUpdater updater) {
                this.updater = updater;
                return this;
            }

            public ConfigurationBuilder listeners(TrainingListener[] listeners) {
                this.listeners = listeners;
                return this;
            }

            public Configuration build() {
                return new Configuration(this.numLayer, this.numHiddenNodes, this.l2, this.updater, this.listeners);
            }

            public String toString() {
                return "DQNFactoryStdDense.Configuration.ConfigurationBuilder(numLayer=" + this.numLayer + ", numHiddenNodes=" + this.numHiddenNodes + ", l2=" + this.l2 + ", updater=" + this.updater + ", listeners=" + Arrays.deepToString(this.listeners) + ")";
            }
        }
    }
}

