package ai.djl.spark.task.text;

import ai.djl.huggingface.translator.TextClassificationTranslatorFactory;
import ai.djl.inference.Predictor;
import ai.djl.modality.Classifications;
import org.apache.spark.ml.param.Param;
import org.apache.spark.ml.param.shared.HasInputCol;
import org.apache.spark.ml.param.shared.HasOutputCol;
import org.apache.spark.ml.util.Identifiable$;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.Row$;
import org.apache.spark.sql.types.ArrayType$;
import org.apache.spark.sql.types.DoubleType$;
import org.apache.spark.sql.types.StringType$;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructField$;
import org.apache.spark.sql.types.StructType;
import org.apache.spark.sql.types.StructType$;
import scala.MatchError;
import scala.Predef$;
import scala.collection.GenTraversableOnce;
import scala.collection.Iterable$;
import scala.collection.Iterator;
import scala.collection.Seq;
import scala.collection.Seq$;
import scala.collection.TraversableLike;
import scala.collection.TraversableOnce;
import scala.collection.immutable.$colon;
import scala.collection.immutable.Nil$;
import scala.collection.mutable.ArrayOps;
import scala.jdk.CollectionConverters$;
import scala.reflect.ClassTag$;
import scala.reflect.ScalaSignature;
import scala.runtime.BoxedUnit;
import scala.runtime.BoxesRunTime;

/* compiled from: TextClassifier.scala */
@ScalaSignature(bytes = "\u0006\u0001\u0005md\u0001\u0002\n\u0014\u0001yA\u0001\"\u0014\u0001\u0003\u0006\u0004%\tE\u0014\u0005\t\u001f\u0002\u0011\t\u0011)A\u0005S!)\u0001\u000b\u0001C\u0001#\")\u0001\u000b\u0001C\u0001)\"9Q\u000b\u0001b\u0001\n\u000b1\u0006B\u00020\u0001A\u00035q\u000bC\u0005`\u0001\u0001\u0007\t\u0019!C\u0005A\"I\u0011\r\u0001a\u0001\u0002\u0004%IA\u0019\u0005\nQ\u0002\u0001\r\u0011!Q!\nmCQ!\u001b\u0001\u0005\u0002)DQA\u001c\u0001\u0005\u0002=DQ!\u001d\u0001\u0005\u0002IDQ\u0001\u001e\u0001\u0005\u0002UDq!a\r\u0001\t\u0003\n)\u0004C\u0004\u0002D\u0001!\t&!\u0012\t\u000f\u0005\u0005\u0004\u0001\"\u0015\u0002d!9\u0011Q\u000f\u0001\u0005B\u0005]$A\u0004+fqR\u001cE.Y:tS\u001aLWM\u001d\u0006\u0003)U\tA\u0001^3yi*\u0011acF\u0001\u0005i\u0006\u001c8N\u0003\u0002\u00193\u0005)1\u000f]1sW*\u0011!dG\u0001\u0004I*d'\"\u0001\u000f\u0002\u0005\u0005L7\u0001A\n\u0005\u0001}Y$\n\u0005\u0003!C\r\"T\"A\n\n\u0005\t\u001a\"!\u0005\"bg\u0016$V\r\u001f;Qe\u0016$\u0017n\u0019;peB\u0019AeJ\u0015\u000e\u0003\u0015R\u0011AJ\u0001\u0006g\u000e\fG.Y\u0005\u0003Q\u0015\u0012Q!\u0011:sCf\u0004\"AK\u0019\u000f\u0005-z\u0003C\u0001\u0017&\u001b\u0005i#B\u0001\u0018\u001e\u0003\u0019a$o\\8u}%\u0011\u0001'J\u0001\u0007!J,G-\u001a4\n\u0005I\u001a$AB*ue&twM\u0003\u00021KA\u0019AeJ\u001b\u0011\u0005YJT\"A\u001c\u000b\u0005aJ\u0012\u0001C7pI\u0006d\u0017\u000e^=\n\u0005i:$aD\"mCN\u001c\u0018NZ5dCRLwN\\:\u0011\u0005qBU\"A\u001f\u000b\u0005yz\u0014AB:iCJ,GM\u0003\u0002A\u0003\u0006)\u0001/\u0019:b[*\u0011!iQ\u0001\u0003[2T!\u0001\u0007#\u000b\u0005\u00153\u0015AB1qC\u000eDWMC\u0001H\u0003\ry'oZ\u0005\u0003\u0013v\u00121\u0002S1t\u0013:\u0004X\u000f^\"pYB\u0011AhS\u0005\u0003\u0019v\u0012A\u0002S1t\u001fV$\b/\u001e;D_2\f1!^5e+\u0005I\u0013\u0001B;jI\u0002\na\u0001P5oSRtDC\u0001*T!\t\u0001\u0003\u0001C\u0003N\u0007\u0001\u0007\u0011\u0006F\u0001S\u0003\u0011!x\u000e]&\u0016\u0003]\u00032\u0001W-\\\u001b\u0005y\u0014B\u0001.@\u0005\u0015\u0001\u0016M]1n!\t!C,\u0003\u0002^K\t\u0019\u0011J\u001c;\u0002\u000bQ|\u0007o\u0013\u0011\u0002\u001b%t\u0007/\u001e;D_2Le\u000eZ3y+\u0005Y\u0016!E5oaV$8i\u001c7J]\u0012,\u0007p\u0018\u0013fcR\u00111M\u001a\t\u0003I\u0011L!!Z\u0013\u0003\tUs\u0017\u000e\u001e\u0005\bO\"\t\t\u00111\u0001\\\u0003\rAH%M\u0001\u000fS:\u0004X\u000f^\"pY&sG-\u001a=!\u0003-\u0019X\r^%oaV$8i\u001c7\u0015\u0005-dW\"\u0001\u0001\t\u000b5T\u0001\u0019A\u0015\u0002\u000bY\fG.^3\u0002\u0019M,GoT;uaV$8i\u001c7\u0015\u0005-\u0004\b\"B7\f\u0001\u0004I\u0013aB:fiR{\u0007o\u0013\u000b\u0003WNDQ!\u001c\u0007A\u0002m\u000b\u0001b\u00197bgNLg-\u001f\u000b\u0004m\u0006=\u0001cA<\u0002\n9\u0019\u00010a\u0001\u000f\u0005e|hB\u0001>\u007f\u001d\tYXP\u0004\u0002-y&\tq)\u0003\u0002F\r&\u0011\u0001\u0004R\u0005\u0004\u0003\u0003\u0019\u0015aA:rY&!\u0011QAA\u0004\u0003\u001d\u0001\u0018mY6bO\u0016T1!!\u0001D\u0013\u0011\tY!!\u0004\u0003\u0013\u0011\u000bG/\u0019$sC6,'\u0002BA\u0003\u0003\u000fAq!!\u0005\u000e\u0001\u0004\t\u0019\"A\u0004eCR\f7/\u001a;1\t\u0005U\u0011\u0011\u0005\t\u0007\u0003/\tI\"!\b\u000e\u0005\u0005\u001d\u0011\u0002BA\u000e\u0003\u000f\u0011q\u0001R1uCN,G\u000f\u0005\u0003\u0002 \u0005\u0005B\u0002\u0001\u0003\r\u0003G\ty!!A\u0001\u0002\u000b\u0005\u0011Q\u0005\u0002\u0004?\u0012\n\u0014\u0003BA\u0014\u0003[\u00012\u0001JA\u0015\u0013\r\tY#\n\u0002\b\u001d>$\b.\u001b8h!\r!\u0013qF\u0005\u0004\u0003c)#aA!os\u0006IAO]1og\u001a|'/\u001c\u000b\u0004m\u0006]\u0002bBA\t\u001d\u0001\u0007\u0011\u0011\b\u0019\u0005\u0003w\ty\u0004\u0005\u0004\u0002\u0018\u0005e\u0011Q\b\t\u0005\u0003?\ty\u0004\u0002\u0007\u0002B\u0005]\u0012\u0011!A\u0001\u0006\u0003\t)CA\u0002`II\nQ\u0002\u001e:b]N4wN]7S_^\u001cH\u0003BA$\u0003;\u0002b!!\u0013\u0002R\u0005]c\u0002BA&\u0003\u001fr1\u0001LA'\u0013\u00051\u0013bAA\u0003K%!\u00111KA+\u0005!IE/\u001a:bi>\u0014(bAA\u0003KA!\u0011qCA-\u0013\u0011\tY&a\u0002\u0003\u0007I{w\u000fC\u0004\u0002`=\u0001\r!a\u0012\u0002\t%$XM]\u0001\u0012m\u0006d\u0017\u000eZ1uK&s\u0007/\u001e;UsB,GcA2\u0002f!9\u0011q\r\tA\u0002\u0005%\u0014AB:dQ\u0016l\u0017\r\u0005\u0003\u0002l\u0005ETBAA7\u0015\u0011\ty'a\u0002\u0002\u000bQL\b/Z:\n\t\u0005M\u0014Q\u000e\u0002\u000b'R\u0014Xo\u0019;UsB,\u0017a\u0004;sC:\u001chm\u001c:n'\u000eDW-\\1\u0015\t\u0005%\u0014\u0011\u0010\u0005\b\u0003O\n\u0002\u0019AA5\u0001")
/* loaded from: input_file:ai/djl/spark/task/text/TextClassifier.class */
public class TextClassifier extends BaseTextPredictor<String[], Classifications[]> implements HasInputCol, HasOutputCol {
    private final String uid;
    private final Param<Object> topK;
    private int inputColIndex;
    private final Param<String> outputCol;
    private final Param<String> inputCol;

    public final String getOutputCol() {
        return HasOutputCol.getOutputCol$(this);
    }

    public final String getInputCol() {
        return HasInputCol.getInputCol$(this);
    }

    public final Param<String> outputCol() {
        return this.outputCol;
    }

    public final void org$apache$spark$ml$param$shared$HasOutputCol$_setter_$outputCol_$eq(Param<String> param) {
        this.outputCol = param;
    }

    public final Param<String> inputCol() {
        return this.inputCol;
    }

    public final void org$apache$spark$ml$param$shared$HasInputCol$_setter_$inputCol_$eq(Param<String> param) {
        this.inputCol = param;
    }

    @Override // ai.djl.spark.task.text.BaseTextPredictor, ai.djl.spark.task.BasePredictor
    public String uid() {
        return this.uid;
    }

    public final Param<Object> topK() {
        return this.topK;
    }

    private int inputColIndex() {
        return this.inputColIndex;
    }

    private void inputColIndex_$eq(int i) {
        this.inputColIndex = i;
    }

    public TextClassifier setInputCol(String str) {
        return set(inputCol(), str);
    }

    public TextClassifier setOutputCol(String str) {
        return set(outputCol(), str);
    }

    public TextClassifier setTopK(int i) {
        return set(topK(), BoxesRunTime.boxToInteger(i));
    }

    public Dataset<Row> classify(Dataset<?> dataset) {
        return transform(dataset);
    }

    @Override // ai.djl.spark.task.BasePredictor
    public Dataset<Row> transform(Dataset<?> dataset) {
        if (isDefined(topK())) {
            arguments().put("topK", $(topK()).toString());
        } else {
            BoxedUnit boxedUnit = BoxedUnit.UNIT;
        }
        inputColIndex_$eq(dataset.schema().fieldIndex((String) $(inputCol())));
        return super.transform(dataset);
    }

    @Override // ai.djl.spark.task.BasePredictor
    public Iterator<Row> transformRows(Iterator<Row> iterator) {
        Predictor<String[], Classifications[]> newPredictor = model().newPredictor();
        return iterator.grouped(BoxesRunTime.unboxToInt($(batchSize()))).flatMap(seq -> {
            return (GenTraversableOnce) ((TraversableLike) seq.zip(Predef$.MODULE$.wrapRefArray((Classifications[]) newPredictor.predict((String[]) ((TraversableOnce) seq.map(row -> {
                return row.getString(this.inputColIndex());
            }, Seq$.MODULE$.canBuildFrom())).toArray(ClassTag$.MODULE$.apply(String.class)))), Seq$.MODULE$.canBuildFrom())).map(tuple2 -> {
                if (tuple2 == null) {
                    throw new MatchError(tuple2);
                }
                Row row2 = (Row) tuple2._1();
                Classifications classifications = (Classifications) tuple2._2();
                return Row$.MODULE$.fromSeq((Seq) row2.toSeq().$colon$plus(Row$.MODULE$.apply(Predef$.MODULE$.genericWrapArray(new Object[]{classifications.getClassNames().toArray(), classifications.getProbabilities().toArray(), ((TraversableLike) CollectionConverters$.MODULE$.collectionAsScalaIterableConverter(classifications.topK()).asScala()).map(classification -> {
                    return classification.toString();
                }, Iterable$.MODULE$.canBuildFrom())})), Seq$.MODULE$.canBuildFrom()));
            }, Seq$.MODULE$.canBuildFrom());
        });
    }

    @Override // ai.djl.spark.task.BasePredictor
    public void validateInputType(StructType structType) {
        validateType(structType.apply((String) $(inputCol())), StringType$.MODULE$);
    }

    public StructType transformSchema(StructType structType) {
        return new StructType((StructField[]) new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps(structType.fields())).$colon$plus(new StructField((String) $(outputCol()), StructType$.MODULE$.apply(new $colon.colon(new StructField("class_names", ArrayType$.MODULE$.apply(StringType$.MODULE$), StructField$.MODULE$.apply$default$3(), StructField$.MODULE$.apply$default$4()), new $colon.colon(new StructField("probabilities", ArrayType$.MODULE$.apply(DoubleType$.MODULE$), StructField$.MODULE$.apply$default$3(), StructField$.MODULE$.apply$default$4()), new $colon.colon(new StructField("top_k", ArrayType$.MODULE$.apply(StringType$.MODULE$), StructField$.MODULE$.apply$default$3(), StructField$.MODULE$.apply$default$4()), Nil$.MODULE$)))), StructField$.MODULE$.apply$default$3(), StructField$.MODULE$.apply$default$4()), ClassTag$.MODULE$.apply(StructField.class)));
    }

    public TextClassifier(String str) {
        this.uid = str;
        HasInputCol.$init$(this);
        HasOutputCol.$init$(this);
        this.topK = new Param<>(this, "topK", "The number of classes to return");
        setDefault(inputClass(), String[].class);
        setDefault(outputClass(), Classifications[].class);
        setDefault(translatorFactory(), new TextClassificationTranslatorFactory());
    }

    public TextClassifier() {
        this(Identifiable$.MODULE$.randomUID("TextClassifier"));
    }
}
