package ai.djl.spark.task.text;

import ai.djl.huggingface.translator.TextClassificationTranslatorFactory;
import ai.djl.inference.Predictor;
import ai.djl.modality.Classifications;
import java.util.List;
import org.apache.spark.ml.param.Param;
import org.apache.spark.ml.param.shared.HasInputCol;
import org.apache.spark.ml.param.shared.HasOutputCol;
import org.apache.spark.ml.util.Identifiable$;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.Row$;
import org.apache.spark.sql.types.ArrayType$;
import org.apache.spark.sql.types.DoubleType$;
import org.apache.spark.sql.types.StringType$;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructField$;
import org.apache.spark.sql.types.StructType;
import org.apache.spark.sql.types.StructType$;
import scala.MatchError;
import scala.Predef$;
import scala.collection.GenTraversableOnce;
import scala.collection.Iterable;
import scala.collection.Iterable$;
import scala.collection.Iterator;
import scala.collection.Seq;
import scala.collection.Seq$;
import scala.collection.TraversableLike;
import scala.collection.immutable.$colon;
import scala.collection.immutable.Nil$;
import scala.collection.mutable.ArrayOps;
import scala.jdk.CollectionConverters$;
import scala.reflect.ClassTag$;
import scala.reflect.ScalaSignature;
import scala.runtime.BoxedUnit;
import scala.runtime.BoxesRunTime;

/* compiled from: TextClassifier.scala */
@ScalaSignature(bytes = "\u0006\u0001\u0005Md\u0001\u0002\n\u0014\u0001yA\u0001\u0002\u0013\u0001\u0003\u0006\u0004%\t%\u0013\u0005\t\u0015\u0002\u0011\t\u0011)A\u0005G!)1\n\u0001C\u0001\u0019\")1\n\u0001C\u0001\u001f\"9\u0001\u000b\u0001b\u0001\n\u000b\t\u0006B\u0002.\u0001A\u00035!\u000bC\u0005\\\u0001\u0001\u0007\t\u0019!C\u00059\"IQ\f\u0001a\u0001\u0002\u0004%IA\u0018\u0005\nI\u0002\u0001\r\u0011!Q!\nYCQ!\u001a\u0001\u0005\u0002\u0019DQA\u001b\u0001\u0005\u0002-DQ!\u001c\u0001\u0005\u00029DQ\u0001\u001d\u0001\u0005\u0002EDq!a\u000b\u0001\t\u0003\ni\u0003C\u0004\u0002<\u0001!\t&!\u0010\t\u000f\u0005e\u0003\u0001\"\u0015\u0002\\!9\u0011Q\u000e\u0001\u0005B\u0005=$A\u0004+fqR\u001cE.Y:tS\u001aLWM\u001d\u0006\u0003)U\tA\u0001^3yi*\u0011acF\u0001\u0005i\u0006\u001c8N\u0003\u0002\u00193\u0005)1\u000f]1sW*\u0011!dG\u0001\u0004I*d'\"\u0001\u000f\u0002\u0005\u0005L7\u0001A\n\u0005\u0001}1T\t\u0005\u0003!C\r\u0002T\"A\n\n\u0005\t\u001a\"!\u0005\"bg\u0016$V\r\u001f;Qe\u0016$\u0017n\u0019;peB\u0011A%\f\b\u0003K-\u0002\"AJ\u0015\u000e\u0003\u001dR!\u0001K\u000f\u0002\rq\u0012xn\u001c;?\u0015\u0005Q\u0013!B:dC2\f\u0017B\u0001\u0017*\u0003\u0019\u0001&/\u001a3fM&\u0011af\f\u0002\u0007'R\u0014\u0018N\\4\u000b\u00051J\u0003CA\u00195\u001b\u0005\u0011$BA\u001a\u001a\u0003!iw\u000eZ1mSRL\u0018BA\u001b3\u0005=\u0019E.Y:tS\u001aL7-\u0019;j_:\u001c\bCA\u001cD\u001b\u0005A$BA\u001d;\u0003\u0019\u0019\b.\u0019:fI*\u00111\bP\u0001\u0006a\u0006\u0014\u0018-\u001c\u0006\u0003{y\n!!\u001c7\u000b\u0005ay$B\u0001!B\u0003\u0019\t\u0007/Y2iK*\t!)A\u0002pe\u001eL!\u0001\u0012\u001d\u0003\u0017!\u000b7/\u00138qkR\u001cu\u000e\u001c\t\u0003o\u0019K!a\u0012\u001d\u0003\u0019!\u000b7oT;uaV$8i\u001c7\u0002\u0007ULG-F\u0001$\u0003\u0011)\u0018\u000e\u001a\u0011\u0002\rqJg.\u001b;?)\tie\n\u0005\u0002!\u0001!)\u0001j\u0001a\u0001GQ\tQ*\u0001\u0003u_B\\U#\u0001*\u0011\u0007M#f+D\u0001;\u0013\t)&HA\u0003QCJ\fW\u000e\u0005\u0002X16\t\u0011&\u0003\u0002ZS\t\u0019\u0011J\u001c;\u0002\u000bQ|\u0007o\u0013\u0011\u0002\u001b%t\u0007/\u001e;D_2Le\u000eZ3y+\u00051\u0016!E5oaV$8i\u001c7J]\u0012,\u0007p\u0018\u0013fcR\u0011qL\u0019\t\u0003/\u0002L!!Y\u0015\u0003\tUs\u0017\u000e\u001e\u0005\bG\"\t\t\u00111\u0001W\u0003\rAH%M\u0001\u000fS:\u0004X\u000f^\"pY&sG-\u001a=!\u0003-\u0019X\r^%oaV$8i\u001c7\u0015\u0005\u001dDW\"\u0001\u0001\t\u000b%T\u0001\u0019A\u0012\u0002\u000bY\fG.^3\u0002\u0019M,GoT;uaV$8i\u001c7\u0015\u0005\u001dd\u0007\"B5\f\u0001\u0004\u0019\u0013aB:fiR{\u0007o\u0013\u000b\u0003O>DQ!\u001b\u0007A\u0002Y\u000b\u0001b\u00197bgNLg-\u001f\u000b\u0004e\u0006\u001d\u0001cA:\u0002\u00029\u0011A/ \b\u0003knt!A\u001e>\u000f\u0005]LhB\u0001\u0014y\u0013\u0005\u0011\u0015B\u0001!B\u0013\tAr(\u0003\u0002}}\u0005\u00191/\u001d7\n\u0005y|\u0018a\u00029bG.\fw-\u001a\u0006\u0003yzJA!a\u0001\u0002\u0006\tIA)\u0019;b\rJ\fW.\u001a\u0006\u0003}~Dq!!\u0003\u000e\u0001\u0004\tY!A\u0004eCR\f7/\u001a;1\t\u00055\u0011\u0011\u0004\t\u0007\u0003\u001f\t\t\"!\u0006\u000e\u0003}L1!a\u0005��\u0005\u001d!\u0015\r^1tKR\u0004B!a\u0006\u0002\u001a1\u0001A\u0001DA\u000e\u0003\u000f\t\t\u0011!A\u0003\u0002\u0005u!aA0%cE!\u0011qDA\u0013!\r9\u0016\u0011E\u0005\u0004\u0003GI#a\u0002(pi\"Lgn\u001a\t\u0004/\u0006\u001d\u0012bAA\u0015S\t\u0019\u0011I\\=\u0002\u0013Q\u0014\u0018M\\:g_JlGc\u0001:\u00020!9\u0011\u0011\u0002\bA\u0002\u0005E\u0002\u0007BA\u001a\u0003o\u0001b!a\u0004\u0002\u0012\u0005U\u0002\u0003BA\f\u0003o!A\"!\u000f\u00020\u0005\u0005\t\u0011!B\u0001\u0003;\u00111a\u0018\u00133\u00035!(/\u00198tM>\u0014XNU8xgR!\u0011qHA+!\u0019\t\t%!\u0013\u0002P9!\u00111IA$\u001d\r1\u0013QI\u0005\u0002U%\u0011a0K\u0005\u0005\u0003\u0017\niE\u0001\u0005Ji\u0016\u0014\u0018\r^8s\u0015\tq\u0018\u0006\u0005\u0003\u0002\u0010\u0005E\u0013bAA*\u007f\n\u0019!k\\<\t\u000f\u0005]s\u00021\u0001\u0002@\u0005!\u0011\u000e^3s\u0003E1\u0018\r\\5eCR,\u0017J\u001c9viRK\b/\u001a\u000b\u0004?\u0006u\u0003bBA0!\u0001\u0007\u0011\u0011M\u0001\u0007g\u000eDW-\\1\u0011\t\u0005\r\u0014\u0011N\u0007\u0003\u0003KR1!a\u001a��\u0003\u0015!\u0018\u0010]3t\u0013\u0011\tY'!\u001a\u0003\u0015M#(/^2u)f\u0004X-A\bue\u0006t7OZ8s[N\u001b\u0007.Z7b)\u0011\t\t'!\u001d\t\u000f\u0005}\u0013\u00031\u0001\u0002b\u0001")
/* loaded from: input_file:ai/djl/spark/task/text/TextClassifier.class */
public class TextClassifier extends BaseTextPredictor<String, Classifications> implements HasInputCol, HasOutputCol {
    private final String uid;
    private final Param<Object> topK;
    private int inputColIndex;
    private final Param<String> outputCol;
    private final Param<String> inputCol;

    public final String getOutputCol() {
        return HasOutputCol.getOutputCol$(this);
    }

    public final String getInputCol() {
        return HasInputCol.getInputCol$(this);
    }

    public final Param<String> outputCol() {
        return this.outputCol;
    }

    public final void org$apache$spark$ml$param$shared$HasOutputCol$_setter_$outputCol_$eq(Param<String> param) {
        this.outputCol = param;
    }

    public final Param<String> inputCol() {
        return this.inputCol;
    }

    public final void org$apache$spark$ml$param$shared$HasInputCol$_setter_$inputCol_$eq(Param<String> param) {
        this.inputCol = param;
    }

    @Override // ai.djl.spark.task.text.BaseTextPredictor, ai.djl.spark.task.BasePredictor
    public String uid() {
        return this.uid;
    }

    public final Param<Object> topK() {
        return this.topK;
    }

    private int inputColIndex() {
        return this.inputColIndex;
    }

    private void inputColIndex_$eq(int i) {
        this.inputColIndex = i;
    }

    public TextClassifier setInputCol(String str) {
        return set(inputCol(), str);
    }

    public TextClassifier setOutputCol(String str) {
        return set(outputCol(), str);
    }

    public TextClassifier setTopK(int i) {
        return set(topK(), BoxesRunTime.boxToInteger(i));
    }

    public Dataset<Row> classify(Dataset<?> dataset) {
        return transform(dataset);
    }

    @Override // ai.djl.spark.task.BasePredictor
    public Dataset<Row> transform(Dataset<?> dataset) {
        if (isDefined(topK())) {
            arguments().put("topK", $(topK()).toString());
        } else {
            BoxedUnit boxedUnit = BoxedUnit.UNIT;
        }
        inputColIndex_$eq(dataset.schema().fieldIndex((String) $(inputCol())));
        return super.transform(dataset);
    }

    @Override // ai.djl.spark.task.BasePredictor
    public Iterator<Row> transformRows(Iterator<Row> iterator) {
        Predictor<String, Classifications> newPredictor = model().newPredictor();
        return iterator.grouped(BoxesRunTime.unboxToInt($(batchSize()))).flatMap(seq -> {
            return (GenTraversableOnce) ((TraversableLike) seq.zip((Iterable) CollectionConverters$.MODULE$.collectionAsScalaIterableConverter(newPredictor.batchPredict((List) CollectionConverters$.MODULE$.seqAsJavaListConverter((Seq) seq.map(row -> {
                return row.getString(this.inputColIndex());
            }, Seq$.MODULE$.canBuildFrom())).asJava())).asScala(), Seq$.MODULE$.canBuildFrom())).map(tuple2 -> {
                if (tuple2 == null) {
                    throw new MatchError(tuple2);
                }
                Row row2 = (Row) tuple2._1();
                Classifications classifications = (Classifications) tuple2._2();
                return Row$.MODULE$.fromSeq((Seq) row2.toSeq().$colon$plus(Row$.MODULE$.apply(Predef$.MODULE$.genericWrapArray(new Object[]{classifications.getClassNames().toArray(), classifications.getProbabilities().toArray(), ((TraversableLike) CollectionConverters$.MODULE$.collectionAsScalaIterableConverter(classifications.topK()).asScala()).map(classification -> {
                    return classification.toString();
                }, Iterable$.MODULE$.canBuildFrom())})), Seq$.MODULE$.canBuildFrom()));
            }, Seq$.MODULE$.canBuildFrom());
        });
    }

    @Override // ai.djl.spark.task.BasePredictor
    public void validateInputType(StructType structType) {
        validateType(structType.apply((String) $(inputCol())), StringType$.MODULE$);
    }

    public StructType transformSchema(StructType structType) {
        return new StructType((StructField[]) new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps(structType.fields())).$colon$plus(new StructField((String) $(outputCol()), StructType$.MODULE$.apply(new $colon.colon(new StructField("class_names", ArrayType$.MODULE$.apply(StringType$.MODULE$), StructField$.MODULE$.apply$default$3(), StructField$.MODULE$.apply$default$4()), new $colon.colon(new StructField("probabilities", ArrayType$.MODULE$.apply(DoubleType$.MODULE$), StructField$.MODULE$.apply$default$3(), StructField$.MODULE$.apply$default$4()), new $colon.colon(new StructField("top_k", ArrayType$.MODULE$.apply(StringType$.MODULE$), StructField$.MODULE$.apply$default$3(), StructField$.MODULE$.apply$default$4()), Nil$.MODULE$)))), StructField$.MODULE$.apply$default$3(), StructField$.MODULE$.apply$default$4()), ClassTag$.MODULE$.apply(StructField.class)));
    }

    public TextClassifier(String str) {
        this.uid = str;
        HasInputCol.$init$(this);
        HasOutputCol.$init$(this);
        this.topK = new Param<>(this, "topK", "The number of classes to return");
        setDefault(inputClass(), String.class);
        setDefault(outputClass(), Classifications.class);
        setDefault(translatorFactory(), new TextClassificationTranslatorFactory());
    }

    public TextClassifier() {
        this(Identifiable$.MODULE$.randomUID("TextClassifier"));
    }
}
