package org.deeplearning4j.spark.ml.classification;

import org.apache.spark.sql.Row;
import org.apache.spark.sql.types.StructType;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import scala.Array$;
import scala.Function3;
import scala.MatchError;
import scala.Predef$;
import scala.Serializable;
import scala.Tuple2;
import scala.collection.Iterable;
import scala.collection.Iterable$;
import scala.collection.Iterator;
import scala.collection.Seq$;
import scala.collection.TraversableLike;
import scala.collection.immutable.Nil$;
import scala.reflect.ClassTag$;
import scala.runtime.AbstractFunction1;

/* compiled from: MultiLayerNetworkClassification.scala */
/* loaded from: input_file:org/deeplearning4j/spark/ml/classification/NeuralNetworkClassificationModel$$anonfun$3.class */
public class NeuralNetworkClassificationModel$$anonfun$3 extends AbstractFunction1<Iterator<Row>, Iterator<Row>> implements Serializable {
    public static final long serialVersionUID = 0;
    private final /* synthetic */ NeuralNetworkClassificationModel $outer;
    private final StructType schema$1;

    public final Iterator<Row> apply(Iterator<Row> iterator) {
        Iterable apply;
        MultiLayerNetwork multiLayerNetwork = new MultiLayerNetwork(MultiLayerConfiguration.fromJson((String) this.$outer.$(this.$outer.conf())));
        multiLayerNetwork.init();
        multiLayerNetwork.setParameters((INDArray) this.$outer.networkParams().value());
        Tuple2 unzip = iterator.map(new NeuralNetworkClassificationModel$$anonfun$3$$anonfun$4(this, this.schema$1.fieldIndex(this.$outer.getFeaturesCol()))).toIterable().unzip(Predef$.MODULE$.conforms());
        if (unzip == null) {
            throw new MatchError(unzip);
        }
        Tuple2 tuple2 = new Tuple2((Iterable) unzip._1(), (Iterable) unzip._2());
        Iterable iterable = (Iterable) tuple2._1();
        Iterable iterable2 = (Iterable) tuple2._2();
        switch (iterable2.size()) {
            case 0:
                apply = Seq$.MODULE$.apply(Nil$.MODULE$);
                break;
            default:
                apply = (Iterable) ((TraversableLike) iterable2.zipWithIndex(Iterable$.MODULE$.canBuildFrom())).map(new NeuralNetworkClassificationModel$$anonfun$3$$anonfun$6(this, multiLayerNetwork.output(Nd4j.vstack((INDArray[]) iterable.toArray(ClassTag$.MODULE$.apply(INDArray.class))), true), (Function3[]) Predef$.MODULE$.refArrayOps(this.schema$1.fieldNames()).flatMap(new NeuralNetworkClassificationModel$$anonfun$3$$anonfun$5(this), Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(Function3.class)))), Iterable$.MODULE$.canBuildFrom());
                break;
        }
        return apply.iterator();
    }

    public /* synthetic */ NeuralNetworkClassificationModel org$deeplearning4j$spark$ml$classification$NeuralNetworkClassificationModel$$anonfun$$$outer() {
        return this.$outer;
    }

    public NeuralNetworkClassificationModel$$anonfun$3(NeuralNetworkClassificationModel neuralNetworkClassificationModel, StructType structType) {
        if (neuralNetworkClassificationModel == null) {
            throw new NullPointerException();
        }
        this.$outer = neuralNetworkClassificationModel;
        this.schema$1 = structType;
    }
}
