package ai.djl.spark.task.audio;

import ai.djl.inference.Predictor;
import ai.djl.modality.audio.Audio;
import ai.djl.modality.audio.AudioFactory;
import ai.djl.modality.audio.translator.SpeechRecognitionTranslatorFactory;
import java.io.ByteArrayInputStream;
import org.apache.spark.ml.param.IntParam;
import org.apache.spark.ml.param.Param;
import org.apache.spark.ml.param.shared.HasInputCol;
import org.apache.spark.ml.param.shared.HasOutputCol;
import org.apache.spark.ml.util.Identifiable$;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.Row$;
import org.apache.spark.sql.types.BinaryType$;
import org.apache.spark.sql.types.StringType$;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructField$;
import org.apache.spark.sql.types.StructType;
import scala.Predef$;
import scala.collection.Iterator;
import scala.collection.Seq;
import scala.collection.Seq$;
import scala.collection.mutable.ArrayOps;
import scala.reflect.ClassTag$;
import scala.reflect.ScalaSignature;
import scala.runtime.BoxedUnit;
import scala.runtime.BoxesRunTime;

/* compiled from: SpeechRecognizer.scala */
@ScalaSignature(bytes = "\u0006\u0001\u0005\u001de\u0001\u0002\r\u001a\u0001\u0011B\u0001\u0002\u0013\u0001\u0003\u0006\u0004%\t%\u0013\u0005\t\u0015\u0002\u0011\t\u0011)A\u0005S!)1\n\u0001C\u0001\u0019\")1\n\u0001C\u0001\u001f\"9\u0001\u000b\u0001b\u0001\n\u000b\t\u0006B\u0002,\u0001A\u00035!\u000bC\u0004X\u0001\t\u0007IQA)\t\ra\u0003\u0001\u0015!\u0004S\u0011\u001dI\u0006A1A\u0005\u0006ECaA\u0017\u0001!\u0002\u001b\u0011\u0006\"C.\u0001\u0001\u0004\u0005\r\u0011\"\u0005]\u0011%\t\u0007\u00011AA\u0002\u0013E!\rC\u0005i\u0001\u0001\u0007\t\u0011)Q\u0005;\")\u0011\u000e\u0001C\u0001U\")a\u000e\u0001C\u0001_\")\u0011\u000f\u0001C\u0001e\")A\u000f\u0001C\u0001k\")q\u000f\u0001C\u0001q\")!\u0010\u0001C\u0001w\"9\u0011q\b\u0001\u0005B\u0005\u0005\u0003bBA(\u0001\u0011E\u0013\u0011\u000b\u0005\b\u0003[\u0002A\u0011AA8\u0011\u001d\t\t\t\u0001C!\u0003\u0007\u0013\u0001c\u00159fK\u000eD'+Z2pO:L'0\u001a:\u000b\u0005iY\u0012!B1vI&|'B\u0001\u000f\u001e\u0003\u0011!\u0018m]6\u000b\u0005yy\u0012!B:qCJ\\'B\u0001\u0011\"\u0003\r!'\u000e\u001c\u0006\u0002E\u0005\u0011\u0011-[\u0002\u0001'\u0011\u0001QEN#\u0011\u0007\u0019:\u0013&D\u0001\u001a\u0013\tA\u0013D\u0001\nCCN,\u0017)\u001e3j_B\u0013X\rZ5di>\u0014\bC\u0001\u00164\u001d\tY\u0013\u0007\u0005\u0002-_5\tQF\u0003\u0002/G\u00051AH]8pizR\u0011\u0001M\u0001\u0006g\u000e\fG.Y\u0005\u0003e=\na\u0001\u0015:fI\u00164\u0017B\u0001\u001b6\u0005\u0019\u0019FO]5oO*\u0011!g\f\t\u0003o\rk\u0011\u0001\u000f\u0006\u0003si\naa\u001d5be\u0016$'BA\u001e=\u0003\u0015\u0001\u0018M]1n\u0015\tid(\u0001\u0002nY*\u0011ad\u0010\u0006\u0003\u0001\u0006\u000ba!\u00199bG\",'\"\u0001\"\u0002\u0007=\u0014x-\u0003\u0002Eq\tY\u0001*Y:J]B,HoQ8m!\t9d)\u0003\u0002Hq\ta\u0001*Y:PkR\u0004X\u000f^\"pY\u0006\u0019Q/\u001b3\u0016\u0003%\nA!^5eA\u00051A(\u001b8jiz\"\"!\u0014(\u0011\u0005\u0019\u0002\u0001\"\u0002%\u0004\u0001\u0004IC#A'\u0002\u0011\rD\u0017M\u001c8fYN,\u0012A\u0015\t\u0003'Rk\u0011AO\u0005\u0003+j\u0012\u0001\"\u00138u!\u0006\u0014\u0018-\\\u0001\nG\"\fgN\\3mg\u0002\n!b]1na2,'+\u0019;f\u0003-\u0019\u0018-\u001c9mKJ\u000bG/\u001a\u0011\u0002\u0019M\fW\u000e\u001d7f\r>\u0014X.\u0019;\u0002\u001bM\fW\u000e\u001d7f\r>\u0014X.\u0019;!\u00035Ig\u000e];u\u0007>d\u0017J\u001c3fqV\tQ\f\u0005\u0002_?6\tq&\u0003\u0002a_\t\u0019\u0011J\u001c;\u0002#%t\u0007/\u001e;D_2Le\u000eZ3y?\u0012*\u0017\u000f\u0006\u0002dMB\u0011a\fZ\u0005\u0003K>\u0012A!\u00168ji\"9q\rDA\u0001\u0002\u0004i\u0016a\u0001=%c\u0005q\u0011N\u001c9vi\u000e{G.\u00138eKb\u0004\u0013aC:fi&s\u0007/\u001e;D_2$\"a\u001b7\u000e\u0003\u0001AQ!\u001c\bA\u0002%\nQA^1mk\u0016\fAb]3u\u001fV$\b/\u001e;D_2$\"a\u001b9\t\u000b5|\u0001\u0019A\u0015\u0002\u0017M,Go\u00115b]:,Gn\u001d\u000b\u0003WNDQ!\u001c\tA\u0002u\u000bQb]3u'\u0006l\u0007\u000f\\3SCR,GCA6w\u0011\u0015i\u0017\u00031\u0001^\u0003=\u0019X\r^*b[BdWMR8s[\u0006$HCA6z\u0011\u0015i'\u00031\u0001^\u0003%\u0011XmY8h]&TX\rF\u0002}\u00037\u00012!`A\u000b\u001d\rq\u0018q\u0002\b\u0004\u007f\u0006-a\u0002BA\u0001\u0003\u0013qA!a\u0001\u0002\b9\u0019A&!\u0002\n\u0003\tK!\u0001Q!\n\u0005yy\u0014bAA\u0007}\u0005\u00191/\u001d7\n\t\u0005E\u00111C\u0001\ba\u0006\u001c7.Y4f\u0015\r\tiAP\u0005\u0005\u0003/\tIBA\u0005ECR\fgI]1nK*!\u0011\u0011CA\n\u0011\u001d\tib\u0005a\u0001\u0003?\tq\u0001Z1uCN,G\u000f\r\u0003\u0002\"\u00055\u0002CBA\u0012\u0003K\tI#\u0004\u0002\u0002\u0014%!\u0011qEA\n\u0005\u001d!\u0015\r^1tKR\u0004B!a\u000b\u0002.1\u0001A\u0001DA\u0018\u00037\t\t\u0011!A\u0003\u0002\u0005E\"aA0%cE!\u00111GA\u001d!\rq\u0016QG\u0005\u0004\u0003oy#a\u0002(pi\"Lgn\u001a\t\u0004=\u0006m\u0012bAA\u001f_\t\u0019\u0011I\\=\u0002\u0013Q\u0014\u0018M\\:g_JlGc\u0001?\u0002D!9\u0011Q\u0004\u000bA\u0002\u0005\u0015\u0003\u0007BA$\u0003\u0017\u0002b!a\t\u0002&\u0005%\u0003\u0003BA\u0016\u0003\u0017\"A\"!\u0014\u0002D\u0005\u0005\t\u0011!B\u0001\u0003c\u00111a\u0018\u00133\u00035!(/\u00198tM>\u0014XNU8xgR!\u00111KA5!\u0019\t)&!\u0018\u0002d9!\u0011qKA.\u001d\ra\u0013\u0011L\u0005\u0002a%\u0019\u0011\u0011C\u0018\n\t\u0005}\u0013\u0011\r\u0002\t\u0013R,'/\u0019;pe*\u0019\u0011\u0011C\u0018\u0011\t\u0005\r\u0012QM\u0005\u0005\u0003O\n\u0019BA\u0002S_^Dq!a\u001b\u0016\u0001\u0004\t\u0019&\u0001\u0003ji\u0016\u0014\u0018!\u0005<bY&$\u0017\r^3J]B,H\u000fV=qKR\u00191-!\u001d\t\u000f\u0005Md\u00031\u0001\u0002v\u000511o\u00195f[\u0006\u0004B!a\u001e\u0002~5\u0011\u0011\u0011\u0010\u0006\u0005\u0003w\n\u0019\"A\u0003usB,7/\u0003\u0003\u0002��\u0005e$AC*ueV\u001cG\u000fV=qK\u0006yAO]1og\u001a|'/\\*dQ\u0016l\u0017\r\u0006\u0003\u0002v\u0005\u0015\u0005bBA:/\u0001\u0007\u0011Q\u000f")
/* loaded from: input_file:ai/djl/spark/task/audio/SpeechRecognizer.class */
public class SpeechRecognizer extends BaseAudioPredictor<String> implements HasInputCol, HasOutputCol {
    private final String uid;
    private final IntParam channels;
    private final IntParam sampleRate;
    private final IntParam sampleFormat;
    private int inputColIndex;
    private final Param<String> outputCol;
    private final Param<String> inputCol;

    public final String getOutputCol() {
        return HasOutputCol.getOutputCol$(this);
    }

    public final String getInputCol() {
        return HasInputCol.getInputCol$(this);
    }

    public final Param<String> outputCol() {
        return this.outputCol;
    }

    public final void org$apache$spark$ml$param$shared$HasOutputCol$_setter_$outputCol_$eq(Param<String> param) {
        this.outputCol = param;
    }

    public final Param<String> inputCol() {
        return this.inputCol;
    }

    public final void org$apache$spark$ml$param$shared$HasInputCol$_setter_$inputCol_$eq(Param<String> param) {
        this.inputCol = param;
    }

    @Override // ai.djl.spark.task.audio.BaseAudioPredictor, ai.djl.spark.task.BasePredictor
    public String uid() {
        return this.uid;
    }

    public final IntParam channels() {
        return this.channels;
    }

    public final IntParam sampleRate() {
        return this.sampleRate;
    }

    public final IntParam sampleFormat() {
        return this.sampleFormat;
    }

    public int inputColIndex() {
        return this.inputColIndex;
    }

    public void inputColIndex_$eq(int i) {
        this.inputColIndex = i;
    }

    public SpeechRecognizer setInputCol(String str) {
        return set(inputCol(), str);
    }

    public SpeechRecognizer setOutputCol(String str) {
        return set(outputCol(), str);
    }

    public SpeechRecognizer setChannels(int i) {
        return set(channels(), BoxesRunTime.boxToInteger(i));
    }

    public SpeechRecognizer setSampleRate(int i) {
        return set(sampleRate(), BoxesRunTime.boxToInteger(i));
    }

    public SpeechRecognizer setSampleFormat(int i) {
        return set(sampleFormat(), BoxesRunTime.boxToInteger(i));
    }

    public Dataset<Row> recognize(Dataset<?> dataset) {
        return transform(dataset);
    }

    @Override // ai.djl.spark.task.BasePredictor
    public Dataset<Row> transform(Dataset<?> dataset) {
        inputColIndex_$eq(dataset.schema().fieldIndex((String) $(inputCol())));
        return super.transform(dataset);
    }

    @Override // ai.djl.spark.task.BasePredictor
    public Iterator<Row> transformRows(Iterator<Row> iterator) {
        Predictor<Audio, String> newPredictor = model().newPredictor();
        return iterator.map(row -> {
            byte[] bArr = (byte[]) row.getAs(this.inputColIndex());
            AudioFactory newInstance = AudioFactory.newInstance();
            if (this.isDefined(this.channels())) {
                newInstance.setChannels(BoxesRunTime.unboxToInt(this.$(this.channels())));
            } else {
                BoxedUnit boxedUnit = BoxedUnit.UNIT;
            }
            if (this.isDefined(this.sampleRate())) {
                newInstance.setSampleRate(BoxesRunTime.unboxToInt(this.$(this.sampleRate())));
            } else {
                BoxedUnit boxedUnit2 = BoxedUnit.UNIT;
            }
            if (this.isDefined(this.sampleFormat())) {
                newInstance.setSampleFormat(BoxesRunTime.unboxToInt(this.$(this.sampleFormat())));
            } else {
                BoxedUnit boxedUnit3 = BoxedUnit.UNIT;
            }
            return Row$.MODULE$.fromSeq((Seq) row.toSeq().$colon$plus(newPredictor.predict(newInstance.fromInputStream(new ByteArrayInputStream(bArr))), Seq$.MODULE$.canBuildFrom()));
        });
    }

    @Override // ai.djl.spark.task.BasePredictor
    public void validateInputType(StructType structType) {
        validateType(structType.apply((String) $(inputCol())), BinaryType$.MODULE$);
    }

    public StructType transformSchema(StructType structType) {
        return new StructType((StructField[]) new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps(structType.fields())).$colon$plus(new StructField((String) $(outputCol()), StringType$.MODULE$, StructField$.MODULE$.apply$default$3(), StructField$.MODULE$.apply$default$4()), ClassTag$.MODULE$.apply(StructField.class)));
    }

    public SpeechRecognizer(String str) {
        this.uid = str;
        HasInputCol.$init$(this);
        HasOutputCol.$init$(this);
        this.channels = new IntParam(this, "channels", "The number of channels");
        this.sampleRate = new IntParam(this, "sampleRate", "The audio sample rate");
        this.sampleFormat = new IntParam(this, "sampleFormat", "The audio sample format");
        setDefault(outputClass(), String.class);
        setDefault(translatorFactory(), new SpeechRecognitionTranslatorFactory());
    }

    public SpeechRecognizer() {
        this(Identifiable$.MODULE$.randomUID("SpeechRecognizer"));
    }
}
