package org.platanios.tensorflow.api.ops.rnn.attention;

import org.platanios.tensorflow.api.core.Indexer;
import org.platanios.tensorflow.api.core.package$exception$InvalidShapeException;
import org.platanios.tensorflow.api.core.types.Cpackage;
import org.platanios.tensorflow.api.core.types.package$TF$;
import org.platanios.tensorflow.api.implicits.Implicits$;
import org.platanios.tensorflow.api.ops.Output;
import org.platanios.tensorflow.api.ops.basic.Basic$;
import org.platanios.tensorflow.api.ops.math.Math$;
import scala.Function1;
import scala.MatchError;
import scala.None$;
import scala.Option;
import scala.Predef$;
import scala.Serializable;
import scala.Some;
import scala.collection.Seq$;
import scala.runtime.Nothing$;

/* compiled from: Attention.scala */
/* loaded from: input_file:org/platanios/tensorflow/api/ops/rnn/attention/Attention$.class */
public final class Attention$ {
    public static Attention$ MODULE$;

    static {
        new Attention$();
    }

    public <T, State, StateShape> String $lessinit$greater$default$2() {
        return "Attention";
    }

    /* JADX WARN: Multi-variable type inference failed */
    public <T> Output<T> maybeMaskValues(Output<T> output, Option<Output<Object>> option, Cpackage.TF<T> tf, Predef$.less.colon.less<Function1<Function1<T, Nothing$>, Nothing$>, Function1<Function1<Cpackage.TruncatedHalf, Nothing$>, Nothing$>> lessVar) throws package$exception$InvalidShapeException {
        Output<T> $times;
        if (None$.MODULE$.equals(option)) {
            $times = output;
        } else {
            if (!(option instanceof Some)) {
                throw new MatchError(option);
            }
            Serializable castTo = Basic$.MODULE$.sequenceMask((Output) ((Some) option).value(), Basic$.MODULE$.shape(output, Basic$.MODULE$.shape$default$2(), Basic$.MODULE$.shape$default$3(), tf).slice(Implicits$.MODULE$.intToIndex(1), Predef$.MODULE$.wrapRefArray(new Indexer[0])), Basic$.MODULE$.sequenceMask$default$3(), package$TF$.MODULE$.intEvTF(), Predef$.MODULE$.$conforms()).castTo(tf);
            $times = output.$times(Implicits$.MODULE$.outputBasicOps(castTo).reshape(Basic$.MODULE$.concatenate(Seq$.MODULE$.apply(Predef$.MODULE$.wrapRefArray(new Output[]{Basic$.MODULE$.shape(castTo, Basic$.MODULE$.shape$default$2(), Basic$.MODULE$.shape$default$3(), tf), Basic$.MODULE$.ones((Output<Object>) Basic$.MODULE$.expandDims((output.rank() != -1 ? Basic$.MODULE$.constant(Implicits$.MODULE$.intToTensor(output.rank()), Basic$.MODULE$.constant$default$2(), Basic$.MODULE$.constant$default$3()) : Basic$.MODULE$.rank(output, Basic$.MODULE$.rank$default$2(), Basic$.MODULE$.rank$default$3(), tf)).$minus(Implicits$.MODULE$.intToOutput(2), Predef$.MODULE$.$conforms()), Implicits$.MODULE$.intToOutput(0), Basic$.MODULE$.expandDims$default$3(), package$TF$.MODULE$.intEvTF(), package$TF$.MODULE$.intEvTF(), Predef$.MODULE$.$conforms()), package$TF$.MODULE$.intEvTF())})), Implicits$.MODULE$.intToOutput(0), Basic$.MODULE$.concatenate$default$3(), package$TF$.MODULE$.intEvTF()), package$TF$.MODULE$.intEvTF(), Predef$.MODULE$.$conforms()), lessVar);
        }
        return $times;
    }

    /* JADX WARN: Multi-variable type inference failed */
    public <T> Output<T> maybeMaskScore(Output<T> output, Output<T> output2, Option<Output<Object>> option, Cpackage.TF<T> tf, Predef$.less.colon.less<Function1<Function1<T, Nothing$>, Nothing$>, Function1<Function1<Cpackage.TruncatedHalf, Nothing$>, Nothing$>> lessVar) {
        Output<T> select;
        if (None$.MODULE$.equals(option)) {
            select = output;
        } else {
            if (!(option instanceof Some)) {
                throw new MatchError(option);
            }
            select = Math$.MODULE$.select(Basic$.MODULE$.sequenceMask((Output) ((Some) option).value(), Basic$.MODULE$.shape(output, Basic$.MODULE$.shape$default$2(), Basic$.MODULE$.shape$default$3(), tf).slice(Implicits$.MODULE$.intToIndex(1), Predef$.MODULE$.wrapRefArray(new Indexer[0])), Basic$.MODULE$.sequenceMask$default$3(), package$TF$.MODULE$.intEvTF(), Predef$.MODULE$.$conforms()), output, output2.$times(Basic$.MODULE$.onesLike(output, Basic$.MODULE$.onesLike$default$2(), Basic$.MODULE$.onesLike$default$3()), lessVar), Math$.MODULE$.select$default$4(), tf);
        }
        return select;
    }

    public <T> Option<Output<Object>> maybeMaskScore$default$3() {
        return None$.MODULE$;
    }

    private Attention$() {
        MODULE$ = this;
    }
}
