package org.deeplearning4j.spark.ml.classification;

import org.apache.spark.annotation.DeveloperApi;
import org.apache.spark.broadcast.Broadcast;
import org.apache.spark.ml.classification.ClassificationModel;
import org.apache.spark.ml.param.IntParam;
import org.apache.spark.ml.param.Param;
import org.apache.spark.ml.param.ParamMap;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.sql.DataFrame;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.types.StructType;
import org.deeplearning4j.spark.ml.param.shared.HasEpochs;
import org.deeplearning4j.spark.ml.param.shared.HasMultiLayerConfiguration;
import org.nd4j.linalg.api.ndarray.INDArray;
import scala.NotImplementedError;
import scala.reflect.ClassTag$;
import scala.reflect.ScalaSignature;

/* compiled from: MultiLayerNetworkClassification.scala */
@DeveloperApi
@ScalaSignature(bytes = "\u0006\u0001e4A!\u0001\u0002\u0001\u001b\t\u0001c*Z;sC2tU\r^<pe.\u001cE.Y:tS\u001aL7-\u0019;j_:lu\u000eZ3m\u0015\t\u0019A!\u0001\bdY\u0006\u001c8/\u001b4jG\u0006$\u0018n\u001c8\u000b\u0005\u00151\u0011AA7m\u0015\t9\u0001\"A\u0003ta\u0006\u00148N\u0003\u0002\n\u0015\u0005qA-Z3qY\u0016\f'O\\5oORR'\"A\u0006\u0002\u0007=\u0014xm\u0001\u0001\u0014\u0007\u0001q\u0011\u0005\u0005\u0003\u0010+]yR\"\u0001\t\u000b\u0005\r\t\"BA\u0003\u0013\u0015\t91C\u0003\u0002\u0015\u0015\u00051\u0011\r]1dQ\u0016L!A\u0006\t\u0003'\rc\u0017m]:jM&\u001c\u0017\r^5p]6{G-\u001a7\u0011\u0005aiR\"A\r\u000b\u0005iY\u0012A\u00027j]\u0006dwM\u0003\u0002\u001d%\u0005)Q\u000e\u001c7jE&\u0011a$\u0007\u0002\u0007-\u0016\u001cGo\u001c:\u0011\u0005\u0001\u0002Q\"\u0001\u0002\u0011\u0005\u0001\u0012\u0013BA\u0012\u0003\u0005\u0005rU-\u001e:bY:+Go^8sW\u000ec\u0017m]:jM&\u001c\u0017\r^5p]B\u000b'/Y7t\u0011!)\u0003A!b\u0001\n\u00032\u0013aA;jIV\tq\u0005\u0005\u0002)]9\u0011\u0011\u0006L\u0007\u0002U)\t1&A\u0003tG\u0006d\u0017-\u0003\u0002.U\u00051\u0001K]3eK\u001aL!a\f\u0019\u0003\rM#(/\u001b8h\u0015\ti#\u0006\u0003\u00053\u0001\t\u0005\t\u0015!\u0003(\u0003\u0011)\u0018\u000e\u001a\u0011\t\u0011Q\u0002!Q1A\u0005BU\n!B\\;n\u00072\f7o]3t+\u00051\u0004CA\u00158\u0013\tA$FA\u0002J]RD\u0001B\u000f\u0001\u0003\u0002\u0003\u0006IAN\u0001\f]Vl7\t\\1tg\u0016\u001c\b\u0005\u0003\u0005=\u0001\t\u0015\r\u0011\"\u0001>\u00035qW\r^<pe.\u0004\u0016M]1ngV\ta\bE\u0002@\u0005\u0012k\u0011\u0001\u0011\u0006\u0003\u0003J\t\u0011B\u0019:pC\u0012\u001c\u0017m\u001d;\n\u0005\r\u0003%!\u0003\"s_\u0006$7-Y:u!\t)U*D\u0001G\u0015\t9\u0005*A\u0004oI\u0006\u0014(/Y=\u000b\u0005%S\u0015aA1qS*\u0011!d\u0013\u0006\u0003\u0019*\tAA\u001c35U&\u0011aJ\u0012\u0002\t\u0013:#\u0015I\u001d:bs\"A\u0001\u000b\u0001B\u0001B\u0003%a(\u0001\boKR<xN]6QCJ\fWn\u001d\u0011\t\rI\u0003A\u0011\u0001\u0003T\u0003\u0019a\u0014N\\5u}Q!q\u0004V+W\u0011\u0015)\u0013\u000b1\u0001(\u0011\u0015!\u0014\u000b1\u00017\u0011\u0015a\u0014\u000b1\u0001?\u0011\u0015A\u0006\u0001\"\u0015Z\u0003)\u0001(/\u001a3jGR\u0014\u0016m\u001e\u000b\u0003/iCQaW,A\u0002]\t\u0001BZ3biV\u0014Xm\u001d\u0005\u0006;\u0002!\tEX\u0001\niJ\fgn\u001d4pe6$\"aX3\u0011\u0005\u0001\u001cW\"A1\u000b\u0005\t\u0014\u0012aA:rY&\u0011A-\u0019\u0002\n\t\u0006$\u0018M\u0012:b[\u0016DQA\u001a/A\u0002}\u000bq\u0001Z1uCN,G\u000fC\u0003i\u0001\u0011\u0005\u0013.\u0001\u0003d_BLHCA\u0010k\u0011\u0015Yw\r1\u0001m\u0003\u0015)\u0007\u0010\u001e:b!\ti\u0007/D\u0001o\u0015\ty\u0017#A\u0003qCJ\fW.\u0003\u0002r]\nA\u0001+\u0019:b[6\u000b\u0007\u000f\u000b\u0002\u0001gB\u0011Ao^\u0007\u0002k*\u0011aOE\u0001\u000bC:tw\u000e^1uS>t\u0017B\u0001=v\u00051!UM^3m_B,'/\u00119j\u0001")
/* loaded from: input_file:org/deeplearning4j/spark/ml/classification/NeuralNetworkClassificationModel.class */
public class NeuralNetworkClassificationModel extends ClassificationModel<Vector, NeuralNetworkClassificationModel> implements NeuralNetworkClassificationParams {
    private final String uid;
    private final int numClasses;
    private final Broadcast<INDArray> networkParams;
    private final IntParam epochs;
    private final Param<String> conf;

    @Override // org.deeplearning4j.spark.ml.param.shared.HasEpochs
    public IntParam epochs() {
        return this.epochs;
    }

    @Override // org.deeplearning4j.spark.ml.param.shared.HasEpochs
    public void org$deeplearning4j$spark$ml$param$shared$HasEpochs$_setter_$epochs_$eq(IntParam intParam) {
        this.epochs = intParam;
    }

    @Override // org.deeplearning4j.spark.ml.param.shared.HasEpochs
    public int getEpochs() {
        return HasEpochs.Cclass.getEpochs(this);
    }

    @Override // org.deeplearning4j.spark.ml.param.shared.HasMultiLayerConfiguration
    public Param<String> conf() {
        return this.conf;
    }

    @Override // org.deeplearning4j.spark.ml.param.shared.HasMultiLayerConfiguration
    public void org$deeplearning4j$spark$ml$param$shared$HasMultiLayerConfiguration$_setter_$conf_$eq(Param param) {
        this.conf = param;
    }

    @Override // org.deeplearning4j.spark.ml.param.shared.HasMultiLayerConfiguration
    public String getConf() {
        return HasMultiLayerConfiguration.Cclass.getConf(this);
    }

    public String uid() {
        return this.uid;
    }

    public int numClasses() {
        return this.numClasses;
    }

    public Broadcast<INDArray> networkParams() {
        return this.networkParams;
    }

    public Vector predictRaw(Vector vector) {
        throw new NotImplementedError();
    }

    public DataFrame transform(DataFrame dataFrame) {
        StructType transformSchema = transformSchema(dataFrame.schema(), true);
        return dataFrame.sqlContext().createDataFrame(dataFrame.mapPartitions(new NeuralNetworkClassificationModel$$anonfun$3(this, transformSchema), ClassTag$.MODULE$.apply(Row.class)), transformSchema);
    }

    /* renamed from: copy, reason: merged with bridge method [inline-methods] and merged with bridge method [inline-methods] and merged with bridge method [inline-methods] and merged with bridge method [inline-methods] */
    public NeuralNetworkClassificationModel m8copy(ParamMap paramMap) {
        return (NeuralNetworkClassificationModel) copyValues(new NeuralNetworkClassificationModel(uid(), numClasses(), networkParams()), paramMap);
    }

    public NeuralNetworkClassificationModel(String str, int i, Broadcast<INDArray> broadcast) {
        this.uid = str;
        this.numClasses = i;
        this.networkParams = broadcast;
        org$deeplearning4j$spark$ml$param$shared$HasMultiLayerConfiguration$_setter_$conf_$eq(new Param(this, "conf", "multilayer configuration"));
        org$deeplearning4j$spark$ml$param$shared$HasEpochs$_setter_$epochs_$eq(new IntParam(this, "epochs", "number of epochs"));
    }
}
