package ai.djl.spark.task;

import ai.djl.spark.ModelLoader;
import ai.djl.translate.TranslatorFactory;
import java.util.HashMap;
import java.util.Map;
import org.apache.spark.ml.Transformer;
import org.apache.spark.ml.param.Param;
import org.apache.spark.ml.param.ParamMap;
import org.apache.spark.ml.util.Identifiable$;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.catalyst.encoders.RowEncoder$;
import org.apache.spark.sql.types.DataType;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
import scala.Predef$;
import scala.collection.Iterator;
import scala.reflect.ScalaSignature;
import scala.runtime.BoxedUnit;

/* compiled from: BasePredictor.scala */
@ScalaSignature(bytes = "\u0006\u0001\t5a!B\u0013'\u0003\u0003y\u0003\u0002\u0003\u001f\u0001\u0005\u000b\u0007I\u0011I\u001f\t\u0011-\u0003!\u0011!Q\u0001\nyBQ\u0001\u0014\u0001\u0005\u00025CQ\u0001\u0014\u0001\u0005\u0002\u0001Dq!\u0019\u0001C\u0002\u0013\u0015!\r\u0003\u0004j\u0001\u0001\u0006ia\u0019\u0005\bU\u0002\u0011\r\u0011\"\u0002c\u0011\u0019Y\u0007\u0001)A\u0007G\"9A\u000e\u0001b\u0001\n\u000bi\u0007B\u0002:\u0001A\u00035a\u000eC\u0004t\u0001\t\u0007IQ\u0001;\t\r]\u0004\u0001\u0015!\u0004v\u0011\u001dA\bA1A\u0005\u0006\tDa!\u001f\u0001!\u0002\u001b\u0019\u0007b\u0002>\u0001\u0005\u0004%)a\u001f\u0005\b\u0003\u000f\u0001\u0001\u0015!\u0004}\u0011-\tI\u0001\u0001a\u0001\u0002\u0004%\t\"a\u0003\t\u0017\u0005U\u0001\u00011AA\u0002\u0013E\u0011q\u0003\u0005\f\u0003G\u0001\u0001\u0019!A!B\u0013\ti\u0001C\u0005\u0002&\u0001\u0001\r\u0011\"\u0005\u0002(!I\u0011q\b\u0001A\u0002\u0013E\u0011\u0011\t\u0005\t\u0003\u000b\u0002\u0001\u0015)\u0003\u0002*!Y\u0011q\t\u0001A\u0002\u0003\u0007I\u0011CA%\u0011-\tY\u0006\u0001a\u0001\u0002\u0004%\t\"!\u0018\t\u0017\u0005\u0005\u0004\u00011A\u0001B\u0003&\u00111\n\u0005\b\u0003G\u0002A\u0011AA3\u0011\u001d\ti\u0007\u0001C\u0001\u0003_Bq!a\u001d\u0001\t\u0003\t)\bC\u0004\u0002z\u0001!\t!a\u001f\t\u000f\u0005}\u0004\u0001\"\u0001\u0002\u0002\"9\u0011Q\u0011\u0001\u0005\u0002\u0005\u001d\u0005bBAF\u0001\u0011\u0005\u0013Q\u0012\u0005\b\u0003\u0003\u0004A\u0011IAb\u0011\u001d\ty\r\u0001D\t\u0003#Dq!!<\u0001\r\u0003\ty\u000fC\u0004\u0002v\u0002!\t!a>\u0003\u001b\t\u000b7/\u001a)sK\u0012L7\r^8s\u0015\t9\u0003&\u0001\u0003uCN\\'BA\u0015+\u0003\u0015\u0019\b/\u0019:l\u0015\tYC&A\u0002eU2T\u0011!L\u0001\u0003C&\u001c\u0001!F\u00021%v\u001b\"\u0001A\u0019\u0011\u0005IRT\"A\u001a\u000b\u0005Q*\u0014AA7m\u0015\tIcG\u0003\u00028q\u00051\u0011\r]1dQ\u0016T\u0011!O\u0001\u0004_J<\u0017BA\u001e4\u0005-!&/\u00198tM>\u0014X.\u001a:\u0002\u0007ULG-F\u0001?!\ty\u0004J\u0004\u0002A\rB\u0011\u0011\tR\u0007\u0002\u0005*\u00111IL\u0001\u0007yI|w\u000e\u001e \u000b\u0003\u0015\u000bQa]2bY\u0006L!a\u0012#\u0002\rA\u0013X\rZ3g\u0013\tI%J\u0001\u0004TiJLgn\u001a\u0006\u0003\u000f\u0012\u000bA!^5eA\u00051A(\u001b8jiz\"\"AT0\u0011\t=\u0003\u0001\u000bX\u0007\u0002MA\u0011\u0011K\u0015\u0007\u0001\t\u0015\u0019\u0006A1\u0001U\u0005\u0005\t\u0015CA+Z!\t1v+D\u0001E\u0013\tAFIA\u0004O_RD\u0017N\\4\u0011\u0005YS\u0016BA.E\u0005\r\te.\u001f\t\u0003#v#QA\u0018\u0001C\u0002Q\u0013\u0011A\u0011\u0005\u0006y\r\u0001\rA\u0010\u000b\u0002\u001d\u00061QM\\4j]\u0016,\u0012a\u0019\t\u0004I\u001etT\"A3\u000b\u0005\u0019\u001c\u0014!\u00029be\u0006l\u0017B\u00015f\u0005\u0015\u0001\u0016M]1n\u0003\u001d)gnZ5oK\u0002\n\u0001\"\\8eK2,&\u000f\\\u0001\n[>$W\r\\+sY\u0002\n!\"\u001b8qkR\u001cE.Y:t+\u0005q\u0007c\u00013h_B\u0019q\b\u001d)\n\u0005ET%!B\"mCN\u001c\u0018aC5oaV$8\t\\1tg\u0002\n1b\\;uaV$8\t\\1tgV\tQ\u000fE\u0002eOZ\u00042a\u00109]\u00031yW\u000f\u001e9vi\u000ec\u0017m]:!\u0003)\u0011\u0017\r^2iS\u001aLWM]\u0001\fE\u0006$8\r[5gS\u0016\u0014\b%A\tue\u0006t7\u000f\\1u_J4\u0015m\u0019;pef,\u0012\u0001 \t\u0004I\u001el\bc\u0001@\u0002\u00045\tqPC\u0002\u0002\u0002)\n\u0011\u0002\u001e:b]Nd\u0017\r^3\n\u0007\u0005\u0015qPA\tUe\u0006t7\u000f\\1u_J4\u0015m\u0019;pef\f!\u0003\u001e:b]Nd\u0017\r^8s\r\u0006\u001cGo\u001c:zA\u0005)Qn\u001c3fYV\u0011\u0011Q\u0002\t\u0007\u0003\u001f\t\t\u0002\u0015/\u000e\u0003!J1!a\u0005)\u0005-iu\u000eZ3m\u0019>\fG-\u001a:\u0002\u00135|G-\u001a7`I\u0015\fH\u0003BA\r\u0003?\u00012AVA\u000e\u0013\r\ti\u0002\u0012\u0002\u0005+:LG\u000fC\u0005\u0002\"I\t\t\u00111\u0001\u0002\u000e\u0005\u0019\u0001\u0010J\u0019\u0002\r5|G-\u001a7!\u0003%\t'oZ;nK:$8/\u0006\u0002\u0002*A9\u00111FA\u001b}\u0005eRBAA\u0017\u0015\u0011\ty#!\r\u0002\tU$\u0018\u000e\u001c\u0006\u0003\u0003g\tAA[1wC&!\u0011qGA\u0017\u0005\ri\u0015\r\u001d\t\u0004-\u0006m\u0012bAA\u001f\t\n1\u0011I\\=SK\u001a\fQ\"\u0019:hk6,g\u000e^:`I\u0015\fH\u0003BA\r\u0003\u0007B\u0011\"!\t\u0016\u0003\u0003\u0005\r!!\u000b\u0002\u0015\u0005\u0014x-^7f]R\u001c\b%\u0001\u0007pkR\u0004X\u000f^*dQ\u0016l\u0017-\u0006\u0002\u0002LA!\u0011QJA,\u001b\t\tyE\u0003\u0003\u0002R\u0005M\u0013!\u0002;za\u0016\u001c(bAA+k\u0005\u00191/\u001d7\n\t\u0005e\u0013q\n\u0002\u000b'R\u0014Xo\u0019;UsB,\u0017\u0001E8viB,HoU2iK6\fw\fJ3r)\u0011\tI\"a\u0018\t\u0013\u0005\u0005\u0002$!AA\u0002\u0005-\u0013!D8viB,HoU2iK6\f\u0007%A\u0005tKR,enZ5oKR!\u0011qMA5\u001b\u0005\u0001\u0001BBA65\u0001\u0007a(A\u0003wC2,X-A\u0006tKRlu\u000eZ3m+JdG\u0003BA4\u0003cBa!a\u001b\u001c\u0001\u0004q\u0014!D:fi&s\u0007/\u001e;DY\u0006\u001c8\u000f\u0006\u0003\u0002h\u0005]\u0004BBA69\u0001\u0007q.\u0001\btKR|U\u000f\u001e9vi\u000ec\u0017m]:\u0015\t\u0005\u001d\u0014Q\u0010\u0005\u0007\u0003Wj\u0002\u0019\u0001<\u0002\u001bM,GOQ1uG\"Lg-[3s)\u0011\t9'a!\t\r\u0005-d\u00041\u0001?\u0003Q\u0019X\r\u001e+sC:\u001cH.\u0019;pe\u001a\u000b7\r^8ssR!\u0011qMAE\u0011\u0019\tYg\ba\u0001{\u0006IAO]1og\u001a|'/\u001c\u000b\u0005\u0003\u001f\u000bi\u000b\u0005\u0003\u0002\u0012\u0006\u001df\u0002BAJ\u0003GsA!!&\u0002\":!\u0011qSAP\u001d\u0011\tI*!(\u000f\u0007\u0005\u000bY*C\u0001:\u0013\t9\u0004(\u0003\u0002*m%\u0019\u0011QK\u001b\n\t\u0005\u0015\u00161K\u0001\ba\u0006\u001c7.Y4f\u0013\u0011\tI+a+\u0003\u0013\u0011\u000bG/\u0019$sC6,'\u0002BAS\u0003'Bq!a,!\u0001\u0004\t\t,A\u0004eCR\f7/\u001a;1\t\u0005M\u0016Q\u0018\t\u0007\u0003k\u000b9,a/\u000e\u0005\u0005M\u0013\u0002BA]\u0003'\u0012q\u0001R1uCN,G\u000fE\u0002R\u0003{#1\"a0\u0002.\u0006\u0005\t\u0011!B\u0001)\n\u0019q\fJ\u0019\u0002\t\r|\u0007/\u001f\u000b\u0004\u001d\u0006\u0015\u0007bBAdC\u0001\u0007\u0011\u0011Z\u0001\u0006Kb$(/\u0019\t\u0004I\u0006-\u0017bAAgK\nA\u0001+\u0019:b[6\u000b\u0007/A\u0007ue\u0006t7OZ8s[J{wo\u001d\u000b\u0005\u0003'\fI\u000f\u0005\u0004\u0002V\u0006u\u00171\u001d\b\u0005\u0003/\fYND\u0002B\u00033L\u0011!R\u0005\u0004\u0003K#\u0015\u0002BAp\u0003C\u0014\u0001\"\u0013;fe\u0006$xN\u001d\u0006\u0004\u0003K#\u0005\u0003BA[\u0003KLA!a:\u0002T\t\u0019!k\\<\t\u000f\u0005-(\u00051\u0001\u0002T\u0006!\u0011\u000e^3s\u0003E1\u0018\r\\5eCR,\u0017J\u001c9viRK\b/\u001a\u000b\u0005\u00033\t\t\u0010C\u0004\u0002t\u000e\u0002\r!a\u0013\u0002\rM\u001c\u0007.Z7b\u000311\u0018\r\\5eCR,G+\u001f9f)\u0019\tI\"!?\u0003\u0004!9\u00111 \u0013A\u0002\u0005u\u0018!\u00024jK2$\u0007\u0003BA'\u0003\u007fLAA!\u0001\u0002P\tY1\u000b\u001e:vGR4\u0015.\u001a7e\u0011\u001d\u0011)\u0001\na\u0001\u0005\u000f\t!\u0001\u001e9\u0011\t\u00055#\u0011B\u0005\u0005\u0005\u0017\tyE\u0001\u0005ECR\fG+\u001f9f\u0001")
/* loaded from: input_file:ai/djl/spark/task/BasePredictor.class */
public abstract class BasePredictor<A, B> extends Transformer {
    private final String uid;
    private final Param<String> engine;
    private final Param<String> modelUrl;
    private final Param<Class<A>> inputClass;
    private final Param<Class<B>> outputClass;
    private final Param<String> batchifier;
    private final Param<TranslatorFactory> translatorFactory;
    private ModelLoader<A, B> model;
    private Map<String, Object> arguments;
    private StructType outputSchema;

    public String uid() {
        return this.uid;
    }

    public final Param<String> engine() {
        return this.engine;
    }

    public final Param<String> modelUrl() {
        return this.modelUrl;
    }

    public final Param<Class<A>> inputClass() {
        return this.inputClass;
    }

    public final Param<Class<B>> outputClass() {
        return this.outputClass;
    }

    public final Param<String> batchifier() {
        return this.batchifier;
    }

    public final Param<TranslatorFactory> translatorFactory() {
        return this.translatorFactory;
    }

    public ModelLoader<A, B> model() {
        return this.model;
    }

    public void model_$eq(ModelLoader<A, B> modelLoader) {
        this.model = modelLoader;
    }

    public Map<String, Object> arguments() {
        return this.arguments;
    }

    public void arguments_$eq(Map<String, Object> map) {
        this.arguments = map;
    }

    public StructType outputSchema() {
        return this.outputSchema;
    }

    public void outputSchema_$eq(StructType structType) {
        this.outputSchema = structType;
    }

    public BasePredictor<A, B> setEngine(String str) {
        return set(engine(), str);
    }

    public BasePredictor<A, B> setModelUrl(String str) {
        return set(modelUrl(), str);
    }

    public BasePredictor<A, B> setInputClass(Class<A> cls) {
        return set(inputClass(), cls);
    }

    public BasePredictor<A, B> setOutputClass(Class<B> cls) {
        return set(outputClass(), cls);
    }

    public BasePredictor<A, B> setBatchifier(String str) {
        return set(batchifier(), str);
    }

    public BasePredictor<A, B> setTranslatorFactory(TranslatorFactory translatorFactory) {
        return set(translatorFactory(), translatorFactory);
    }

    public Dataset<Row> transform(Dataset<?> dataset) {
        if (isDefined(batchifier())) {
            arguments().put("batchifier", $(batchifier()));
        } else {
            BoxedUnit boxedUnit = BoxedUnit.UNIT;
        }
        model_$eq(new ModelLoader<>((String) $(engine()), (String) $(modelUrl()), (Class) $(inputClass()), (Class) $(outputClass()), (TranslatorFactory) $(translatorFactory()), arguments()));
        validateInputType(dataset.schema());
        outputSchema_$eq(transformSchema(dataset.schema()));
        return dataset.toDF().mapPartitions(iterator -> {
            return this.transformRows(iterator);
        }, RowEncoder$.MODULE$.apply(outputSchema()));
    }

    /* renamed from: copy, reason: merged with bridge method [inline-methods] and merged with bridge method [inline-methods] and merged with bridge method [inline-methods] */
    public BasePredictor<A, B> m2copy(ParamMap paramMap) {
        return defaultCopy(paramMap);
    }

    public abstract Iterator<Row> transformRows(Iterator<Row> iterator);

    public abstract void validateInputType(StructType structType);

    public void validateType(StructField structField, DataType dataType) {
        Predef$ predef$ = Predef$.MODULE$;
        DataType dataType2 = structField.dataType();
        predef$.require(dataType2 != null ? dataType2.equals(dataType) : dataType == null, () -> {
            return new StringBuilder(37).append("Input column ").append(structField.name()).append(" type must be ").append(dataType).append(" but got ").append(structField.dataType()).append(".").toString();
        });
    }

    public BasePredictor(String str) {
        this.uid = str;
        this.engine = new Param<>(this, "engine", "The engine");
        this.modelUrl = new Param<>(this, "modelUrl", "The model URL");
        this.inputClass = new Param<>(this, "inputClass", "The input class");
        this.outputClass = new Param<>(this, "outputClass", "The output class");
        this.batchifier = new Param<>(this, "batchifier", "The batchifier. Valid values include none (default), stack, and padding.");
        this.translatorFactory = new Param<>(this, "translatorFactory", "The translator factory");
        this.arguments = new HashMap();
        setDefault(engine(), null);
        setDefault(modelUrl(), null);
    }

    public BasePredictor() {
        this(Identifiable$.MODULE$.randomUID("BasePredictor"));
    }
}
