package ai.djl.spark.task.binary;

import ai.djl.inference.Predictor;
import ai.djl.spark.task.BasePredictor;
import ai.djl.spark.translator.binary.NpBinaryTranslatorFactory;
import org.apache.spark.ml.param.Param;
import org.apache.spark.ml.param.shared.HasInputCol;
import org.apache.spark.ml.param.shared.HasOutputCol;
import org.apache.spark.ml.util.Identifiable$;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.Row$;
import org.apache.spark.sql.types.BinaryType$;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructField$;
import org.apache.spark.sql.types.StructType;
import scala.Predef$;
import scala.collection.Iterator;
import scala.collection.Seq;
import scala.collection.Seq$;
import scala.collection.mutable.ArrayOps;
import scala.reflect.ClassTag$;
import scala.reflect.ScalaSignature;

/* compiled from: BinaryPredictor.scala */
@ScalaSignature(bytes = "\u0006\u0001\u0005mc\u0001B\b\u0011\u0001mA\u0001b\u000f\u0001\u0003\u0006\u0004%\t\u0005\u0010\u0005\t\u0011\u0002\u0011\t\u0011)A\u0005{!)\u0011\n\u0001C\u0001\u0015\")\u0011\n\u0001C\u0001\u001d\"Iq\n\u0001a\u0001\u0002\u0004%I\u0001\u0015\u0005\n)\u0002\u0001\r\u00111A\u0005\nUC\u0011b\u0017\u0001A\u0002\u0003\u0005\u000b\u0015B)\t\u000bq\u0003A\u0011A/\t\u000b\u0005\u0004A\u0011\u00012\t\u000b\u0011\u0004A\u0011A3\t\u000f\u0005M\u0001\u0001\"\u0011\u0002\u0016!9\u00111\u0005\u0001\u0005R\u0005\u0015\u0002bBA!\u0001\u0011\u0005\u00111\t\u0005\b\u0003+\u0002A\u0011IA,\u0005=\u0011\u0015N\\1ssB\u0013X\rZ5di>\u0014(BA\t\u0013\u0003\u0019\u0011\u0017N\\1ss*\u00111\u0003F\u0001\u0005i\u0006\u001c8N\u0003\u0002\u0016-\u0005)1\u000f]1sW*\u0011q\u0003G\u0001\u0004I*d'\"A\r\u0002\u0005\u0005L7\u0001A\n\u0005\u0001qI\u0003\b\u0005\u0003\u001e=\u0001\u0002S\"\u0001\n\n\u0005}\u0011\"!\u0004\"bg\u0016\u0004&/\u001a3jGR|'\u000fE\u0002\"I\u0019j\u0011A\t\u0006\u0002G\u0005)1oY1mC&\u0011QE\t\u0002\u0006\u0003J\u0014\u0018-\u001f\t\u0003C\u001dJ!\u0001\u000b\u0012\u0003\t\tKH/\u001a\t\u0003UYj\u0011a\u000b\u0006\u0003Y5\naa\u001d5be\u0016$'B\u0001\u00180\u0003\u0015\u0001\u0018M]1n\u0015\t\u0001\u0014'\u0001\u0002nY*\u0011QC\r\u0006\u0003gQ\na!\u00199bG\",'\"A\u001b\u0002\u0007=\u0014x-\u0003\u00028W\tY\u0001*Y:J]B,HoQ8m!\tQ\u0013(\u0003\u0002;W\ta\u0001*Y:PkR\u0004X\u000f^\"pY\u0006\u0019Q/\u001b3\u0016\u0003u\u0002\"AP#\u000f\u0005}\u001a\u0005C\u0001!#\u001b\u0005\t%B\u0001\"\u001b\u0003\u0019a$o\\8u}%\u0011AII\u0001\u0007!J,G-\u001a4\n\u0005\u0019;%AB*ue&twM\u0003\u0002EE\u0005!Q/\u001b3!\u0003\u0019a\u0014N\\5u}Q\u00111*\u0014\t\u0003\u0019\u0002i\u0011\u0001\u0005\u0005\u0006w\r\u0001\r!\u0010\u000b\u0002\u0017\u0006i\u0011N\u001c9vi\u000e{G.\u00138eKb,\u0012!\u0015\t\u0003CIK!a\u0015\u0012\u0003\u0007%sG/A\tj]B,HoQ8m\u0013:$W\r_0%KF$\"AV-\u0011\u0005\u0005:\u0016B\u0001-#\u0005\u0011)f.\u001b;\t\u000fi3\u0011\u0011!a\u0001#\u0006\u0019\u0001\u0010J\u0019\u0002\u001d%t\u0007/\u001e;D_2Le\u000eZ3yA\u0005Y1/\u001a;J]B,HoQ8m)\tqv,D\u0001\u0001\u0011\u0015\u0001\u0007\u00021\u0001>\u0003\u00151\u0018\r\\;f\u00031\u0019X\r^(viB,HoQ8m)\tq6\rC\u0003a\u0013\u0001\u0007Q(A\u0004qe\u0016$\u0017n\u0019;\u0015\u0005\u0019<\bCA4u\u001d\tA\u0017O\u0004\u0002j_:\u0011!N\u001c\b\u0003W6t!\u0001\u00117\n\u0003UJ!a\r\u001b\n\u0005U\u0011\u0014B\u000192\u0003\r\u0019\u0018\u000f\\\u0005\u0003eN\fq\u0001]1dW\u0006<WM\u0003\u0002qc%\u0011QO\u001e\u0002\n\t\u0006$\u0018M\u0012:b[\u0016T!A]:\t\u000baT\u0001\u0019A=\u0002\u000f\u0011\fG/Y:fiB\u001a!0!\u0001\u0011\u0007mdh0D\u0001t\u0013\ti8OA\u0004ECR\f7/\u001a;\u0011\u0007}\f\t\u0001\u0004\u0001\u0005\u0017\u0005\rq/!A\u0001\u0002\u000b\u0005\u0011Q\u0001\u0002\u0004?\u0012\n\u0014\u0003BA\u0004\u0003\u001b\u00012!IA\u0005\u0013\r\tYA\t\u0002\b\u001d>$\b.\u001b8h!\r\t\u0013qB\u0005\u0004\u0003#\u0011#aA!os\u0006IAO]1og\u001a|'/\u001c\u000b\u0004M\u0006]\u0001B\u0002=\f\u0001\u0004\tI\u0002\r\u0003\u0002\u001c\u0005}\u0001\u0003B>}\u0003;\u00012a`A\u0010\t1\t\t#a\u0006\u0002\u0002\u0003\u0005)\u0011AA\u0003\u0005\ryFEM\u0001\u000eiJ\fgn\u001d4pe6\u0014vn^:\u0015\t\u0005\u001d\u0012Q\b\t\u0007\u0003S\t\t$a\u000e\u000f\t\u0005-\u0012q\u0006\b\u0004\u0001\u00065\u0012\"A\u0012\n\u0005I\u0014\u0013\u0002BA\u001a\u0003k\u0011\u0001\"\u0013;fe\u0006$xN\u001d\u0006\u0003e\n\u00022a_A\u001d\u0013\r\tYd\u001d\u0002\u0004%><\bbBA \u0019\u0001\u0007\u0011qE\u0001\u0005SR,'/A\twC2LG-\u0019;f\u0013:\u0004X\u000f\u001e+za\u0016$2AVA#\u0011\u001d\t9%\u0004a\u0001\u0003\u0013\naa]2iK6\f\u0007\u0003BA&\u0003#j!!!\u0014\u000b\u0007\u0005=3/A\u0003usB,7/\u0003\u0003\u0002T\u00055#AC*ueV\u001cG\u000fV=qK\u0006yAO]1og\u001a|'/\\*dQ\u0016l\u0017\r\u0006\u0003\u0002J\u0005e\u0003bBA$\u001d\u0001\u0007\u0011\u0011\n")
/* loaded from: input_file:ai/djl/spark/task/binary/BinaryPredictor.class */
public class BinaryPredictor extends BasePredictor<byte[], byte[]> implements HasInputCol, HasOutputCol {
    private final String uid;
    private int inputColIndex;
    private final Param<String> outputCol;
    private final Param<String> inputCol;

    public final String getOutputCol() {
        return HasOutputCol.getOutputCol$(this);
    }

    public final String getInputCol() {
        return HasInputCol.getInputCol$(this);
    }

    public final Param<String> outputCol() {
        return this.outputCol;
    }

    public final void org$apache$spark$ml$param$shared$HasOutputCol$_setter_$outputCol_$eq(Param<String> param) {
        this.outputCol = param;
    }

    public final Param<String> inputCol() {
        return this.inputCol;
    }

    public final void org$apache$spark$ml$param$shared$HasInputCol$_setter_$inputCol_$eq(Param<String> param) {
        this.inputCol = param;
    }

    @Override // ai.djl.spark.task.BasePredictor
    public String uid() {
        return this.uid;
    }

    private int inputColIndex() {
        return this.inputColIndex;
    }

    private void inputColIndex_$eq(int i) {
        this.inputColIndex = i;
    }

    public BinaryPredictor setInputCol(String str) {
        return set(inputCol(), str);
    }

    public BinaryPredictor setOutputCol(String str) {
        return set(outputCol(), str);
    }

    public Dataset<Row> predict(Dataset<?> dataset) {
        return transform(dataset);
    }

    @Override // ai.djl.spark.task.BasePredictor
    public Dataset<Row> transform(Dataset<?> dataset) {
        arguments().put("batchifier", $(batchifier()));
        inputColIndex_$eq(dataset.schema().fieldIndex((String) $(inputCol())));
        return super.transform(dataset);
    }

    @Override // ai.djl.spark.task.BasePredictor
    public Iterator<Row> transformRows(Iterator<Row> iterator) {
        Predictor<byte[], byte[]> newPredictor = model().newPredictor();
        return iterator.map(row -> {
            return Row$.MODULE$.fromSeq((Seq) row.toSeq().$colon$plus(newPredictor.predict(row.getAs(this.inputColIndex())), Seq$.MODULE$.canBuildFrom()));
        });
    }

    @Override // ai.djl.spark.task.BasePredictor
    public void validateInputType(StructType structType) {
        validateType(structType.apply((String) $(inputCol())), BinaryType$.MODULE$);
    }

    public StructType transformSchema(StructType structType) {
        return new StructType((StructField[]) new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps(structType.fields())).$colon$plus(new StructField((String) $(outputCol()), BinaryType$.MODULE$, StructField$.MODULE$.apply$default$3(), StructField$.MODULE$.apply$default$4()), ClassTag$.MODULE$.apply(StructField.class)));
    }

    public BinaryPredictor(String str) {
        this.uid = str;
        HasInputCol.$init$(this);
        HasOutputCol.$init$(this);
        setDefault(inputClass(), byte[].class);
        setDefault(outputClass(), byte[].class);
        setDefault(translatorFactory(), new NpBinaryTranslatorFactory());
    }

    public BinaryPredictor() {
        this(Identifiable$.MODULE$.randomUID("BinaryPredictor"));
    }
}
