package org.deeplearning4j.scalnet.models;

import org.deeplearning4j.eval.Evaluation;
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.Updater;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.optimize.api.TrainingListener;
import org.deeplearning4j.scalnet.layers.core.Layer;
import org.deeplearning4j.scalnet.layers.core.Node;
import org.deeplearning4j.scalnet.layers.core.Preprocessor;
import org.deeplearning4j.scalnet.logging.Logging;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.DataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.lossfunctions.LossFunctions;
import org.slf4j.Logger;
import scala.MatchError;
import scala.Predef$;
import scala.Predef$ArrowAssoc$;
import scala.Tuple2;
import scala.collection.TraversableLike;
import scala.collection.immutable.List;
import scala.collection.immutable.List$;
import scala.collection.immutable.Map;
import scala.collection.immutable.Nil$;
import scala.reflect.ScalaSignature;
import scala.runtime.BoxedUnit;
import scala.runtime.BoxesRunTime;
import scala.runtime.ObjectRef;

/* compiled from: Sequential.scala */
@ScalaSignature(bytes = "\u0006\u0001\u0005=f\u0001B\u000f\u001f\u0001\u001dB\u0001\u0002\u000f\u0001\u0003\u0002\u0003\u0006I!\u000f\u0005\ty\u0001\u0011\t\u0011)A\u0005{!A\u0001\t\u0001B\u0001B\u0003%\u0011\tC\u0003E\u0001\u0011\u0005Q\tC\u0004K\u0001\u0001\u0007I\u0011B&\t\u000f\t\u0004\u0001\u0019!C\u0005G\"1\u0011\u000e\u0001Q!\n1CqA\u001b\u0001A\u0002\u0013%1\u000eC\u0004v\u0001\u0001\u0007I\u0011\u0002<\t\ra\u0004\u0001\u0015)\u0003m\u0011\u0015I\b\u0001\"\u0001l\u0011\u0015Q\b\u0001\"\u0001L\u0011\u001dY\bA1A\u0005\nqDa! \u0001!\u0002\u0013I\u0004\"\u0002@\u0001\t\u0013y\bbBA\u0003\u0001\u0011\u0005\u0011q\u0001\u0005\b\u0003\u0017\u0001A\u0011AA\u0007\u0011\u001d\t\t\u0002\u0001C\u0001\u0003'Aq!a\u0006\u0001\t\u0003\nI\u0002C\u0005\u0002n\u0001\t\n\u0011\"\u0001\u0002p!I\u0011Q\u0011\u0001\u0012\u0002\u0013\u0005\u0011qQ\u0004\b\u0003\u0017s\u0002\u0012AAG\r\u0019ib\u0004#\u0001\u0002\u0010\"1Ai\u0006C\u0001\u0003#Cq!a%\u0018\t\u0003\t)\nC\u0005\u0002\u001e^\t\n\u0011\"\u0001\u0002 \"I\u00111U\f\u0012\u0002\u0013\u0005\u0011Q\u0015\u0005\n\u0003S;\u0012\u0013!C\u0001\u0003W\u0013!bU3rk\u0016tG/[1m\u0015\ty\u0002%\u0001\u0004n_\u0012,Gn\u001d\u0006\u0003C\t\nqa]2bY:,GO\u0003\u0002$I\u0005qA-Z3qY\u0016\f'O\\5oORR'\"A\u0013\u0002\u0007=\u0014xm\u0001\u0001\u0014\t\u0001AcF\r\t\u0003S1j\u0011A\u000b\u0006\u0002W\u0005)1oY1mC&\u0011QF\u000b\u0002\u0007\u0003:L(+\u001a4\u0011\u0005=\u0002T\"\u0001\u0010\n\u0005Er\"!B'pI\u0016d\u0007CA\u001a7\u001b\u0005!$BA\u001b!\u0003\u001dawnZ4j]\u001eL!a\u000e\u001b\u0003\u000f1{wmZ5oO\u0006IQ.\u001b8j\u0005\u0006$8\r\u001b\t\u0003SiJ!a\u000f\u0016\u0003\u000f\t{w\u000e\\3b]\u0006A!-[1t\u0013:LG\u000f\u0005\u0002*}%\u0011qH\u000b\u0002\u0007\t>,(\r\\3\u0002\u000fItwmU3fIB\u0011\u0011FQ\u0005\u0003\u0007*\u0012A\u0001T8oO\u00061A(\u001b8jiz\"BAR$I\u0013B\u0011q\u0006\u0001\u0005\u0006q\u0011\u0001\r!\u000f\u0005\u0006y\u0011\u0001\r!\u0010\u0005\u0006\u0001\u0012\u0001\r!Q\u0001\u000f?B\u0014X\r\u001d:pG\u0016\u001c8o\u001c:t+\u0005a\u0005\u0003B'U/js!A\u0014*\u0011\u0005=SS\"\u0001)\u000b\u0005E3\u0013A\u0002\u001fs_>$h(\u0003\u0002TU\u00051\u0001K]3eK\u001aL!!\u0016,\u0003\u00075\u000b\u0007O\u0003\u0002TUA\u0011\u0011\u0006W\u0005\u00033*\u00121!\u00138u!\tY\u0006-D\u0001]\u0015\tif,\u0001\u0003d_J,'BA0!\u0003\u0019a\u0017-_3sg&\u0011\u0011\r\u0018\u0002\u0005\u001d>$W-\u0001\n`aJ,\u0007O]8dKN\u001cxN]:`I\u0015\fHC\u00013h!\tIS-\u0003\u0002gU\t!QK\\5u\u0011\u001dAg!!AA\u00021\u000b1\u0001\u001f\u00132\u0003=y\u0006O]3qe>\u001cWm]:peN\u0004\u0013aC0j]B,Ho\u00155ba\u0016,\u0012\u0001\u001c\t\u0004[J<fB\u00018q\u001d\tyu.C\u0001,\u0013\t\t(&A\u0004qC\u000e\\\u0017mZ3\n\u0005M$(\u0001\u0002'jgRT!!\u001d\u0016\u0002\u001f}Kg\u000e];u'\"\f\u0007/Z0%KF$\"\u0001Z<\t\u000f!L\u0011\u0011!a\u0001Y\u0006aq,\u001b8qkR\u001c\u0006.\u00199fA\u0005Q\u0011N\u001c9viNC\u0017\r]3\u0002!\u001d,G\u000f\u0015:faJ|7-Z:t_J\u001c\u0018\u0001\u00038p\u0019\u0006LXM]:\u0016\u0003e\n\u0011B\\8MCf,'o\u001d\u0011\u0002\u0015\u0015l\u0007\u000f^=TQ\u0006\u0004X\rF\u0002:\u0003\u0003Aa!a\u0001\u0010\u0001\u0004Q\u0016!\u00027bs\u0016\u0014\u0018aD5oM\u0016\u0014\u0018J\u001c9viNC\u0017\r]3\u0015\u00071\fI\u0001\u0003\u0004\u0002\u0004A\u0001\rAW\u0001\u000bG\",7m[*iCB,Gc\u00013\u0002\u0010!1\u00111A\tA\u0002i\u000b1!\u00193e)\r!\u0017Q\u0003\u0005\u0007\u0003\u0007\u0011\u0002\u0019\u0001.\u0002\u000f\r|W\u000e]5mKR9A-a\u0007\u0002J\u0005u\u0003bBA\u000f'\u0001\u0007\u0011qD\u0001\rY>\u001c8OR;oGRLwN\u001c\t\u0005\u0003C\t\u0019E\u0004\u0003\u0002$\u0005ub\u0002BA\u0013\u0003oqA!a\n\u000229!\u0011\u0011FA\u0017\u001d\ry\u00151F\u0005\u0002K%\u0019\u0011q\u0006\u0013\u0002\t9$GG[\u0005\u0005\u0003g\t)$\u0001\u0004mS:\fGn\u001a\u0006\u0004\u0003_!\u0013\u0002BA\u001d\u0003w\tQ\u0002\\8tg\u001a,hn\u0019;j_:\u001c(\u0002BA\u001a\u0003kIA!a\u0010\u0002B\u0005iAj\\:t\rVt7\r^5p]NTA!!\u000f\u0002<%!\u0011QIA$\u00051aun]:Gk:\u001cG/[8o\u0015\u0011\ty$!\u0011\t\u0013\u0005-3\u0003%AA\u0002\u00055\u0013!C8qi&l\u0017N_3s!\u0011\ty%!\u0017\u000e\u0005\u0005E#\u0002BA*\u0003+\n1!\u00199j\u0015\r\t9FI\u0001\u0003]:LA!a\u0017\u0002R\t)r\n\u001d;j[&T\u0018\r^5p]\u0006cwm\u001c:ji\"l\u0007\"CA0'A\u0005\t\u0019AA1\u0003\u001d)\b\u000fZ1uKJ\u0004B!a\u0019\u0002j5\u0011\u0011Q\r\u0006\u0005\u0003O\n)&\u0001\u0003d_:4\u0017\u0002BA6\u0003K\u0012q!\u00169eCR,'/A\td_6\u0004\u0018\u000e\\3%I\u00164\u0017-\u001e7uII*\"!!\u001d+\t\u00055\u00131O\u0016\u0003\u0003k\u0002B!a\u001e\u0002\u00026\u0011\u0011\u0011\u0010\u0006\u0005\u0003w\ni(A\u0005v]\u000eDWmY6fI*\u0019\u0011q\u0010\u0016\u0002\u0015\u0005tgn\u001c;bi&|g.\u0003\u0003\u0002\u0004\u0006e$!E;oG\",7m[3e-\u0006\u0014\u0018.\u00198dK\u0006\t2m\\7qS2,G\u0005Z3gCVdG\u000fJ\u001a\u0016\u0005\u0005%%\u0006BA1\u0003g\n!bU3rk\u0016tG/[1m!\tysc\u0005\u0002\u0018QQ\u0011\u0011QR\u0001\u0006CB\u0004H.\u001f\u000b\b\r\u0006]\u0015\u0011TAN\u0011\u001dA\u0014\u0004%AA\u0002eBq\u0001P\r\u0011\u0002\u0003\u0007Q\bC\u0004A3A\u0005\t\u0019A!\u0002\u001f\u0005\u0004\b\u000f\\=%I\u00164\u0017-\u001e7uIE*\"!!)+\u0007e\n\u0019(A\bbaBd\u0017\u0010\n3fM\u0006,H\u000e\u001e\u00133+\t\t9KK\u0002>\u0003g\nq\"\u00199qYf$C-\u001a4bk2$HeM\u000b\u0003\u0003[S3!QA:\u0001")
/* loaded from: input_file:org/deeplearning4j/scalnet/models/Sequential.class */
public class Sequential implements Model {
    private final boolean miniBatch;
    private final double biasInit;
    private final long rngSeed;
    private Map<Object, Node> _preprocessors;
    private List<Object> _inputShape;
    private final boolean noLayers;
    private List<Node> layers;
    private MultiLayerNetwork model;
    private Logger logger;
    private volatile boolean bitmap$0;

    public static Sequential apply(boolean z, double d, long j) {
        return Sequential$.MODULE$.apply(z, d, j);
    }

    @Override // org.deeplearning4j.scalnet.models.Model
    public List<Node> getLayers() {
        List<Node> layers;
        layers = getLayers();
        return layers;
    }

    @Override // org.deeplearning4j.scalnet.models.Model
    public NeuralNetConfiguration.Builder buildModelConfig(OptimizationAlgorithm optimizationAlgorithm, Updater updater, boolean z, double d, long j) {
        NeuralNetConfiguration.Builder buildModelConfig;
        buildModelConfig = buildModelConfig(optimizationAlgorithm, updater, z, d, j);
        return buildModelConfig;
    }

    @Override // org.deeplearning4j.scalnet.models.Model
    public void buildOutput(LossFunctions.LossFunction lossFunction) {
        buildOutput(lossFunction);
    }

    @Override // org.deeplearning4j.scalnet.models.Model
    public void fit(DataSetIterator dataSetIterator, int i, List<TrainingListener> list) {
        fit(dataSetIterator, i, (List<TrainingListener>) list);
    }

    @Override // org.deeplearning4j.scalnet.models.Model
    public void fit(DataSet dataSet, int i, List<TrainingListener> list) {
        fit(dataSet, i, (List<TrainingListener>) list);
    }

    @Override // org.deeplearning4j.scalnet.models.Model
    public INDArray predict(INDArray iNDArray) {
        INDArray predict;
        predict = predict(iNDArray);
        return predict;
    }

    @Override // org.deeplearning4j.scalnet.models.Model
    public INDArray predict(DataSet dataSet) {
        INDArray predict;
        predict = predict(dataSet);
        return predict;
    }

    @Override // org.deeplearning4j.scalnet.models.Model
    public Evaluation evaluate(DataSetIterator dataSetIterator) {
        Evaluation evaluate;
        evaluate = evaluate(dataSetIterator);
        return evaluate;
    }

    @Override // org.deeplearning4j.scalnet.models.Model
    public Evaluation evaluate(DataSetIterator dataSetIterator, int i) {
        Evaluation evaluate;
        evaluate = evaluate(dataSetIterator, i);
        return evaluate;
    }

    @Override // org.deeplearning4j.scalnet.models.Model
    public Evaluation evaluate(DataSet dataSet) {
        Evaluation evaluate;
        evaluate = evaluate(dataSet);
        return evaluate;
    }

    @Override // org.deeplearning4j.scalnet.models.Model
    public Evaluation evaluate(DataSet dataSet, int i) {
        Evaluation evaluate;
        evaluate = evaluate(dataSet, i);
        return evaluate;
    }

    @Override // org.deeplearning4j.scalnet.models.Model
    public String toString() {
        String model;
        model = toString();
        return model;
    }

    @Override // org.deeplearning4j.scalnet.models.Model
    public String toJson() {
        String json;
        json = toJson();
        return json;
    }

    @Override // org.deeplearning4j.scalnet.models.Model
    public String toYaml() {
        String yaml;
        yaml = toYaml();
        return yaml;
    }

    @Override // org.deeplearning4j.scalnet.models.Model
    public MultiLayerNetwork getNetwork() {
        MultiLayerNetwork network;
        network = getNetwork();
        return network;
    }

    @Override // org.deeplearning4j.scalnet.models.Model
    public List<Node> layers() {
        return this.layers;
    }

    @Override // org.deeplearning4j.scalnet.models.Model
    public void layers_$eq(List<Node> list) {
        this.layers = list;
    }

    @Override // org.deeplearning4j.scalnet.models.Model
    public MultiLayerNetwork model() {
        return this.model;
    }

    @Override // org.deeplearning4j.scalnet.models.Model
    public void model_$eq(MultiLayerNetwork multiLayerNetwork) {
        this.model = multiLayerNetwork;
    }

    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r0v0 */
    /* JADX WARN: Type inference failed for: r0v1, types: [java.lang.Throwable] */
    /* JADX WARN: Type inference failed for: r0v8, types: [org.deeplearning4j.scalnet.models.Sequential] */
    private Logger logger$lzycompute() {
        Logger logger;
        ?? r0 = this;
        synchronized (r0) {
            if (!this.bitmap$0) {
                logger = logger();
                this.logger = logger;
                r0 = this;
                r0.bitmap$0 = true;
            }
        }
        return this.logger;
    }

    @Override // org.deeplearning4j.scalnet.logging.Logging
    public Logger logger() {
        return !this.bitmap$0 ? logger$lzycompute() : this.logger;
    }

    private Map<Object, Node> _preprocessors() {
        return this._preprocessors;
    }

    private void _preprocessors_$eq(Map<Object, Node> map) {
        this._preprocessors = map;
    }

    private List<Object> _inputShape() {
        return this._inputShape;
    }

    private void _inputShape_$eq(List<Object> list) {
        this._inputShape = list;
    }

    public List<Object> inputShape() {
        return _inputShape();
    }

    public Map<Object, Node> getPreprocessors() {
        return _preprocessors();
    }

    private boolean noLayers() {
        return this.noLayers;
    }

    private boolean emptyShape(Node node) {
        return !_preprocessors().contains(BoxesRunTime.boxToInteger(layers().length())) && !layers().nonEmpty() && node.inputShape().lengthCompare(1) == 0 && BoxesRunTime.unboxToInt(node.inputShape().head()) == 0;
    }

    public List<Object> inferInputShape(Node node) {
        return _preprocessors().contains(BoxesRunTime.boxToInteger(layers().length())) ? ((Node) _preprocessors().apply(BoxesRunTime.boxToInteger(layers().length()))).outputShape() : (List) layers().lastOption().map(node2 -> {
            return node2.outputShape();
        }).getOrElse(() -> {
            return node.inputShape();
        });
    }

    public void checkShape(Node node) {
        if (emptyShape(node)) {
            throw new IllegalArgumentException("Input layer must have non-empty inputShape");
        }
        if (noLayers()) {
            _inputShape_$eq(node.inputShape());
        }
    }

    public void add(Node node) {
        List<Object> inferInputShape = inferInputShape(node);
        checkShape(node);
        Node reshapeInput = node.reshapeInput(inferInputShape);
        if (reshapeInput instanceof Preprocessor) {
            _preprocessors_$eq(_preprocessors().$plus(Predef$ArrowAssoc$.MODULE$.$minus$greater$extension(Predef$.MODULE$.ArrowAssoc(BoxesRunTime.boxToInteger(layers().length())), reshapeInput)));
            BoxedUnit boxedUnit = BoxedUnit.UNIT;
        } else {
            layers_$eq((List) layers().$colon$plus(reshapeInput, List$.MODULE$.canBuildFrom()));
            BoxedUnit boxedUnit2 = BoxedUnit.UNIT;
        }
    }

    @Override // org.deeplearning4j.scalnet.models.Model
    public void compile(LossFunctions.LossFunction lossFunction, OptimizationAlgorithm optimizationAlgorithm, Updater updater) {
        NeuralNetConfiguration.Builder buildModelConfig = buildModelConfig(optimizationAlgorithm, updater, this.miniBatch, this.biasInit, this.rngSeed);
        buildOutput(lossFunction);
        ObjectRef create = ObjectRef.create(buildModelConfig.list());
        ((TraversableLike) layers().zipWithIndex(List$.MODULE$.canBuildFrom())).withFilter(tuple2 -> {
            return BoxesRunTime.boxToBoolean($anonfun$compile$1(tuple2));
        }).foreach(tuple22 -> {
            if (tuple22 == null) {
                throw new MatchError(tuple22);
            }
            Node node = (Node) tuple22._1();
            int _2$mcI$sp = tuple22._2$mcI$sp();
            this.logger().info(new StringBuilder(8).append("Layer ").append(_2$mcI$sp).append(": ").append(node.getClass().getSimpleName()).toString());
            this.logger().info(new StringBuilder(7).append(" size: ").append(node.describe()).toString());
            return ((NeuralNetConfiguration.ListBuilder) create.elem).layer(_2$mcI$sp, ((Layer) node).compile());
        });
        _preprocessors().withFilter(tuple23 -> {
            return BoxesRunTime.boxToBoolean($anonfun$compile$3(tuple23));
        }).foreach(tuple24 -> {
            if (tuple24 == null) {
                throw new MatchError(tuple24);
            }
            int _1$mcI$sp = tuple24._1$mcI$sp();
            Node node = (Node) tuple24._2();
            this.logger().info(new StringBuilder(15).append("Preprocessor ").append(_1$mcI$sp).append(": ").append(node.getClass().getSimpleName()).toString());
            this.logger().info(new StringBuilder(7).append(" size: ").append(node.describe()).toString());
            return ((NeuralNetConfiguration.ListBuilder) create.elem).inputPreProcessor(Predef$.MODULE$.int2Integer(_1$mcI$sp), ((Preprocessor) node).compile());
        });
        model_$eq(new MultiLayerNetwork(((NeuralNetConfiguration.ListBuilder) create.elem).build()));
        model().init();
    }

    public OptimizationAlgorithm compile$default$2() {
        return OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT;
    }

    public Updater compile$default$3() {
        return Updater.SGD;
    }

    public static final /* synthetic */ boolean $anonfun$compile$1(Tuple2 tuple2) {
        return tuple2 != null;
    }

    public static final /* synthetic */ boolean $anonfun$compile$3(Tuple2 tuple2) {
        return tuple2 != null;
    }

    public Sequential(boolean z, double d, long j) {
        this.miniBatch = z;
        this.biasInit = d;
        this.rngSeed = j;
        Logging.$init$(this);
        layers_$eq(Nil$.MODULE$);
        this._preprocessors = Predef$.MODULE$.Map().apply(Nil$.MODULE$);
        this._inputShape = Nil$.MODULE$;
        this.noLayers = inputShape().isEmpty() && layers().isEmpty() && _preprocessors().isEmpty();
    }
}
