package ai.djl.spark.task.text;

import ai.djl.huggingface.tokenizers.Encoding;
import ai.djl.huggingface.tokenizers.HuggingFaceTokenizer;
import org.apache.spark.ml.param.Param;
import org.apache.spark.ml.param.shared.HasInputCol;
import org.apache.spark.ml.param.shared.HasOutputCol;
import org.apache.spark.ml.util.Identifiable$;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.Row$;
import org.apache.spark.sql.types.ArrayType$;
import org.apache.spark.sql.types.LongType$;
import org.apache.spark.sql.types.StringType$;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructField$;
import org.apache.spark.sql.types.StructType;
import org.apache.spark.sql.types.StructType$;
import scala.Predef$;
import scala.collection.Iterator;
import scala.collection.Seq;
import scala.collection.Seq$;
import scala.collection.immutable.$colon;
import scala.collection.immutable.Nil$;
import scala.collection.mutable.ArrayOps;
import scala.reflect.ClassTag$;
import scala.reflect.ScalaSignature;

/* compiled from: HuggingFaceTextEncoder.scala */
@ScalaSignature(bytes = "\u0006\u0001\u0005]d\u0001\u0002\n\u0014\u0001yA\u0001B\u0013\u0001\u0003\u0006\u0004%\te\u0013\u0005\t\u0019\u0002\u0011\t\u0011)A\u0005G!)Q\n\u0001C\u0001\u001d\")Q\n\u0001C\u0001#\"9!\u000b\u0001b\u0001\n\u000b\u0019\u0006B\u0002-\u0001A\u00035A\u000bC\u0005Z\u0001\u0001\u0007\t\u0019!C\u00055\"Iq\f\u0001a\u0001\u0002\u0004%I\u0001\u0019\u0005\nM\u0002\u0001\r\u0011!Q!\nmCQa\u001a\u0001\u0005\u0002!DQ\u0001\u001c\u0001\u0005\u00025DQa\u001c\u0001\u0005\u0002ADQA\u001d\u0001\u0005\u0002MDq!a\f\u0001\t\u0003\n\t\u0004C\u0004\u0002@\u0001!\t%!\u0011\t\u000f\u0005u\u0003\u0001\"\u0001\u0002`!9\u0011\u0011\u000f\u0001\u0005B\u0005M$A\u0006%vO\u001eLgn\u001a$bG\u0016$V\r\u001f;F]\u000e|G-\u001a:\u000b\u0005Q)\u0012\u0001\u0002;fqRT!AF\f\u0002\tQ\f7o\u001b\u0006\u00031e\tQa\u001d9be.T!AG\u000e\u0002\u0007\u0011TGNC\u0001\u001d\u0003\t\t\u0017n\u0001\u0001\u0014\t\u0001y\u0002h\u0012\t\u0005A\u0005\u001a\u0003'D\u0001\u0014\u0013\t\u00113CA\tCCN,G+\u001a=u!J,G-[2u_J\u0004\"\u0001J\u0017\u000f\u0005\u0015Z\u0003C\u0001\u0014*\u001b\u00059#B\u0001\u0015\u001e\u0003\u0019a$o\\8u})\t!&A\u0003tG\u0006d\u0017-\u0003\u0002-S\u00051\u0001K]3eK\u001aL!AL\u0018\u0003\rM#(/\u001b8h\u0015\ta\u0013\u0006\u0005\u00022m5\t!G\u0003\u00024i\u0005QAo\\6f]&TXM]:\u000b\u0005UJ\u0012a\u00035vO\u001eLgn\u001a4bG\u0016L!a\u000e\u001a\u0003\u0011\u0015s7m\u001c3j]\u001e\u0004\"!O#\u000e\u0003iR!a\u000f\u001f\u0002\rMD\u0017M]3e\u0015\tid(A\u0003qCJ\fWN\u0003\u0002@\u0001\u0006\u0011Q\u000e\u001c\u0006\u00031\u0005S!AQ\"\u0002\r\u0005\u0004\u0018m\u00195f\u0015\u0005!\u0015aA8sO&\u0011aI\u000f\u0002\f\u0011\u0006\u001c\u0018J\u001c9vi\u000e{G\u000e\u0005\u0002:\u0011&\u0011\u0011J\u000f\u0002\r\u0011\u0006\u001cx*\u001e;qkR\u001cu\u000e\\\u0001\u0004k&$W#A\u0012\u0002\tULG\rI\u0001\u0007y%t\u0017\u000e\u001e \u0015\u0005=\u0003\u0006C\u0001\u0011\u0001\u0011\u0015Q5\u00011\u0001$)\u0005y\u0015!\u0003;pW\u0016t\u0017N_3s+\u0005!\u0006cA+WG5\tA(\u0003\u0002Xy\t)\u0001+\u0019:b[\u0006QAo\\6f]&TXM\u001d\u0011\u0002\u001b%t\u0007/\u001e;D_2Le\u000eZ3y+\u0005Y\u0006C\u0001/^\u001b\u0005I\u0013B\u00010*\u0005\rIe\u000e^\u0001\u0012S:\u0004X\u000f^\"pY&sG-\u001a=`I\u0015\fHCA1e!\ta&-\u0003\u0002dS\t!QK\\5u\u0011\u001d)\u0007\"!AA\u0002m\u000b1\u0001\u001f\u00132\u00039Ig\u000e];u\u0007>d\u0017J\u001c3fq\u0002\n1b]3u\u0013:\u0004X\u000f^\"pYR\u0011\u0011N[\u0007\u0002\u0001!)1N\u0003a\u0001G\u0005)a/\u00197vK\u0006a1/\u001a;PkR\u0004X\u000f^\"pYR\u0011\u0011N\u001c\u0005\u0006W.\u0001\raI\u0001\rg\u0016$Hk\\6f]&TXM\u001d\u000b\u0003SFDQa\u001b\u0007A\u0002\r\na!\u001a8d_\u0012,Gc\u0001;\u0002\fA\u0019Q/!\u0002\u000f\u0005Y|hBA<~\u001d\tAHP\u0004\u0002zw:\u0011aE_\u0005\u0002\t&\u0011!iQ\u0005\u00031\u0005K!A !\u0002\u0007M\fH.\u0003\u0003\u0002\u0002\u0005\r\u0011a\u00029bG.\fw-\u001a\u0006\u0003}\u0002KA!a\u0002\u0002\n\tIA)\u0019;b\rJ\fW.\u001a\u0006\u0005\u0003\u0003\t\u0019\u0001C\u0004\u0002\u000e5\u0001\r!a\u0004\u0002\u000f\u0011\fG/Y:fiB\"\u0011\u0011CA\u000f!\u0019\t\u0019\"!\u0006\u0002\u001a5\u0011\u00111A\u0005\u0005\u0003/\t\u0019AA\u0004ECR\f7/\u001a;\u0011\t\u0005m\u0011Q\u0004\u0007\u0001\t1\ty\"a\u0003\u0002\u0002\u0003\u0005)\u0011AA\u0011\u0005\ryF%M\t\u0005\u0003G\tI\u0003E\u0002]\u0003KI1!a\n*\u0005\u001dqu\u000e\u001e5j]\u001e\u00042\u0001XA\u0016\u0013\r\ti#\u000b\u0002\u0004\u0003:L\u0018!\u0003;sC:\u001chm\u001c:n)\r!\u00181\u0007\u0005\b\u0003\u001bq\u0001\u0019AA\u001ba\u0011\t9$a\u000f\u0011\r\u0005M\u0011QCA\u001d!\u0011\tY\"a\u000f\u0005\u0019\u0005u\u00121GA\u0001\u0002\u0003\u0015\t!!\t\u0003\u0007}##'A\u0007ue\u0006t7OZ8s[J{wo\u001d\u000b\u0005\u0003\u0007\nI\u0006\u0005\u0004\u0002F\u00055\u00131\u000b\b\u0005\u0003\u000f\nYED\u0002'\u0003\u0013J\u0011AK\u0005\u0004\u0003\u0003I\u0013\u0002BA(\u0003#\u0012\u0001\"\u0013;fe\u0006$xN\u001d\u0006\u0004\u0003\u0003I\u0003\u0003BA\n\u0003+JA!a\u0016\u0002\u0004\t\u0019!k\\<\t\u000f\u0005ms\u00021\u0001\u0002D\u0005!\u0011\u000e^3s\u0003E1\u0018\r\\5eCR,\u0017J\u001c9viRK\b/\u001a\u000b\u0004C\u0006\u0005\u0004bBA2!\u0001\u0007\u0011QM\u0001\u0007g\u000eDW-\\1\u0011\t\u0005\u001d\u0014QN\u0007\u0003\u0003SRA!a\u001b\u0002\u0004\u0005)A/\u001f9fg&!\u0011qNA5\u0005)\u0019FO];diRK\b/Z\u0001\u0010iJ\fgn\u001d4pe6\u001c6\r[3nCR!\u0011QMA;\u0011\u001d\t\u0019'\u0005a\u0001\u0003K\u0002")
/* loaded from: input_file:ai/djl/spark/task/text/HuggingFaceTextEncoder.class */
public class HuggingFaceTextEncoder extends BaseTextPredictor<String, Encoding> implements HasInputCol, HasOutputCol {
    private final String uid;
    private final Param<String> tokenizer;
    private int inputColIndex;
    private final Param<String> outputCol;
    private final Param<String> inputCol;

    public final String getOutputCol() {
        return HasOutputCol.getOutputCol$(this);
    }

    public final String getInputCol() {
        return HasInputCol.getInputCol$(this);
    }

    public final Param<String> outputCol() {
        return this.outputCol;
    }

    public final void org$apache$spark$ml$param$shared$HasOutputCol$_setter_$outputCol_$eq(Param<String> param) {
        this.outputCol = param;
    }

    public final Param<String> inputCol() {
        return this.inputCol;
    }

    public final void org$apache$spark$ml$param$shared$HasInputCol$_setter_$inputCol_$eq(Param<String> param) {
        this.inputCol = param;
    }

    @Override // ai.djl.spark.task.text.BaseTextPredictor, ai.djl.spark.task.BasePredictor
    public String uid() {
        return this.uid;
    }

    public final Param<String> tokenizer() {
        return this.tokenizer;
    }

    private int inputColIndex() {
        return this.inputColIndex;
    }

    private void inputColIndex_$eq(int i) {
        this.inputColIndex = i;
    }

    public HuggingFaceTextEncoder setInputCol(String str) {
        return set(inputCol(), str);
    }

    public HuggingFaceTextEncoder setOutputCol(String str) {
        return set(outputCol(), str);
    }

    public HuggingFaceTextEncoder setTokenizer(String str) {
        return set(tokenizer(), str);
    }

    public Dataset<Row> encode(Dataset<?> dataset) {
        return transform(dataset);
    }

    @Override // ai.djl.spark.task.BasePredictor
    public Dataset<Row> transform(Dataset<?> dataset) {
        inputColIndex_$eq(dataset.schema().fieldIndex((String) $(inputCol())));
        return super.transform(dataset);
    }

    @Override // ai.djl.spark.task.BasePredictor
    public Iterator<Row> transformRows(Iterator<Row> iterator) {
        HuggingFaceTokenizer newInstance = HuggingFaceTokenizer.newInstance((String) $(tokenizer()));
        return iterator.map(row -> {
            Encoding encode = newInstance.encode(row.getString(this.inputColIndex()));
            return Row$.MODULE$.fromSeq((Seq) row.toSeq().$colon$plus(Row$.MODULE$.apply(Predef$.MODULE$.genericWrapArray(new Object[]{encode.getIds(), encode.getTypeIds(), encode.getAttentionMask()})), Seq$.MODULE$.canBuildFrom()));
        });
    }

    @Override // ai.djl.spark.task.BasePredictor
    public void validateInputType(StructType structType) {
        validateType(structType.apply((String) $(inputCol())), StringType$.MODULE$);
    }

    public StructType transformSchema(StructType structType) {
        return new StructType((StructField[]) new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps(structType.fields())).$colon$plus(new StructField((String) $(outputCol()), StructType$.MODULE$.apply(new $colon.colon(new StructField("ids", ArrayType$.MODULE$.apply(LongType$.MODULE$), StructField$.MODULE$.apply$default$3(), StructField$.MODULE$.apply$default$4()), new $colon.colon(new StructField("type_ids", ArrayType$.MODULE$.apply(LongType$.MODULE$), StructField$.MODULE$.apply$default$3(), StructField$.MODULE$.apply$default$4()), new $colon.colon(new StructField("attention_mask", ArrayType$.MODULE$.apply(LongType$.MODULE$), StructField$.MODULE$.apply$default$3(), StructField$.MODULE$.apply$default$4()), Nil$.MODULE$)))), StructField$.MODULE$.apply$default$3(), StructField$.MODULE$.apply$default$4()), ClassTag$.MODULE$.apply(StructField.class)));
    }

    public HuggingFaceTextEncoder(String str) {
        this.uid = str;
        HasInputCol.$init$(this);
        HasOutputCol.$init$(this);
        this.tokenizer = new Param<>(this, "tokenizer", "The name of the tokenizer");
        setDefault(inputClass(), String.class);
        setDefault(outputClass(), Encoding.class);
        setDefault(translatorFactory(), null);
    }

    public HuggingFaceTextEncoder() {
        this(Identifiable$.MODULE$.randomUID("HuggingFaceTextEncoder"));
    }
}
