package ai.djl.spark.task.text;

import ai.djl.huggingface.translator.QuestionAnsweringTranslatorFactory;
import ai.djl.inference.Predictor;
import ai.djl.modality.nlp.qa.QAInput;
import java.util.List;
import org.apache.spark.ml.param.Param;
import org.apache.spark.ml.param.StringArrayParam;
import org.apache.spark.ml.param.shared.HasInputCols;
import org.apache.spark.ml.param.shared.HasOutputCol;
import org.apache.spark.ml.util.Identifiable$;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.Row$;
import org.apache.spark.sql.types.StringType$;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructField$;
import org.apache.spark.sql.types.StructType;
import scala.MatchError;
import scala.Predef$;
import scala.collection.GenTraversableOnce;
import scala.collection.Iterable;
import scala.collection.Iterator;
import scala.collection.Seq;
import scala.collection.Seq$;
import scala.collection.TraversableLike;
import scala.collection.mutable.ArrayOps;
import scala.jdk.CollectionConverters$;
import scala.reflect.ClassTag$;
import scala.reflect.ScalaSignature;
import scala.runtime.BoxesRunTime;

/* compiled from: QuestionAnswerer.scala */
@ScalaSignature(bytes = "\u0006\u0001\u0005}c\u0001\u0002\b\u0010\u0001iA\u0001\u0002\u0013\u0001\u0003\u0006\u0004%\t%\u0013\u0005\t\u0015\u0002\u0011\t\u0011)A\u0005S!)1\n\u0001C\u0001\u0019\")1\n\u0001C\u0001\u001f\"9\u0001\u000b\u0001b\u0001\n\u0013\t\u0006BB-\u0001A\u0003%!\u000bC\u0003[\u0001\u0011\u00051\fC\u0003a\u0001\u0011\u0005\u0011\rC\u0003d\u0001\u0011\u0005A\rC\u0004\u0002\u0012\u0001!\t%a\u0005\t\u000f\u0005\u0005\u0002\u0001\"\u0015\u0002$!9\u0011q\b\u0001\u0005R\u0005\u0005\u0003bBA-\u0001\u0011\u0005\u00131\f\u0002\u0011#V,7\u000f^5p]\u0006s7o^3sKJT!\u0001E\t\u0002\tQ,\u0007\u0010\u001e\u0006\u0003%M\tA\u0001^1tW*\u0011A#F\u0001\u0006gB\f'o\u001b\u0006\u0003-]\t1\u0001\u001a6m\u0015\u0005A\u0012AA1j\u0007\u0001\u0019B\u0001A\u000e7\u000bB!A$H\u0010*\u001b\u0005y\u0011B\u0001\u0010\u0010\u0005E\u0011\u0015m]3UKb$\bK]3eS\u000e$xN\u001d\t\u0003A\u001dj\u0011!\t\u0006\u0003E\r\n!!]1\u000b\u0005\u0011*\u0013a\u00018ma*\u0011a%F\u0001\t[>$\u0017\r\\5us&\u0011\u0001&\t\u0002\b#\u0006Ke\u000e];u!\tQ3G\u0004\u0002,cA\u0011AfL\u0007\u0002[)\u0011a&G\u0001\u0007yI|w\u000e\u001e \u000b\u0003A\nQa]2bY\u0006L!AM\u0018\u0002\rA\u0013X\rZ3g\u0013\t!TG\u0001\u0004TiJLgn\u001a\u0006\u0003e=\u0002\"aN\"\u000e\u0003aR!!\u000f\u001e\u0002\rMD\u0017M]3e\u0015\tYD(A\u0003qCJ\fWN\u0003\u0002>}\u0005\u0011Q\u000e\u001c\u0006\u0003)}R!\u0001Q!\u0002\r\u0005\u0004\u0018m\u00195f\u0015\u0005\u0011\u0015aA8sO&\u0011A\t\u000f\u0002\r\u0011\u0006\u001c\u0018J\u001c9vi\u000e{Gn\u001d\t\u0003o\u0019K!a\u0012\u001d\u0003\u0019!\u000b7oT;uaV$8i\u001c7\u0002\u0007ULG-F\u0001*\u0003\u0011)\u0018\u000e\u001a\u0011\u0002\rqJg.\u001b;?)\tie\n\u0005\u0002\u001d\u0001!)\u0001j\u0001a\u0001SQ\tQ*A\bj]B,HoQ8m\u0013:$\u0017nY3t+\u0005\u0011\u0006cA*U-6\tq&\u0003\u0002V_\t)\u0011I\u001d:bsB\u00111kV\u0005\u00031>\u00121!\u00138u\u0003AIg\u000e];u\u0007>d\u0017J\u001c3jG\u0016\u001c\b%\u0001\u0007tKRLe\u000e];u\u0007>d7\u000f\u0006\u0002];6\t\u0001\u0001C\u0003_\u000f\u0001\u0007q,A\u0003wC2,X\rE\u0002T)&\nAb]3u\u001fV$\b/\u001e;D_2$\"\u0001\u00182\t\u000byC\u0001\u0019A\u0015\u0002\r\u0005t7o^3s)\t)g\u000f\u0005\u0002gg:\u0011q\r\u001d\b\u0003Q:t!![7\u000f\u0005)dgB\u0001\u0017l\u0013\u0005\u0011\u0015B\u0001!B\u0013\t!r(\u0003\u0002p}\u0005\u00191/\u001d7\n\u0005E\u0014\u0018a\u00029bG.\fw-\u001a\u0006\u0003_zJ!\u0001^;\u0003\u0013\u0011\u000bG/\u0019$sC6,'BA9s\u0011\u00159\u0018\u00021\u0001y\u0003\u001d!\u0017\r^1tKR\u0004$!_@\u0011\u0007i\\X0D\u0001s\u0013\ta(OA\u0004ECR\f7/\u001a;\u0011\u0005y|H\u0002\u0001\u0003\f\u0003\u00031\u0018\u0011!A\u0001\u0006\u0003\t\u0019AA\u0002`IE\nB!!\u0002\u0002\fA\u00191+a\u0002\n\u0007\u0005%qFA\u0004O_RD\u0017N\\4\u0011\u0007M\u000bi!C\u0002\u0002\u0010=\u00121!\u00118z\u0003%!(/\u00198tM>\u0014X\u000eF\u0002f\u0003+Aaa\u001e\u0006A\u0002\u0005]\u0001\u0007BA\r\u0003;\u0001BA_>\u0002\u001cA\u0019a0!\b\u0005\u0019\u0005}\u0011QCA\u0001\u0002\u0003\u0015\t!a\u0001\u0003\u0007}##'A\u0007ue\u0006t7OZ8s[J{wo\u001d\u000b\u0005\u0003K\tY\u0004\u0005\u0004\u0002(\u0005=\u0012Q\u0007\b\u0005\u0003S\tiCD\u0002-\u0003WI\u0011\u0001M\u0005\u0003c>JA!!\r\u00024\tA\u0011\n^3sCR|'O\u0003\u0002r_A\u0019!0a\u000e\n\u0007\u0005e\"OA\u0002S_^Dq!!\u0010\f\u0001\u0004\t)#\u0001\u0003ji\u0016\u0014\u0018!\u0005<bY&$\u0017\r^3J]B,H\u000fV=qKR!\u00111IA%!\r\u0019\u0016QI\u0005\u0004\u0003\u000fz#\u0001B+oSRDq!a\u0013\r\u0001\u0004\ti%\u0001\u0004tG\",W.\u0019\t\u0005\u0003\u001f\n)&\u0004\u0002\u0002R)\u0019\u00111\u000b:\u0002\u000bQL\b/Z:\n\t\u0005]\u0013\u0011\u000b\u0002\u000b'R\u0014Xo\u0019;UsB,\u0017a\u0004;sC:\u001chm\u001c:n'\u000eDW-\\1\u0015\t\u00055\u0013Q\f\u0005\b\u0003\u0017j\u0001\u0019AA'\u0001")
/* loaded from: input_file:ai/djl/spark/task/text/QuestionAnswerer.class */
public class QuestionAnswerer extends BaseTextPredictor<QAInput, String> implements HasInputCols, HasOutputCol {
    private final String uid;
    private final int[] inputColIndices;
    private final Param<String> outputCol;
    private final StringArrayParam inputCols;

    public final String getOutputCol() {
        return HasOutputCol.getOutputCol$(this);
    }

    public final String[] getInputCols() {
        return HasInputCols.getInputCols$(this);
    }

    public final Param<String> outputCol() {
        return this.outputCol;
    }

    public final void org$apache$spark$ml$param$shared$HasOutputCol$_setter_$outputCol_$eq(Param<String> param) {
        this.outputCol = param;
    }

    public final StringArrayParam inputCols() {
        return this.inputCols;
    }

    public final void org$apache$spark$ml$param$shared$HasInputCols$_setter_$inputCols_$eq(StringArrayParam stringArrayParam) {
        this.inputCols = stringArrayParam;
    }

    @Override // ai.djl.spark.task.text.BaseTextPredictor, ai.djl.spark.task.BasePredictor
    public String uid() {
        return this.uid;
    }

    private int[] inputColIndices() {
        return this.inputColIndices;
    }

    public QuestionAnswerer setInputCols(String[] strArr) {
        return set(inputCols(), strArr);
    }

    public QuestionAnswerer setOutputCol(String str) {
        return set(outputCol(), str);
    }

    public Dataset<Row> answer(Dataset<?> dataset) {
        return transform(dataset);
    }

    @Override // ai.djl.spark.task.BasePredictor
    public Dataset<Row> transform(Dataset<?> dataset) {
        inputColIndices()[0] = dataset.schema().fieldIndex(((String[]) $(inputCols()))[0]);
        inputColIndices()[1] = dataset.schema().fieldIndex(((String[]) $(inputCols()))[1]);
        return super.transform(dataset);
    }

    @Override // ai.djl.spark.task.BasePredictor
    public Iterator<Row> transformRows(Iterator<Row> iterator) {
        Predictor<QAInput, String> newPredictor = model().newPredictor();
        return iterator.grouped(BoxesRunTime.unboxToInt($(batchSize()))).flatMap(seq -> {
            return (GenTraversableOnce) ((TraversableLike) seq.zip((Iterable) CollectionConverters$.MODULE$.collectionAsScalaIterableConverter(newPredictor.batchPredict((List) CollectionConverters$.MODULE$.seqAsJavaListConverter((Seq) seq.map(row -> {
                return new QAInput(row.getString(this.inputColIndices()[0]), row.getString(this.inputColIndices()[1]));
            }, Seq$.MODULE$.canBuildFrom())).asJava())).asScala(), Seq$.MODULE$.canBuildFrom())).map(tuple2 -> {
                if (tuple2 == null) {
                    throw new MatchError(tuple2);
                }
                Row row2 = (Row) tuple2._1();
                return Row$.MODULE$.fromSeq((Seq) row2.toSeq().$colon$plus((String) tuple2._2(), Seq$.MODULE$.canBuildFrom()));
            }, Seq$.MODULE$.canBuildFrom());
        });
    }

    @Override // ai.djl.spark.task.BasePredictor
    public void validateInputType(StructType structType) {
        Predef$.MODULE$.assert(((String[]) $(inputCols())).length == 2, () -> {
            return "inputCols must have 2 columns";
        });
        validateType(structType.apply(((String[]) $(inputCols()))[0]), StringType$.MODULE$);
        validateType(structType.apply(((String[]) $(inputCols()))[1]), StringType$.MODULE$);
    }

    public StructType transformSchema(StructType structType) {
        return new StructType((StructField[]) new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps(structType.fields())).$colon$plus(new StructField((String) $(outputCol()), StringType$.MODULE$, StructField$.MODULE$.apply$default$3(), StructField$.MODULE$.apply$default$4()), ClassTag$.MODULE$.apply(StructField.class)));
    }

    public QuestionAnswerer(String str) {
        this.uid = str;
        HasInputCols.$init$(this);
        HasOutputCol.$init$(this);
        this.inputColIndices = new int[2];
        setDefault(inputClass(), QAInput.class);
        setDefault(outputClass(), String.class);
        setDefault(translatorFactory(), new QuestionAnsweringTranslatorFactory());
    }

    public QuestionAnswerer() {
        this(Identifiable$.MODULE$.randomUID("QuestionAnswerer"));
    }
}
