package org.deeplearning4j.spark.ml.reconstruction;

import org.apache.spark.annotation.DeveloperApi;
import org.apache.spark.broadcast.Broadcast;
import org.apache.spark.ml.param.IntParam;
import org.apache.spark.ml.param.Param;
import org.apache.spark.ml.param.ParamMap;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.sql.DataFrame;
import org.apache.spark.sql.functions$;
import org.apache.spark.sql.types.DataType;
import org.apache.spark.sql.types.StructType;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.spark.ml.UnsupervisedLearnerParams;
import org.deeplearning4j.spark.ml.UnsupervisedModel;
import org.deeplearning4j.spark.ml.param.shared.HasEpochs;
import org.deeplearning4j.spark.ml.param.shared.HasLayerIndex;
import org.deeplearning4j.spark.ml.param.shared.HasMultiLayerConfiguration;
import org.deeplearning4j.spark.ml.param.shared.HasReconstructionCol;
import org.deeplearning4j.spark.ml.reconstruction.NeuralNetworkReconstructionParams;
import org.deeplearning4j.spark.sql.types.VectorUDT$;
import org.deeplearning4j.spark.util.package$conversions$;
import org.nd4j.linalg.api.ndarray.INDArray;
import scala.reflect.ScalaSignature;

/* compiled from: MultiLayerNetworkReconstruction.scala */
@DeveloperApi
@ScalaSignature(bytes = "\u0006\u0001\u0005Eb\u0001B\u0001\u0003\u00015\u0011\u0001ET3ve\u0006dg*\u001a;x_J\\'+Z2p]N$(/^2uS>tWj\u001c3fY*\u00111\u0001B\u0001\u000fe\u0016\u001cwN\\:ueV\u001cG/[8o\u0015\t)a!\u0001\u0002nY*\u0011q\u0001C\u0001\u0006gB\f'o\u001b\u0006\u0003\u0013)\ta\u0002Z3fa2,\u0017M\u001d8j]\u001e$$NC\u0001\f\u0003\ry'oZ\u0002\u0001'\r\u0001ab\b\t\u0005\u001fA\u0011R$D\u0001\u0005\u0013\t\tBAA\tV]N,\b/\u001a:wSN,G-T8eK2\u0004\"aE\u000e\u000e\u0003QQ!!\u0006\f\u0002\r1Lg.\u00197h\u0015\t9\u0002$A\u0003nY2L'M\u0003\u0002\b3)\u0011!DC\u0001\u0007CB\f7\r[3\n\u0005q!\"A\u0002,fGR|'\u000f\u0005\u0002\u001f\u00015\t!\u0001\u0005\u0002\u001fA%\u0011\u0011E\u0001\u0002\"\u001d\u0016,(/\u00197OKR<xN]6SK\u000e|gn\u001d;sk\u000e$\u0018n\u001c8QCJ\fWn\u001d\u0005\tG\u0001\u0011)\u0019!C!I\u0005\u0019Q/\u001b3\u0016\u0003\u0015\u0002\"A\n\u0017\u000f\u0005\u001dRS\"\u0001\u0015\u000b\u0003%\nQa]2bY\u0006L!a\u000b\u0015\u0002\rA\u0013X\rZ3g\u0013\ticF\u0001\u0004TiJLgn\u001a\u0006\u0003W!B\u0001\u0002\r\u0001\u0003\u0002\u0003\u0006I!J\u0001\u0005k&$\u0007\u0005\u0003\u00053\u0001\t\u0015\r\u0011\"\u00014\u00035qW\r^<pe.\u0004\u0016M]1ngV\tA\u0007E\u00026qij\u0011A\u000e\u0006\u0003oa\t\u0011B\u0019:pC\u0012\u001c\u0017m\u001d;\n\u0005e2$!\u0003\"s_\u0006$7-Y:u!\tY4)D\u0001=\u0015\tid(A\u0004oI\u0006\u0014(/Y=\u000b\u0005}\u0002\u0015aA1qS*\u0011Q#\u0011\u0006\u0003\u0005*\tAA\u001c35U&\u0011A\t\u0010\u0002\t\u0013:#\u0015I\u001d:bs\"Aa\t\u0001B\u0001B\u0003%A'\u0001\boKR<xN]6QCJ\fWn\u001d\u0011\t\r!\u0003A\u0011\u0001\u0003J\u0003\u0019a\u0014N\\5u}Q\u0019QDS&\t\u000b\r:\u0005\u0019A\u0013\t\u000bI:\u0005\u0019\u0001\u001b\t\u000b5\u0003A\u0011\t(\u0002\u000fA\u0014X\rZ5diR\u0011q*\u0016\t\u0003!Nk\u0011!\u0015\u0006\u0003%b\t1a]9m\u0013\t!\u0016KA\u0005ECR\fgI]1nK\")a\u000b\u0014a\u0001\u001f\u00069A-\u0019;bg\u0016$\b\"\u0002-\u0001\t#I\u0016a\u0003:fG>t7\u000f\u001e:vGR$2A\u0005.]\u0011\u0015Yv\u000b1\u0001\u0013\u0003!1W-\u0019;ve\u0016\u001c\b\"B/X\u0001\u0004q\u0016A\u00037bs\u0016\u0014\u0018J\u001c3fqB\u0011qeX\u0005\u0003A\"\u00121!\u00138u\u0011\u0015\u0011\u0007\u0001\"\u0011d\u0003\u0011\u0019w\u000e]=\u0015\u0005u!\u0007\"B3b\u0001\u00041\u0017!B3yiJ\f\u0007CA4l\u001b\u0005A'BA5k\u0003\u0015\u0001\u0018M]1n\u0015\t)\u0001$\u0003\u0002mQ\nA\u0001+\u0019:b[6\u000b\u0007\u000fC\u0004o\u0001\u0001\u0007I\u0011B8\u0002\u001b9,Go^8sW\"{G\u000eZ3s+\u0005\u0001\bcA9wq6\t!O\u0003\u0002ti\u0006!A.\u00198h\u0015\u0005)\u0018\u0001\u00026bm\u0006L!a\u001e:\u0003\u0017QC'/Z1e\u0019>\u001c\u0017\r\u001c\t\u0003szl\u0011A\u001f\u0006\u0003wr\f!\"\\;mi&d\u0017-_3s\u0015\ti\b\"\u0001\u0002o]&\u0011qP\u001f\u0002\u0012\u001bVdG/\u001b'bs\u0016\u0014h*\u001a;x_J\\\u0007\"CA\u0002\u0001\u0001\u0007I\u0011BA\u0003\u0003EqW\r^<pe.Du\u000e\u001c3fe~#S-\u001d\u000b\u0005\u0003\u000f\ti\u0001E\u0002(\u0003\u0013I1!a\u0003)\u0005\u0011)f.\u001b;\t\u0013\u0005=\u0011\u0011AA\u0001\u0002\u0004\u0001\u0018a\u0001=%c!9\u00111\u0003\u0001!B\u0013\u0001\u0018A\u00048fi^|'o\u001b%pY\u0012,'\u000f\t\u0015\u0005\u0003#\t9\u0002E\u0002(\u00033I1!a\u0007)\u0005%!(/\u00198tS\u0016tG\u000fC\u0004\u0002 \u0001!I!!\t\u0002\u000f9,Go^8sWR\t\u0001\u0010K\u0002\u0001\u0003K\u0001B!a\n\u0002.5\u0011\u0011\u0011\u0006\u0006\u0004\u0003WA\u0012AC1o]>$\u0018\r^5p]&!\u0011qFA\u0015\u00051!UM^3m_B,'/\u00119j\u0001")
/* loaded from: input_file:org/deeplearning4j/spark/ml/reconstruction/NeuralNetworkReconstructionModel.class */
public class NeuralNetworkReconstructionModel extends UnsupervisedModel<Vector, NeuralNetworkReconstructionModel> implements NeuralNetworkReconstructionParams {
    private final String uid;
    private final Broadcast<INDArray> networkParams;
    private transient ThreadLocal<MultiLayerNetwork> networkHolder;
    private final Param<String> reconstructionCol;
    private final IntParam layerIndex;
    private final IntParam epochs;
    private final Param<String> conf;

    @Override // org.deeplearning4j.spark.ml.reconstruction.NeuralNetworkReconstructionParams
    public StructType org$deeplearning4j$spark$ml$reconstruction$NeuralNetworkReconstructionParams$$super$validateAndTransformSchema(StructType structType, boolean z, DataType dataType) {
        return UnsupervisedLearnerParams.Cclass.validateAndTransformSchema(this, structType, z, dataType);
    }

    @Override // org.deeplearning4j.spark.ml.UnsupervisedModel, org.deeplearning4j.spark.ml.UnsupervisedLearnerParams
    public StructType validateAndTransformSchema(StructType structType, boolean z, DataType dataType) {
        return NeuralNetworkReconstructionParams.Cclass.validateAndTransformSchema(this, structType, z, dataType);
    }

    @Override // org.deeplearning4j.spark.ml.param.shared.HasReconstructionCol
    public Param<String> reconstructionCol() {
        return this.reconstructionCol;
    }

    @Override // org.deeplearning4j.spark.ml.param.shared.HasReconstructionCol
    public void org$deeplearning4j$spark$ml$param$shared$HasReconstructionCol$_setter_$reconstructionCol_$eq(Param param) {
        this.reconstructionCol = param;
    }

    @Override // org.deeplearning4j.spark.ml.param.shared.HasReconstructionCol
    public String getReconstructionCol() {
        return HasReconstructionCol.Cclass.getReconstructionCol(this);
    }

    @Override // org.deeplearning4j.spark.ml.param.shared.HasLayerIndex
    public IntParam layerIndex() {
        return this.layerIndex;
    }

    @Override // org.deeplearning4j.spark.ml.param.shared.HasLayerIndex
    public void org$deeplearning4j$spark$ml$param$shared$HasLayerIndex$_setter_$layerIndex_$eq(IntParam intParam) {
        this.layerIndex = intParam;
    }

    @Override // org.deeplearning4j.spark.ml.param.shared.HasLayerIndex
    public int getLayerIndex() {
        return HasLayerIndex.Cclass.getLayerIndex(this);
    }

    @Override // org.deeplearning4j.spark.ml.param.shared.HasEpochs
    public IntParam epochs() {
        return this.epochs;
    }

    @Override // org.deeplearning4j.spark.ml.param.shared.HasEpochs
    public void org$deeplearning4j$spark$ml$param$shared$HasEpochs$_setter_$epochs_$eq(IntParam intParam) {
        this.epochs = intParam;
    }

    @Override // org.deeplearning4j.spark.ml.param.shared.HasEpochs
    public int getEpochs() {
        return HasEpochs.Cclass.getEpochs(this);
    }

    @Override // org.deeplearning4j.spark.ml.param.shared.HasMultiLayerConfiguration
    public Param<String> conf() {
        return this.conf;
    }

    @Override // org.deeplearning4j.spark.ml.param.shared.HasMultiLayerConfiguration
    public void org$deeplearning4j$spark$ml$param$shared$HasMultiLayerConfiguration$_setter_$conf_$eq(Param param) {
        this.conf = param;
    }

    @Override // org.deeplearning4j.spark.ml.param.shared.HasMultiLayerConfiguration
    public String getConf() {
        return HasMultiLayerConfiguration.Cclass.getConf(this);
    }

    public String uid() {
        return this.uid;
    }

    public Broadcast<INDArray> networkParams() {
        return this.networkParams;
    }

    @Override // org.deeplearning4j.spark.ml.UnsupervisedModel
    public DataFrame predict(DataFrame dataFrame) {
        Object $ = $(reconstructionCol());
        if ($ != null ? !$.equals("") : "" != 0) {
            return dataFrame.withColumn((String) $(reconstructionCol()), functions$.MODULE$.callUDF(new NeuralNetworkReconstructionModel$$anonfun$3(this), VectorUDT$.MODULE$.apply(), functions$.MODULE$.col((String) $(featuresCol()))));
        }
        logWarning(new NeuralNetworkReconstructionModel$$anonfun$predict$1(this));
        return dataFrame;
    }

    public Vector reconstruct(Vector vector, int i) {
        return package$conversions$.MODULE$.toVector(network().reconstruct(package$conversions$.MODULE$.toINDArray(vector), i));
    }

    /* renamed from: copy, reason: merged with bridge method [inline-methods] and merged with bridge method [inline-methods] and merged with bridge method [inline-methods] and merged with bridge method [inline-methods] */
    public NeuralNetworkReconstructionModel m13copy(ParamMap paramMap) {
        return (NeuralNetworkReconstructionModel) copyValues(new NeuralNetworkReconstructionModel(uid(), networkParams()), paramMap);
    }

    private ThreadLocal<MultiLayerNetwork> networkHolder() {
        return this.networkHolder;
    }

    private void networkHolder_$eq(ThreadLocal<MultiLayerNetwork> threadLocal) {
        this.networkHolder = threadLocal;
    }

    private MultiLayerNetwork network() {
        if (networkHolder() == null) {
            networkHolder_$eq(new ThreadLocal<MultiLayerNetwork>(this) { // from class: org.deeplearning4j.spark.ml.reconstruction.NeuralNetworkReconstructionModel$$anon$1
                private final /* synthetic */ NeuralNetworkReconstructionModel $outer;

                /* JADX WARN: Can't rename method to resolve collision */
                @Override // java.lang.ThreadLocal
                public MultiLayerNetwork initialValue() {
                    MultiLayerNetwork multiLayerNetwork = new MultiLayerNetwork(MultiLayerConfiguration.fromJson((String) this.$outer.$(this.$outer.conf())));
                    multiLayerNetwork.init();
                    multiLayerNetwork.setParameters((INDArray) this.$outer.networkParams().value());
                    return multiLayerNetwork;
                }

                {
                    if (this == null) {
                        throw new NullPointerException();
                    }
                    this.$outer = this;
                }
            });
        }
        return networkHolder().get();
    }

    public NeuralNetworkReconstructionModel(String str, Broadcast<INDArray> broadcast) {
        this.uid = str;
        this.networkParams = broadcast;
        org$deeplearning4j$spark$ml$param$shared$HasMultiLayerConfiguration$_setter_$conf_$eq(new Param(this, "conf", "multilayer configuration"));
        org$deeplearning4j$spark$ml$param$shared$HasEpochs$_setter_$epochs_$eq(new IntParam(this, "epochs", "number of epochs"));
        org$deeplearning4j$spark$ml$param$shared$HasLayerIndex$_setter_$layerIndex_$eq(new IntParam(this, "layerIndex", "layer index (one-based)"));
        org$deeplearning4j$spark$ml$param$shared$HasReconstructionCol$_setter_$reconstructionCol_$eq(new Param(this, "reconstructionCol", "reconstruction column name"));
        NeuralNetworkReconstructionParams.Cclass.$init$(this);
        this.networkHolder = null;
    }
}
