package org.deeplearning4j.spark.ml.classification;

import org.apache.spark.annotation.DeveloperApi;
import org.apache.spark.broadcast.Broadcast;
import org.apache.spark.ml.classification.ClassificationModel;
import org.apache.spark.ml.param.IntParam;
import org.apache.spark.ml.param.Param;
import org.apache.spark.ml.param.ParamMap;
import org.apache.spark.mllib.linalg.Vector;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.optimize.api.IterationListener;
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
import org.deeplearning4j.spark.ml.param.shared.HasEpochs;
import org.deeplearning4j.spark.ml.param.shared.HasMultiLayerConfiguration;
import org.deeplearning4j.spark.util.package$conversions$;
import org.nd4j.linalg.api.ndarray.INDArray;
import scala.reflect.ScalaSignature;

/* compiled from: MultiLayerNetworkClassification.scala */
@DeveloperApi
@ScalaSignature(bytes = "\u0006\u0001\u0005\u0015b\u0001B\u0001\u0003\u00015\u0011\u0001ET3ve\u0006dg*\u001a;x_J\\7\t\\1tg&4\u0017nY1uS>tWj\u001c3fY*\u00111\u0001B\u0001\u000fG2\f7o]5gS\u000e\fG/[8o\u0015\t)a!\u0001\u0002nY*\u0011q\u0001C\u0001\u0006gB\f'o\u001b\u0006\u0003\u0013)\ta\u0002Z3fa2,\u0017M\u001d8j]\u001e$$NC\u0001\f\u0003\ry'oZ\u0002\u0001'\r\u0001a\"\t\t\u0005\u001fU9r$D\u0001\u0011\u0015\t\u0019\u0011C\u0003\u0002\u0006%)\u0011qa\u0005\u0006\u0003))\ta!\u00199bG\",\u0017B\u0001\f\u0011\u0005M\u0019E.Y:tS\u001aL7-\u0019;j_:lu\u000eZ3m!\tAR$D\u0001\u001a\u0015\tQ2$\u0001\u0004mS:\fGn\u001a\u0006\u00039I\tQ!\u001c7mS\nL!AH\r\u0003\rY+7\r^8s!\t\u0001\u0003!D\u0001\u0003!\t\u0001#%\u0003\u0002$\u0005\t\tc*Z;sC2tU\r^<pe.\u001cE.Y:tS\u001aL7-\u0019;j_:\u0004\u0016M]1ng\"AQ\u0005\u0001BC\u0002\u0013\u0005c%A\u0002vS\u0012,\u0012a\n\t\u0003Q9r!!\u000b\u0017\u000e\u0003)R\u0011aK\u0001\u0006g\u000e\fG.Y\u0005\u0003[)\na\u0001\u0015:fI\u00164\u0017BA\u00181\u0005\u0019\u0019FO]5oO*\u0011QF\u000b\u0005\te\u0001\u0011\t\u0011)A\u0005O\u0005!Q/\u001b3!\u0011!!\u0004A!b\u0001\n\u0003*\u0014A\u00038v[\u000ec\u0017m]:fgV\ta\u0007\u0005\u0002*o%\u0011\u0001H\u000b\u0002\u0004\u0013:$\b\u0002\u0003\u001e\u0001\u0005\u0003\u0005\u000b\u0011\u0002\u001c\u0002\u00179,Xn\u00117bgN,7\u000f\t\u0005\ty\u0001\u0011)\u0019!C\u0001{\u0005ia.\u001a;x_J\\\u0007+\u0019:b[N,\u0012A\u0010\t\u0004\u007f\t#U\"\u0001!\u000b\u0005\u0005\u0013\u0012!\u00032s_\u0006$7-Y:u\u0013\t\u0019\u0005IA\u0005Ce>\fGmY1tiB\u0011Q)T\u0007\u0002\r*\u0011q\tS\u0001\b]\u0012\f'O]1z\u0015\tI%*A\u0002ba&T!AG&\u000b\u00051S\u0011\u0001\u00028ei)L!A\u0014$\u0003\u0011%sE)\u0011:sCfD\u0001\u0002\u0015\u0001\u0003\u0002\u0003\u0006IAP\u0001\u000f]\u0016$xo\u001c:l!\u0006\u0014\u0018-\\:!\u0011\u0019\u0011\u0006\u0001\"\u0001\u0005'\u00061A(\u001b8jiz\"Ba\b+V-\")Q%\u0015a\u0001O!)A'\u0015a\u0001m!)A(\u0015a\u0001}!)\u0001\f\u0001C)3\u0006Q\u0001O]3eS\u000e$(+Y<\u0015\u0005]Q\u0006\"B.X\u0001\u00049\u0012\u0001\u00034fCR,(/Z:\t\u000bu\u0003A\u0011\t0\u0002\t\r|\u0007/\u001f\u000b\u0003?}CQ\u0001\u0019/A\u0002\u0005\fQ!\u001a=ue\u0006\u0004\"AY3\u000e\u0003\rT!\u0001Z\t\u0002\u000bA\f'/Y7\n\u0005\u0019\u001c'\u0001\u0003)be\u0006lW*\u00199\t\u000f!\u0004\u0001\u0019!C\u0005S\u0006ia.\u001a;x_J\\\u0007j\u001c7eKJ,\u0012A\u001b\t\u0004WB\u0014X\"\u00017\u000b\u00055t\u0017\u0001\u00027b]\u001eT\u0011a\\\u0001\u0005U\u00064\u0018-\u0003\u0002rY\nYA\u000b\u001b:fC\u0012dunY1m!\t\u0019\b0D\u0001u\u0015\t)h/\u0001\u0006nk2$\u0018\u000e\\1zKJT!a\u001e\u0005\u0002\u00059t\u0017BA=u\u0005EiU\u000f\u001c;j\u0019\u0006LXM\u001d(fi^|'o\u001b\u0005\bw\u0002\u0001\r\u0011\"\u0003}\u0003EqW\r^<pe.Du\u000e\u001c3fe~#S-\u001d\u000b\u0004{\u0006\u0005\u0001CA\u0015\u007f\u0013\ty(F\u0001\u0003V]&$\b\u0002CA\u0002u\u0006\u0005\t\u0019\u00016\u0002\u0007a$\u0013\u0007C\u0004\u0002\b\u0001\u0001\u000b\u0015\u00026\u0002\u001d9,Go^8sW\"{G\u000eZ3sA!\"\u0011QAA\u0006!\rI\u0013QB\u0005\u0004\u0003\u001fQ#!\u0003;sC:\u001c\u0018.\u001a8u\u0011\u001d\t\u0019\u0002\u0001C\u0005\u0003+\tqA\\3uo>\u00148\u000eF\u0001sQ\r\u0001\u0011\u0011\u0004\t\u0005\u00037\t\t#\u0004\u0002\u0002\u001e)\u0019\u0011q\u0004\n\u0002\u0015\u0005tgn\u001c;bi&|g.\u0003\u0003\u0002$\u0005u!\u0001\u0004#fm\u0016dw\u000e]3s\u0003BL\u0007")
/* loaded from: input_file:org/deeplearning4j/spark/ml/classification/NeuralNetworkClassificationModel.class */
public class NeuralNetworkClassificationModel extends ClassificationModel<Vector, NeuralNetworkClassificationModel> implements NeuralNetworkClassificationParams {
    private final String uid;
    private final int numClasses;
    private final Broadcast<INDArray> networkParams;
    private transient ThreadLocal<MultiLayerNetwork> networkHolder;
    private final IntParam epochs;
    private final Param<String> conf;

    @Override // org.deeplearning4j.spark.ml.param.shared.HasEpochs
    public IntParam epochs() {
        return this.epochs;
    }

    @Override // org.deeplearning4j.spark.ml.param.shared.HasEpochs
    public void org$deeplearning4j$spark$ml$param$shared$HasEpochs$_setter_$epochs_$eq(IntParam intParam) {
        this.epochs = intParam;
    }

    @Override // org.deeplearning4j.spark.ml.param.shared.HasEpochs
    public int getEpochs() {
        return HasEpochs.Cclass.getEpochs(this);
    }

    @Override // org.deeplearning4j.spark.ml.param.shared.HasMultiLayerConfiguration
    public Param<String> conf() {
        return this.conf;
    }

    @Override // org.deeplearning4j.spark.ml.param.shared.HasMultiLayerConfiguration
    public void org$deeplearning4j$spark$ml$param$shared$HasMultiLayerConfiguration$_setter_$conf_$eq(Param param) {
        this.conf = param;
    }

    @Override // org.deeplearning4j.spark.ml.param.shared.HasMultiLayerConfiguration
    public String getConf() {
        return HasMultiLayerConfiguration.Cclass.getConf(this);
    }

    public String uid() {
        return this.uid;
    }

    public int numClasses() {
        return this.numClasses;
    }

    public Broadcast<INDArray> networkParams() {
        return this.networkParams;
    }

    public Vector predictRaw(Vector vector) {
        return package$conversions$.MODULE$.toVector(network().output(package$conversions$.MODULE$.toINDArray(vector)));
    }

    /* renamed from: copy, reason: merged with bridge method [inline-methods] and merged with bridge method [inline-methods] and merged with bridge method [inline-methods] and merged with bridge method [inline-methods] */
    public NeuralNetworkClassificationModel m8copy(ParamMap paramMap) {
        return (NeuralNetworkClassificationModel) copyValues(new NeuralNetworkClassificationModel(uid(), numClasses(), networkParams()), paramMap);
    }

    private ThreadLocal<MultiLayerNetwork> networkHolder() {
        return this.networkHolder;
    }

    private void networkHolder_$eq(ThreadLocal<MultiLayerNetwork> threadLocal) {
        this.networkHolder = threadLocal;
    }

    private MultiLayerNetwork network() {
        if (networkHolder() == null) {
            networkHolder_$eq(new ThreadLocal<MultiLayerNetwork>(this) { // from class: org.deeplearning4j.spark.ml.classification.NeuralNetworkClassificationModel$$anon$1
                private final /* synthetic */ NeuralNetworkClassificationModel $outer;

                /* JADX WARN: Can't rename method to resolve collision */
                @Override // java.lang.ThreadLocal
                public MultiLayerNetwork initialValue() {
                    MultiLayerNetwork multiLayerNetwork = new MultiLayerNetwork(MultiLayerConfiguration.fromJson((String) this.$outer.$(this.$outer.conf())));
                    multiLayerNetwork.setListeners(new IterationListener[]{new ScoreIterationListener(1)});
                    multiLayerNetwork.init();
                    multiLayerNetwork.setParameters((INDArray) this.$outer.networkParams().value());
                    return multiLayerNetwork;
                }

                {
                    if (this == null) {
                        throw new NullPointerException();
                    }
                    this.$outer = this;
                }
            });
        }
        return networkHolder().get();
    }

    public NeuralNetworkClassificationModel(String str, int i, Broadcast<INDArray> broadcast) {
        this.uid = str;
        this.numClasses = i;
        this.networkParams = broadcast;
        org$deeplearning4j$spark$ml$param$shared$HasMultiLayerConfiguration$_setter_$conf_$eq(new Param(this, "conf", "multilayer configuration"));
        org$deeplearning4j$spark$ml$param$shared$HasEpochs$_setter_$epochs_$eq(new IntParam(this, "epochs", "number of epochs"));
        this.networkHolder = null;
    }
}
