package org.platanios.tensorflow.api.ops.rnn.attention;

import org.platanios.tensorflow.api.core.Indexer;
import org.platanios.tensorflow.api.core.package$exception$InvalidShapeException;
import org.platanios.tensorflow.api.core.package$exception$InvalidShapeException$;
import org.platanios.tensorflow.api.core.types.Cpackage;
import org.platanios.tensorflow.api.core.types.package$TF$;
import org.platanios.tensorflow.api.implicits.Implicits$;
import org.platanios.tensorflow.api.ops.Basic$;
import org.platanios.tensorflow.api.ops.Math$;
import org.platanios.tensorflow.api.ops.Output;
import scala.Float$;
import scala.Function1;
import scala.Predef$;
import scala.collection.Seq$;
import scala.runtime.Nothing$;

/* compiled from: Attention.scala */
/* loaded from: input_file:org/platanios/tensorflow/api/ops/rnn/attention/Attention$.class */
public final class Attention$ {
    public static Attention$ MODULE$;

    static {
        new Attention$();
    }

    public <T, State, StateShape> Output<Object> $lessinit$greater$default$2() {
        return null;
    }

    public <T, State, StateShape> boolean $lessinit$greater$default$3() {
        return true;
    }

    public <T, State, StateShape> Output<Object> $lessinit$greater$default$4() {
        return Implicits$.MODULE$.floatToOutput(Float$.MODULE$.MinValue());
    }

    public <T, State, StateShape> String $lessinit$greater$default$5() {
        return "Attention";
    }

    /* JADX WARN: Multi-variable type inference failed */
    public <T> Output<Object> dimSize(Output<T> output, int i, Cpackage.TF<T> tf) {
        return (output.rank() == -1 || output.shape().apply(i) == -1) ? Basic$.MODULE$.shape(output, Basic$.MODULE$.shape$default$2(), Basic$.MODULE$.shape$default$3(), tf).castTo(package$TF$.MODULE$.intEvTF()).slice(Implicits$.MODULE$.intToIndex(i), Predef$.MODULE$.wrapRefArray(new Indexer[0])) : Basic$.MODULE$.constant(Implicits$.MODULE$.intToTensor(output.shape().apply(i)), Basic$.MODULE$.constant$default$2(), Basic$.MODULE$.constant$default$3());
    }

    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r0v12, types: [org.platanios.tensorflow.api.ops.Basic$] */
    /* JADX WARN: Type inference failed for: r0v19, types: [org.platanios.tensorflow.api.ops.Basic$] */
    /* JADX WARN: Type inference failed for: r0v2, types: [org.platanios.tensorflow.api.ops.Basic$] */
    /* JADX WARN: Type inference failed for: r1v4, types: [org.platanios.tensorflow.api.ops.Basic$] */
    /* JADX WARN: Type inference failed for: r1v7, types: [org.platanios.tensorflow.api.ops.Basic$] */
    public <T> Output<T> maybeMaskValues(Output<T> output, Output<Object> output2, boolean z, Cpackage.TF<T> tf, Predef$.less.colon.less<Function1<Function1<T, Nothing$>, Nothing$>, Function1<Function1<Cpackage.TruncatedHalf, Nothing$>, Nothing$>> lessVar) throws package$exception$InvalidShapeException {
        if (z && !output.shape().apply(Implicits$.MODULE$.intToIndexerConstruction(2).$colon$colon()).isFullyDefined()) {
            throw new package$exception$InvalidShapeException(new StringBuilder(41).append("Expected memory '").append(output.name()).append("' to have fully defined ").append(new StringBuilder(34).append("inner dimensions, but saw shape: ").append(output.shape()).append(".").toString()).toString(), package$exception$InvalidShapeException$.MODULE$.apply$default$2());
        }
        Output castTo = output2 == null ? null : Basic$.MODULE$.sequenceMask(output2, Basic$.MODULE$.shape(output, Basic$.MODULE$.shape$default$2(), Basic$.MODULE$.shape$default$3(), tf).castTo(package$TF$.MODULE$.intEvTF()).slice(Implicits$.MODULE$.intToIndex(1), Predef$.MODULE$.wrapRefArray(new Indexer[0])), Basic$.MODULE$.sequenceMask$default$3(), package$TF$.MODULE$.intEvTF(), Predef$.MODULE$.$conforms()).castTo(tf);
        if (castTo == null) {
            return output;
        }
        return output.$times(Implicits$.MODULE$.outputBasicOps(castTo).reshape(Basic$.MODULE$.concatenate(Seq$.MODULE$.apply(Predef$.MODULE$.wrapRefArray(new Output[]{Basic$.MODULE$.shape(castTo, Basic$.MODULE$.shape$default$2(), Basic$.MODULE$.shape$default$3(), tf).castTo(package$TF$.MODULE$.intEvTF()), Basic$.MODULE$.ones(Basic$.MODULE$.expandDims((output.rank() != -1 ? Basic$.MODULE$.constant(Implicits$.MODULE$.intToTensor(output.rank()), Basic$.MODULE$.constant$default$2(), Basic$.MODULE$.constant$default$3()) : Basic$.MODULE$.rank(output, Basic$.MODULE$.rank$default$2(), Basic$.MODULE$.rank$default$3(), tf)).$minus(Implicits$.MODULE$.intToOutput(2), Predef$.MODULE$.$conforms()), Implicits$.MODULE$.intToOutput(0), Basic$.MODULE$.expandDims$default$3(), package$TF$.MODULE$.intEvTF(), package$TF$.MODULE$.intEvTF(), Predef$.MODULE$.$conforms()), package$TF$.MODULE$.intEvTF(), package$TF$.MODULE$.intEvTF(), Predef$.MODULE$.$conforms())})), Implicits$.MODULE$.intToOutput(0), Basic$.MODULE$.concatenate$default$3(), package$TF$.MODULE$.intEvTF()), package$TF$.MODULE$.intEvTF(), Predef$.MODULE$.$conforms()), lessVar);
    }

    /* JADX WARN: Multi-variable type inference failed */
    public <T> Output<T> maybeMaskScore(Output<T> output, Output<Object> output2, Output<T> output3, Cpackage.TF<T> tf, Predef$.less.colon.less<Function1<Function1<T, Nothing$>, Nothing$>, Function1<Function1<Cpackage.TruncatedHalf, Nothing$>, Nothing$>> lessVar) {
        if (output2 != null) {
            return output;
        }
        return Math$.MODULE$.select(Basic$.MODULE$.sequenceMask(output2, Basic$.MODULE$.shape(output, Basic$.MODULE$.shape$default$2(), Basic$.MODULE$.shape$default$3(), tf).castTo(package$TF$.MODULE$.intEvTF()).slice(Implicits$.MODULE$.intToIndex(1), Predef$.MODULE$.wrapRefArray(new Indexer[0])), Basic$.MODULE$.sequenceMask$default$3(), package$TF$.MODULE$.intEvTF(), Predef$.MODULE$.$conforms()), output, output3.$times(Basic$.MODULE$.onesLike(output, Basic$.MODULE$.onesLike$default$2(), Basic$.MODULE$.onesLike$default$3()), lessVar), Math$.MODULE$.select$default$4(), tf);
    }

    private Attention$() {
        MODULE$ = this;
    }
}
