package ai.djl.spark.task;

import ai.djl.spark.ModelLoader;
import ai.djl.translate.TranslatorFactory;
import java.util.HashMap;
import java.util.Map;
import org.apache.spark.ml.Transformer;
import org.apache.spark.ml.param.IntParam;
import org.apache.spark.ml.param.Param;
import org.apache.spark.ml.param.ParamMap;
import org.apache.spark.ml.util.Identifiable$;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Encoders$;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.types.DataType;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
import scala.Predef$;
import scala.collection.Iterator;
import scala.reflect.ScalaSignature;
import scala.runtime.BoxedUnit;
import scala.runtime.BoxesRunTime;

/* compiled from: BasePredictor.scala */
@ScalaSignature(bytes = "\u0006\u0001\t-b!\u0002\u0015*\u0003\u0003\u0011\u0004\u0002C \u0001\u0005\u000b\u0007I\u0011\t!\t\u00119\u0003!\u0011!Q\u0001\n\u0005CQa\u0014\u0001\u0005\u0002ACQa\u0014\u0001\u0005\u0002\rDq\u0001\u001a\u0001C\u0002\u0013\u0015Q\r\u0003\u0004m\u0001\u0001\u0006iA\u001a\u0005\b[\u0002\u0011\r\u0011\"\u0002f\u0011\u0019q\u0007\u0001)A\u0007M\"9q\u000e\u0001b\u0001\n\u000b\u0001\bB\u0002;\u0001A\u00035\u0011\u000fC\u0004v\u0001\t\u0007IQ\u0001<\t\rm\u0004\u0001\u0015!\u0004x\u0011\u001da\bA1A\u0005\u0006uDq!!\u0001\u0001A\u00035a\u0010C\u0005\u0002\u0004\u0001\u0011\r\u0011\"\u0002\u0002\u0006!A\u0011Q\u0003\u0001!\u0002\u001b\t9\u0001\u0003\u0005\u0002\u0018\u0001\u0011\r\u0011\"\u0002f\u0011\u001d\tI\u0002\u0001Q\u0001\u000e\u0019D1\"a\u0007\u0001\u0001\u0004\u0005\r\u0011\"\u0005\u0002\u001e!Y\u0011q\u0005\u0001A\u0002\u0003\u0007I\u0011CA\u0015\u0011-\t)\u0004\u0001a\u0001\u0002\u0003\u0006K!a\b\t\u0013\u0005]\u0002\u00011A\u0005\u0012\u0005e\u0002\"CA)\u0001\u0001\u0007I\u0011CA*\u0011!\t9\u0006\u0001Q!\n\u0005m\u0002bCA-\u0001\u0001\u0007\t\u0019!C\t\u00037B1\"!\u001c\u0001\u0001\u0004\u0005\r\u0011\"\u0005\u0002p!Y\u00111\u000f\u0001A\u0002\u0003\u0005\u000b\u0015BA/\u0011\u001d\t)\b\u0001C\u0001\u0003oBq!a \u0001\t\u0003\t\t\tC\u0004\u0002\u0006\u0002!\t!a\"\t\u000f\u0005E\u0005\u0001\"\u0001\u0002\u0014\"9\u0011q\u0013\u0001\u0005\u0002\u0005e\u0005bBAO\u0001\u0011\u0005\u0011q\u0014\u0005\b\u0003G\u0003A\u0011AAS\u0011\u001d\tI\u000b\u0001C!\u0003WCq!a8\u0001\t\u0003\n\t\u000fC\u0004\u0002n\u00021\t\"a<\t\u000f\t-\u0001A\"\u0005\u0003\u000e!9!1\u0003\u0001\u0005\u0002\tU!!\u0004\"bg\u0016\u0004&/\u001a3jGR|'O\u0003\u0002+W\u0005!A/Y:l\u0015\taS&A\u0003ta\u0006\u00148N\u0003\u0002/_\u0005\u0019AM\u001b7\u000b\u0003A\n!!Y5\u0004\u0001U\u00191'\u00161\u0014\u0005\u0001!\u0004CA\u001b>\u001b\u00051$BA\u001c9\u0003\tiGN\u0003\u0002-s)\u0011!hO\u0001\u0007CB\f7\r[3\u000b\u0003q\n1a\u001c:h\u0013\tqdGA\u0006Ue\u0006t7OZ8s[\u0016\u0014\u0018aA;jIV\t\u0011\t\u0005\u0002C\u0017:\u00111)\u0013\t\u0003\t\u001ek\u0011!\u0012\u0006\u0003\rF\na\u0001\u0010:p_Rt$\"\u0001%\u0002\u000bM\u001c\u0017\r\\1\n\u0005);\u0015A\u0002)sK\u0012,g-\u0003\u0002M\u001b\n11\u000b\u001e:j]\u001eT!AS$\u0002\tULG\rI\u0001\u0007y%t\u0017\u000e\u001e \u0015\u0005E\u0013\u0007\u0003\u0002*\u0001'~k\u0011!\u000b\t\u0003)Vc\u0001\u0001B\u0003W\u0001\t\u0007qKA\u0001B#\tAF\f\u0005\u0002Z56\tq)\u0003\u0002\\\u000f\n9aj\u001c;iS:<\u0007CA-^\u0013\tqvIA\u0002B]f\u0004\"\u0001\u00161\u0005\u000b\u0005\u0004!\u0019A,\u0003\u0003\tCQaP\u0002A\u0002\u0005#\u0012!U\u0001\t[>$W\r\\+sYV\ta\rE\u0002hU\u0006k\u0011\u0001\u001b\u0006\u0003SZ\nQ\u0001]1sC6L!a\u001b5\u0003\u000bA\u000b'/Y7\u0002\u00135|G-\u001a7Ve2\u0004\u0013AB3oO&tW-A\u0004f]\u001eLg.\u001a\u0011\u0002\u0013\t\fGo\u00195TSj,W#A9\u0011\u0005\u001d\u0014\u0018BA:i\u0005!Ie\u000e\u001e)be\u0006l\u0017A\u00032bi\u000eD7+\u001b>fA\u0005Q\u0011N\u001c9vi\u000ec\u0017m]:\u0016\u0003]\u00042a\u001a6y!\r\u0011\u0015pU\u0005\u0003u6\u0013Qa\u00117bgN\f1\"\u001b8qkR\u001cE.Y:tA\u0005Yq.\u001e;qkR\u001cE.Y:t+\u0005q\bcA4k\u007fB\u0019!)_0\u0002\u0019=,H\u000f];u\u00072\f7o\u001d\u0011\u0002#Q\u0014\u0018M\\:mCR|'OR1di>\u0014\u00180\u0006\u0002\u0002\bA!qM[A\u0005!\u0011\tY!!\u0005\u000e\u0005\u00055!bAA\b[\u0005IAO]1og2\fG/Z\u0005\u0005\u0003'\tiAA\tUe\u0006t7\u000f\\1u_J4\u0015m\u0019;pef\f!\u0003\u001e:b]Nd\u0017\r^8s\r\u0006\u001cGo\u001c:zA\u0005Q!-\u0019;dQ&4\u0017.\u001a:\u0002\u0017\t\fGo\u00195jM&,'\u000fI\u0001\u0006[>$W\r\\\u000b\u0003\u0003?\u0001b!!\t\u0002$M{V\"A\u0016\n\u0007\u0005\u00152FA\u0006N_\u0012,G\u000eT8bI\u0016\u0014\u0018!C7pI\u0016dw\fJ3r)\u0011\tY#!\r\u0011\u0007e\u000bi#C\u0002\u00020\u001d\u0013A!\u00168ji\"I\u00111\u0007\u000b\u0002\u0002\u0003\u0007\u0011qD\u0001\u0004q\u0012\n\u0014AB7pI\u0016d\u0007%A\u0005be\u001e,X.\u001a8ugV\u0011\u00111\b\t\b\u0003{\t9%QA&\u001b\t\tyD\u0003\u0003\u0002B\u0005\r\u0013\u0001B;uS2T!!!\u0012\u0002\t)\fg/Y\u0005\u0005\u0003\u0013\nyDA\u0002NCB\u00042!WA'\u0013\r\tye\u0012\u0002\u0007\u0003:L(+\u001a4\u0002\u001b\u0005\u0014x-^7f]R\u001cx\fJ3r)\u0011\tY#!\u0016\t\u0013\u0005Mr#!AA\u0002\u0005m\u0012AC1sOVlWM\u001c;tA\u0005aq.\u001e;qkR\u001c6\r[3nCV\u0011\u0011Q\f\t\u0005\u0003?\nI'\u0004\u0002\u0002b)!\u00111MA3\u0003\u0015!\u0018\u0010]3t\u0015\r\t9\u0007O\u0001\u0004gFd\u0017\u0002BA6\u0003C\u0012!b\u0015;sk\u000e$H+\u001f9f\u0003AyW\u000f\u001e9viN\u001b\u0007.Z7b?\u0012*\u0017\u000f\u0006\u0003\u0002,\u0005E\u0004\"CA\u001a5\u0005\u0005\t\u0019AA/\u00035yW\u000f\u001e9viN\u001b\u0007.Z7bA\u0005Y1/\u001a;N_\u0012,G.\u0016:m)\u0011\tI(a\u001f\u000e\u0003\u0001Aa!! \u001d\u0001\u0004\t\u0015!\u0002<bYV,\u0017!C:fi\u0016sw-\u001b8f)\u0011\tI(a!\t\r\u0005uT\u00041\u0001B\u00031\u0019X\r\u001e\"bi\u000eD7+\u001b>f)\u0011\tI(!#\t\u000f\u0005ud\u00041\u0001\u0002\fB\u0019\u0011,!$\n\u0007\u0005=uIA\u0002J]R\fQb]3u\u0013:\u0004X\u000f^\"mCN\u001cH\u0003BA=\u0003+Ca!!  \u0001\u0004A\u0018AD:fi>+H\u000f];u\u00072\f7o\u001d\u000b\u0005\u0003s\nY\n\u0003\u0004\u0002~\u0001\u0002\ra`\u0001\u0015g\u0016$HK]1og2\fGo\u001c:GC\u000e$xN]=\u0015\t\u0005e\u0014\u0011\u0015\u0005\b\u0003{\n\u0003\u0019AA\u0005\u00035\u0019X\r\u001e\"bi\u000eD\u0017NZ5feR!\u0011\u0011PAT\u0011\u0019\tiH\ta\u0001\u0003\u0006IAO]1og\u001a|'/\u001c\u000b\u0005\u0003[\u000bY\r\u0005\u0003\u00020\u0006\u0015g\u0002BAY\u0003\u0003tA!a-\u0002@:!\u0011QWA_\u001d\u0011\t9,a/\u000f\u0007\u0011\u000bI,C\u0001=\u0013\tQ4(\u0003\u0002-s%\u0019\u0011q\r\u001d\n\t\u0005\r\u0017QM\u0001\ba\u0006\u001c7.Y4f\u0013\u0011\t9-!3\u0003\u0013\u0011\u000bG/\u0019$sC6,'\u0002BAb\u0003KBq!!4$\u0001\u0004\ty-A\u0004eCR\f7/\u001a;1\t\u0005E\u00171\u001c\t\u0007\u0003'\f).!7\u000e\u0005\u0005\u0015\u0014\u0002BAl\u0003K\u0012q\u0001R1uCN,G\u000fE\u0002U\u00037$1\"!8\u0002L\u0006\u0005\t\u0011!B\u0001/\n\u0019q\fJ\u0019\u0002\t\r|\u0007/\u001f\u000b\u0004#\u0006\r\bbBAsI\u0001\u0007\u0011q]\u0001\u0006Kb$(/\u0019\t\u0004O\u0006%\u0018bAAvQ\nA\u0001+\u0019:b[6\u000b\u0007/A\u0007ue\u0006t7OZ8s[J{wo\u001d\u000b\u0005\u0003c\u00149\u0001\u0005\u0004\u0002t\u0006m(\u0011\u0001\b\u0005\u0003k\fIPD\u0002E\u0003oL\u0011\u0001S\u0005\u0004\u0003\u0007<\u0015\u0002BA\u007f\u0003\u007f\u0014\u0001\"\u0013;fe\u0006$xN\u001d\u0006\u0004\u0003\u0007<\u0005\u0003BAj\u0005\u0007IAA!\u0002\u0002f\t\u0019!k\\<\t\u000f\t%Q\u00051\u0001\u0002r\u0006!\u0011\u000e^3s\u0003E1\u0018\r\\5eCR,\u0017J\u001c9viRK\b/\u001a\u000b\u0005\u0003W\u0011y\u0001C\u0004\u0003\u0012\u0019\u0002\r!!\u0018\u0002\rM\u001c\u0007.Z7b\u000311\u0018\r\\5eCR,G+\u001f9f)\u0019\tYCa\u0006\u0003\"!9!\u0011D\u0014A\u0002\tm\u0011!\u00024jK2$\u0007\u0003BA0\u0005;IAAa\b\u0002b\tY1\u000b\u001e:vGR4\u0015.\u001a7e\u0011\u001d\u0011\u0019c\na\u0001\u0005K\t!\u0001\u001e9\u0011\t\u0005}#qE\u0005\u0005\u0005S\t\tG\u0001\u0005ECR\fG+\u001f9f\u0001")
/* loaded from: input_file:ai/djl/spark/task/BasePredictor.class */
public abstract class BasePredictor<A, B> extends Transformer {
    private final String uid;
    private final Param<String> modelUrl;
    private final Param<String> engine;
    private final IntParam batchSize;
    private final Param<Class<A>> inputClass;
    private final Param<Class<B>> outputClass;
    private final Param<TranslatorFactory> translatorFactory;
    private final Param<String> batchifier;
    private ModelLoader<A, B> model;
    private Map<String, Object> arguments;
    private StructType outputSchema;

    public String uid() {
        return this.uid;
    }

    public final Param<String> modelUrl() {
        return this.modelUrl;
    }

    public final Param<String> engine() {
        return this.engine;
    }

    public final IntParam batchSize() {
        return this.batchSize;
    }

    public final Param<Class<A>> inputClass() {
        return this.inputClass;
    }

    public final Param<Class<B>> outputClass() {
        return this.outputClass;
    }

    public final Param<TranslatorFactory> translatorFactory() {
        return this.translatorFactory;
    }

    public final Param<String> batchifier() {
        return this.batchifier;
    }

    public ModelLoader<A, B> model() {
        return this.model;
    }

    public void model_$eq(ModelLoader<A, B> modelLoader) {
        this.model = modelLoader;
    }

    public Map<String, Object> arguments() {
        return this.arguments;
    }

    public void arguments_$eq(Map<String, Object> map) {
        this.arguments = map;
    }

    public StructType outputSchema() {
        return this.outputSchema;
    }

    public void outputSchema_$eq(StructType structType) {
        this.outputSchema = structType;
    }

    public BasePredictor<A, B> setModelUrl(String str) {
        return set(modelUrl(), str);
    }

    public BasePredictor<A, B> setEngine(String str) {
        return set(engine(), str);
    }

    public BasePredictor<A, B> setBatchSize(int i) {
        return set(batchSize(), BoxesRunTime.boxToInteger(i));
    }

    public BasePredictor<A, B> setInputClass(Class<A> cls) {
        return set(inputClass(), cls);
    }

    public BasePredictor<A, B> setOutputClass(Class<B> cls) {
        return set(outputClass(), cls);
    }

    public BasePredictor<A, B> setTranslatorFactory(TranslatorFactory translatorFactory) {
        return set(translatorFactory(), translatorFactory);
    }

    public BasePredictor<A, B> setBatchifier(String str) {
        return set(batchifier(), str);
    }

    public Dataset<Row> transform(Dataset<?> dataset) {
        if (isDefined(batchifier())) {
            arguments().put("batchifier", $(batchifier()));
        } else {
            BoxedUnit boxedUnit = BoxedUnit.UNIT;
        }
        model_$eq(new ModelLoader<>((String) $(engine()), (String) $(modelUrl()), (Class) $(inputClass()), (Class) $(outputClass()), (TranslatorFactory) $(translatorFactory()), arguments()));
        validateInputType(dataset.schema());
        outputSchema_$eq(transformSchema(dataset.schema()));
        return dataset.toDF().mapPartitions(iterator -> {
            return this.transformRows(iterator);
        }, Encoders$.MODULE$.row(outputSchema()));
    }

    /* renamed from: copy, reason: merged with bridge method [inline-methods] and merged with bridge method [inline-methods] and merged with bridge method [inline-methods] */
    public BasePredictor<A, B> m2copy(ParamMap paramMap) {
        return defaultCopy(paramMap);
    }

    public abstract Iterator<Row> transformRows(Iterator<Row> iterator);

    public abstract void validateInputType(StructType structType);

    public void validateType(StructField structField, DataType dataType) {
        Predef$ predef$ = Predef$.MODULE$;
        DataType dataType2 = structField.dataType();
        predef$.require(dataType2 != null ? dataType2.equals(dataType) : dataType == null, () -> {
            return new StringBuilder(37).append("Input column ").append(structField.name()).append(" type must be ").append(dataType).append(" but got ").append(structField.dataType()).append(".").toString();
        });
    }

    public BasePredictor(String str) {
        this.uid = str;
        this.modelUrl = new Param<>(this, "modelUrl", "The model URL");
        this.engine = new Param<>(this, "engine", "The engine");
        this.batchSize = new IntParam(this, "batchSize", "The batch size");
        this.inputClass = new Param<>(this, "inputClass", "The input class");
        this.outputClass = new Param<>(this, "outputClass", "The output class");
        this.translatorFactory = new Param<>(this, "translatorFactory", "The translator factory");
        this.batchifier = new Param<>(this, "batchifier", "The batchifier. Valid values include none (default), stack, and padding.");
        this.arguments = new HashMap();
        setDefault(modelUrl(), null);
        setDefault(engine(), null);
        setDefault(batchSize(), BoxesRunTime.boxToInteger(10));
    }

    public BasePredictor() {
        this(Identifiable$.MODULE$.randomUID("BasePredictor"));
    }
}
