package org.deeplearning4j.spark.ml.reconstruction;

import org.apache.spark.SparkContext;
import org.apache.spark.annotation.DeveloperApi;
import org.apache.spark.ml.param.IntParam;
import org.apache.spark.ml.param.Param;
import org.apache.spark.ml.param.ParamPair;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.sql.DataFrame;
import org.apache.spark.sql.types.DataType;
import org.apache.spark.sql.types.StructType;
import org.apache.spark.storage.StorageLevel;
import org.apache.spark.storage.StorageLevel$;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.spark.ml.UnsupervisedLearner;
import org.deeplearning4j.spark.ml.UnsupervisedLearnerParams;
import org.deeplearning4j.spark.ml.nn.ParameterAveragingTrainingStrategy;
import org.deeplearning4j.spark.ml.param.shared.HasEpochs;
import org.deeplearning4j.spark.ml.param.shared.HasLayerIndex;
import org.deeplearning4j.spark.ml.param.shared.HasMultiLayerConfiguration;
import org.deeplearning4j.spark.ml.param.shared.HasReconstructionCol;
import org.deeplearning4j.spark.ml.reconstruction.NeuralNetworkReconstructionParams;
import org.deeplearning4j.spark.ml.util.Identifiable$;
import org.nd4j.linalg.api.ndarray.INDArray;
import scala.Predef$;
import scala.reflect.ClassTag$;
import scala.reflect.ScalaSignature;
import scala.runtime.BoxedUnit;
import scala.runtime.BoxesRunTime;

/* compiled from: MultiLayerNetworkReconstruction.scala */
@DeveloperApi
@ScalaSignature(bytes = "\u0006\u000114A!\u0001\u0002\u0001\u001b\tYb*Z;sC2tU\r^<pe.\u0014VmY8ogR\u0014Xo\u0019;j_:T!a\u0001\u0003\u0002\u001dI,7m\u001c8tiJ,8\r^5p]*\u0011QAB\u0001\u0003[2T!a\u0002\u0005\u0002\u000bM\u0004\u0018M]6\u000b\u0005%Q\u0011A\u00043fKBdW-\u0019:oS:<GG\u001b\u0006\u0002\u0017\u0005\u0019qN]4\u0004\u0001M\u0019\u0001A\u0004\u0012\u0011\u000b=\u0001\"#H\u0010\u000e\u0003\u0011I!!\u0005\u0003\u0003'Us7/\u001e9feZL7/\u001a3MK\u0006\u0014h.\u001a:\u0011\u0005MYR\"\u0001\u000b\u000b\u0005U1\u0012A\u00027j]\u0006dwM\u0003\u0002\u00181\u0005)Q\u000e\u001c7jE*\u0011q!\u0007\u0006\u00035)\ta!\u00199bG\",\u0017B\u0001\u000f\u0015\u0005\u00191Vm\u0019;peB\u0011a\u0004A\u0007\u0002\u0005A\u0011a\u0004I\u0005\u0003C\t\u0011\u0001ET3ve\u0006dg*\u001a;x_J\\'+Z2p]N$(/^2uS>tWj\u001c3fYB\u0011adI\u0005\u0003I\t\u0011\u0011ET3ve\u0006dg*\u001a;x_J\\'+Z2p]N$(/^2uS>t\u0007+\u0019:b[ND\u0001B\n\u0001\u0003\u0006\u0004%\teJ\u0001\u0004k&$W#\u0001\u0015\u0011\u0005%zcB\u0001\u0016.\u001b\u0005Y#\"\u0001\u0017\u0002\u000bM\u001c\u0017\r\\1\n\u00059Z\u0013A\u0002)sK\u0012,g-\u0003\u00021c\t11\u000b\u001e:j]\u001eT!AL\u0016\t\u0011M\u0002!\u0011!Q\u0001\n!\nA!^5eA!)Q\u0007\u0001C\u0001m\u00051A(\u001b8jiz\"\"!H\u001c\t\u000b\u0019\"\u0004\u0019\u0001\u0015\t\u000bU\u0002A\u0011A\u001d\u0015\u0003uAQa\u000f\u0001\u0005\u0002q\nqa]3u\u0007>tg\r\u0006\u0002>}5\t\u0001\u0001C\u0003@u\u0001\u0007\u0001&A\u0003wC2,X\rC\u0003<\u0001\u0011\u0005\u0011\t\u0006\u0002>\u0005\")q\b\u0011a\u0001\u0007B\u0011A)S\u0007\u0002\u000b*\u0011aiR\u0001\u0005G>tgM\u0003\u0002I\u0011\u0005\u0011aN\\\u0005\u0003\u0015\u0016\u0013q#T;mi&d\u0015-_3s\u0007>tg-[4ve\u0006$\u0018n\u001c8\t\u000b1\u0003A\u0011A'\u0002\u0013M,G/\u00129pG\"\u001cHCA\u001fO\u0011\u0015y4\n1\u0001P!\tQ\u0003+\u0003\u0002RW\t\u0019\u0011J\u001c;\t\u000bM\u0003A\u0011\u0001+\u0002\u001bM,G\u000fT1zKJLe\u000eZ3y)\tiT\u000bC\u0003@%\u0002\u0007q\nC\u0003X\u0001\u0011\u0005\u0001,\u0001\u000btKR\u0014VmY8ogR\u0014Xo\u0019;j_:\u001cu\u000e\u001c\u000b\u0003{eCQa\u0010,A\u0002!BQa\u0017\u0001\u0005Rq\u000bQ\u0001\\3be:$\"aH/\t\u000byS\u0006\u0019A0\u0002\u000f\u0011\fG/Y:fiB\u0011\u0001mY\u0007\u0002C*\u0011!\rG\u0001\u0004gFd\u0017B\u00013b\u0005%!\u0015\r^1Ge\u0006lW\r\u000b\u0002\u0001MB\u0011qM[\u0007\u0002Q*\u0011\u0011\u000eG\u0001\u000bC:tw\u000e^1uS>t\u0017BA6i\u00051!UM^3m_B,'/\u00119j\u0001")
/* loaded from: input_file:org/deeplearning4j/spark/ml/reconstruction/NeuralNetworkReconstruction.class */
public class NeuralNetworkReconstruction extends UnsupervisedLearner<Vector, NeuralNetworkReconstruction, NeuralNetworkReconstructionModel> implements NeuralNetworkReconstructionParams {
    private final String uid;
    private final Param<String> reconstructionCol;
    private final IntParam layerIndex;
    private final IntParam epochs;
    private final Param<String> conf;

    @Override // org.deeplearning4j.spark.ml.reconstruction.NeuralNetworkReconstructionParams
    public StructType org$deeplearning4j$spark$ml$reconstruction$NeuralNetworkReconstructionParams$$super$validateAndTransformSchema(StructType structType, boolean z, DataType dataType) {
        return UnsupervisedLearnerParams.Cclass.validateAndTransformSchema(this, structType, z, dataType);
    }

    @Override // org.deeplearning4j.spark.ml.UnsupervisedLearner, org.deeplearning4j.spark.ml.UnsupervisedLearnerParams
    public StructType validateAndTransformSchema(StructType structType, boolean z, DataType dataType) {
        return NeuralNetworkReconstructionParams.Cclass.validateAndTransformSchema(this, structType, z, dataType);
    }

    @Override // org.deeplearning4j.spark.ml.param.shared.HasReconstructionCol
    public Param<String> reconstructionCol() {
        return this.reconstructionCol;
    }

    @Override // org.deeplearning4j.spark.ml.param.shared.HasReconstructionCol
    public void org$deeplearning4j$spark$ml$param$shared$HasReconstructionCol$_setter_$reconstructionCol_$eq(Param param) {
        this.reconstructionCol = param;
    }

    @Override // org.deeplearning4j.spark.ml.param.shared.HasReconstructionCol
    public String getReconstructionCol() {
        return HasReconstructionCol.Cclass.getReconstructionCol(this);
    }

    @Override // org.deeplearning4j.spark.ml.param.shared.HasLayerIndex
    public IntParam layerIndex() {
        return this.layerIndex;
    }

    @Override // org.deeplearning4j.spark.ml.param.shared.HasLayerIndex
    public void org$deeplearning4j$spark$ml$param$shared$HasLayerIndex$_setter_$layerIndex_$eq(IntParam intParam) {
        this.layerIndex = intParam;
    }

    @Override // org.deeplearning4j.spark.ml.param.shared.HasLayerIndex
    public int getLayerIndex() {
        return HasLayerIndex.Cclass.getLayerIndex(this);
    }

    @Override // org.deeplearning4j.spark.ml.param.shared.HasEpochs
    public IntParam epochs() {
        return this.epochs;
    }

    @Override // org.deeplearning4j.spark.ml.param.shared.HasEpochs
    public void org$deeplearning4j$spark$ml$param$shared$HasEpochs$_setter_$epochs_$eq(IntParam intParam) {
        this.epochs = intParam;
    }

    @Override // org.deeplearning4j.spark.ml.param.shared.HasEpochs
    public int getEpochs() {
        return HasEpochs.Cclass.getEpochs(this);
    }

    @Override // org.deeplearning4j.spark.ml.param.shared.HasMultiLayerConfiguration
    public Param<String> conf() {
        return this.conf;
    }

    @Override // org.deeplearning4j.spark.ml.param.shared.HasMultiLayerConfiguration
    public void org$deeplearning4j$spark$ml$param$shared$HasMultiLayerConfiguration$_setter_$conf_$eq(Param param) {
        this.conf = param;
    }

    @Override // org.deeplearning4j.spark.ml.param.shared.HasMultiLayerConfiguration
    public String getConf() {
        return HasMultiLayerConfiguration.Cclass.getConf(this);
    }

    public String uid() {
        return this.uid;
    }

    public NeuralNetworkReconstruction setConf(String str) {
        return (NeuralNetworkReconstruction) set(conf(), str);
    }

    public NeuralNetworkReconstruction setConf(MultiLayerConfiguration multiLayerConfiguration) {
        return (NeuralNetworkReconstruction) set(conf(), multiLayerConfiguration.toJson());
    }

    public NeuralNetworkReconstruction setEpochs(int i) {
        return (NeuralNetworkReconstruction) set(epochs(), BoxesRunTime.boxToInteger(i));
    }

    public NeuralNetworkReconstruction setLayerIndex(int i) {
        return (NeuralNetworkReconstruction) set(layerIndex(), BoxesRunTime.boxToInteger(i));
    }

    public NeuralNetworkReconstruction setReconstructionCol(String str) {
        return (NeuralNetworkReconstruction) set(reconstructionCol(), str);
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // org.deeplearning4j.spark.ml.UnsupervisedLearner
    public NeuralNetworkReconstructionModel learn(DataFrame dataFrame) {
        SparkContext sparkContext = dataFrame.sqlContext().sparkContext();
        MultiLayerConfiguration fromJson = MultiLayerConfiguration.fromJson((String) $(conf()));
        DataFrame select = dataFrame.select((String) $(featuresCol()), Predef$.MODULE$.wrapRefArray(new String[0]));
        StorageLevel storageLevel = dataFrame.rdd().getStorageLevel();
        StorageLevel NONE = StorageLevel$.MODULE$.NONE();
        boolean z = storageLevel != null ? storageLevel.equals(NONE) : NONE == null;
        if (z) {
            select.persist(StorageLevel$.MODULE$.MEMORY_AND_DISK());
        } else {
            BoxedUnit boxedUnit = BoxedUnit.UNIT;
        }
        INDArray train = new ParameterAveragingTrainingStrategy(fromJson, BoxesRunTime.unboxToInt($(epochs()))).train(select.rdd(), new NeuralNetworkReconstruction$$anonfun$1(this));
        if (z) {
            select.unpersist();
        } else {
            BoxedUnit boxedUnit2 = BoxedUnit.UNIT;
        }
        return new NeuralNetworkReconstructionModel(uid(), sparkContext.broadcast(train, ClassTag$.MODULE$.apply(INDArray.class)));
    }

    public NeuralNetworkReconstruction(String str) {
        this.uid = str;
        org$deeplearning4j$spark$ml$param$shared$HasMultiLayerConfiguration$_setter_$conf_$eq(new Param(this, "conf", "multilayer configuration"));
        org$deeplearning4j$spark$ml$param$shared$HasEpochs$_setter_$epochs_$eq(new IntParam(this, "epochs", "number of epochs"));
        org$deeplearning4j$spark$ml$param$shared$HasLayerIndex$_setter_$layerIndex_$eq(new IntParam(this, "layerIndex", "layer index (one-based)"));
        org$deeplearning4j$spark$ml$param$shared$HasReconstructionCol$_setter_$reconstructionCol_$eq(new Param(this, "reconstructionCol", "reconstruction column name"));
        NeuralNetworkReconstructionParams.Cclass.$init$(this);
        setDefault(Predef$.MODULE$.wrapRefArray(new ParamPair[]{epochs().$minus$greater(BoxesRunTime.boxToInteger(1))}));
        setDefault(Predef$.MODULE$.wrapRefArray(new ParamPair[]{layerIndex().$minus$greater(BoxesRunTime.boxToInteger(1))}));
        setDefault(Predef$.MODULE$.wrapRefArray(new ParamPair[]{reconstructionCol().$minus$greater("reconstruction")}));
    }

    public NeuralNetworkReconstruction() {
        this(Identifiable$.MODULE$.randomUID("nnReconstruction"));
    }
}
