package org.platanios.tensorflow.api.ops.rnn.attention;

import org.platanios.tensorflow.api.core.Indexer;
import org.platanios.tensorflow.api.core.IndexerConstructionWithTwoNumbers$;
import org.platanios.tensorflow.api.core.NewAxis$;
import org.platanios.tensorflow.api.core.Shape;
import org.platanios.tensorflow.api.core.package$exception$;
import org.platanios.tensorflow.api.core.types.Cpackage;
import org.platanios.tensorflow.api.core.types.package$TF$;
import org.platanios.tensorflow.api.implicits.Implicits$;
import org.platanios.tensorflow.api.ops.Output;
import org.platanios.tensorflow.api.ops.basic.Basic$;
import org.platanios.tensorflow.api.ops.math.Math$;
import org.platanios.tensorflow.api.ops.rnn.attention.Attention;
import org.platanios.tensorflow.jni.InvalidArgumentException;
import scala.$less;
import scala.$less$colon$less$;
import scala.Function1;
import scala.collection.immutable.Nil$;
import scala.collection.immutable.Seq;
import scala.collection.immutable.Seq$;
import scala.reflect.ScalaSignature;
import scala.runtime.Nothing$;
import scala.runtime.ScalaRunTime$;

/* compiled from: LuongAttention.scala */
@ScalaSignature(bytes = "\u0006\u0005\t}b\u0001B\u0010!\u0001=B\u0001b\u0011\u0001\u0003\u0006\u0004%\t\u0005\u0012\u0005\t\u0019\u0002\u0011\t\u0011)A\u0005\u000b\"AQ\n\u0001BC\u0002\u0013\u0005a\n\u0003\u0005Q\u0001\t\u0005\t\u0015!\u0003P\u0011!\t\u0006A!b\u0001\n\u0003\u0011\u0006\u0002\u0003,\u0001\u0005\u0003\u0005\u000b\u0011B*\t\u0011]\u0003!Q1A\u0005\u00029C\u0001\u0002\u0017\u0001\u0003\u0002\u0003\u0006Ia\u0014\u0005\t3\u0002\u0011)\u0019!C!5\"Iq\f\u0001B\u0001B\u0003%1\f\u0019\u0005\tC\u0002\u0011)\u0019!C!E\"Aa\u000e\u0001B\u0001B\u0003%1\r\u0003\u0005p\u0001\t\r\t\u0015a\u0003q\u0011)\ty\u0001\u0001B\u0002B\u0003-\u0011\u0011\u0003\u0005\b\u0003/\u0001A\u0011AA\r\u0011\u001d\ty\u0003\u0001C!\u0003cAq!a\u0010\u0001\t#\n\t\u0005C\u0004\u0002Z\u0001!\t&a\u0017\t\u000f\u0005E\u0006\u0001\"\u0015\u00024\u001e9\u0011\u0011\u0018\u0011\t\u0002\u0005mfAB\u0010!\u0011\u0003\ti\fC\u0004\u0002\u0018U!\t!!2\t\u000f\u0005\u001dW\u0003\"\u0001\u0002J\"I\u00111_\u000b\u0012\u0002\u0013\u0005\u0011Q\u001f\u0005\n\u0005')\u0012\u0013!C\u0001\u0005+A\u0011B!\u0007\u0016#\u0003%\tAa\u0007\t\u0013\t\rR#%A\u0005\u0002\t\u0015\u0002\"\u0003B\u0017+E\u0005I\u0011\u0001B\u0018\u0011%\u0011\u0019$FI\u0001\n\u0003\u0011)\u0004C\u0005\u0003:U\t\n\u0011\"\u0001\u0003<\tqA*^8oO\u0006#H/\u001a8uS>t'BA\u0011#\u0003%\tG\u000f^3oi&|gN\u0003\u0002$I\u0005\u0019!O\u001c8\u000b\u0005\u00152\u0013aA8qg*\u0011q\u0005K\u0001\u0004CBL'BA\u0015+\u0003)!XM\\:pe\u001adwn\u001e\u0006\u0003W1\n\u0011\u0002\u001d7bi\u0006t\u0017n\\:\u000b\u00035\n1a\u001c:h\u0007\u0001)\"\u0001M\u001c\u0014\u0005\u0001\t\u0004c\u0001\u001a4k5\t\u0001%\u0003\u00025A\ty1+[7qY\u0016\fE\u000f^3oi&|g\u000e\u0005\u00027o1\u0001A!\u0002\u001d\u0001\u0005\u0004I$!\u0001+\u0012\u0005i\u0002\u0005CA\u001e?\u001b\u0005a$\"A\u001f\u0002\u000bM\u001c\u0017\r\\1\n\u0005}b$a\u0002(pi\"Lgn\u001a\t\u0003w\u0005K!A\u0011\u001f\u0003\u0007\u0005s\u00170\u0001\u0006nK6|'/_*ju\u0016,\u0012!\u0012\t\u0004\r\u001eKU\"\u0001\u0013\n\u0005!##AB(viB,H\u000f\u0005\u0002<\u0015&\u00111\n\u0010\u0002\u0004\u0013:$\u0018aC7f[>\u0014\u0018pU5{K\u0002\nQ\"\\3n_JLx+Z5hQR\u001cX#A(\u0011\u0007\u0019;U'\u0001\bnK6|'/_,fS\u001eDGo\u001d\u0011\u0002\u001bA\u0014xNY1cS2LG/\u001f$o+\u0005\u0019\u0006\u0003B\u001eU\u001f>K!!\u0016\u001f\u0003\u0013\u0019+hn\u0019;j_:\f\u0014A\u00049s_\n\f'-\u001b7jif4e\u000eI\u0001\fg\u000e\fG.\u001a$bGR|'/\u0001\u0007tG\u0006dWMR1di>\u0014\b%\u0001\btG>\u0014X-T1tWZ\u000bG.^3\u0016\u0003m\u00032AR$]!\tYT,\u0003\u0002_y\t)a\t\\8bi\u0006y1oY8sK6\u000b7o\u001b,bYV,\u0007%\u0003\u0002Zg\u0005!a.Y7f+\u0005\u0019\u0007C\u00013l\u001d\t)\u0017\u000e\u0005\u0002gy5\tqM\u0003\u0002i]\u00051AH]8pizJ!A\u001b\u001f\u0002\rA\u0013X\rZ3g\u0013\taWN\u0001\u0004TiJLgn\u001a\u0006\u0003Ur\nQA\\1nK\u0002\n!\"\u001a<jI\u0016t7-\u001a\u00132!\u0011\t\u0018\u0011B\u001b\u000f\u0007I\f\u0019A\u0004\u0002t}:\u0011A\u000f \b\u0003knt!A\u001e>\u000f\u0005]LhB\u00014y\u0013\u0005i\u0013BA\u0016-\u0013\tI#&\u0003\u0002(Q%\u0011QPJ\u0001\u0005G>\u0014X-C\u0002��\u0003\u0003\tQ\u0001^=qKNT!! \u0014\n\t\u0005\u0015\u0011qA\u0001\ba\u0006\u001c7.Y4f\u0015\ry\u0018\u0011A\u0005\u0005\u0003\u0017\tiA\u0001\u0002U\r*!\u0011QAA\u0004\u0003))g/\u001b3f]\u000e,GE\r\t\u0005c\u0006MQ'\u0003\u0003\u0002\u0016\u00055!!C%t\t\u0016\u001c\u0017.\\1m\u0003\u0019a\u0014N\\5u}Qq\u00111DA\u0012\u0003K\t9#!\u000b\u0002,\u00055BCBA\u000f\u0003?\t\t\u0003E\u00023\u0001UBQa\\\bA\u0004ADq!a\u0004\u0010\u0001\b\t\t\u0002C\u0003D\u001f\u0001\u0007Q\tC\u0003N\u001f\u0001\u0007q\nC\u0003R\u001f\u0001\u00071\u000bC\u0004X\u001fA\u0005\t\u0019A(\t\u000fe{\u0001\u0013!a\u00017\"9\u0011m\u0004I\u0001\u0002\u0004\u0019\u0017!C6fsN\u001c\u0006.\u00199f)\u0011\t\u0019$a\u000f\u0011\t\u0005U\u0012qG\u0007\u0003\u0003\u0003IA!!\u000f\u0002\u0002\t)1\u000b[1qK\"9\u0011Q\b\tA\u0002\u0005M\u0012a\u0003<bYV,7o\u00155ba\u0016\fAa[3zgR)q*a\u0011\u0002V!9\u0011QI\tA\u0002\u0005\u001d\u0013AB7f[>\u0014\u0018\u0010E\u0003\u0002J\u0005=SGD\u00023\u0003\u0017J1!!\u0014!\u0003%\tE\u000f^3oi&|g.\u0003\u0003\u0002R\u0005M#AB'f[>\u0014\u0018PC\u0002\u0002N\u0001Ba!a\u0016\u0012\u0001\u0004y\u0015A\u0002<bYV,7/A\u0003tG>\u0014X\rF\u0003P\u0003;\n\t\u0007\u0003\u0004\u0002`I\u0001\raT\u0001\u0006cV,'/\u001f\u0005\b\u0003G\u0012\u0002\u0019AA3\u0003\u0015\u0019H/\u0019;f!\u0019\tI%a\u001a6\u001f&!\u0011\u0011NA*\u0005\u0015\u0019F/\u0019;fQ\u0015\u0011\u0012QNAD!\u0015Y\u0014qNA:\u0013\r\t\t\b\u0010\u0002\u0007i\"\u0014xn^:\u0011\t\u0005U\u0014\u0011\u0011\b\u0005\u0003o\nYHD\u0002t\u0003sJA!!\u0002\u0002\u0002%!\u0011QPA@\u0003%)\u0007pY3qi&|gN\u0003\u0003\u0002\u0006\u0005\u0005\u0011\u0002BAB\u0003\u000b\u0013\u0001$\u00138wC2LG-\u0011:hk6,g\u000e^#yG\u0016\u0004H/[8o\u0015\u0011\ti(a 2\ry\u0019\u0017\u0011RAXc%\u0019\u00131RAH\u0003K\u000b\t*F\u0002c\u0003\u001b#a\u0001\u000f\u0018C\u0002\u0005]\u0015\u0002BAI\u0003'\u000b1\u0004\n7fgNLg.\u001b;%OJ,\u0017\r^3sI\u0011,g-Y;mi\u0012\n$bAAKy\u00051A\u000f\u001b:poN\f2AOAM!\u0011\tY*a(\u000f\u0007m\ni*C\u0002\u0002\u0006qJA!!)\u0002$\nIA\u000b\u001b:po\u0006\u0014G.\u001a\u0006\u0004\u0003\u000ba\u0014'C\u0012\u0002(\u0006%\u00161VAK\u001d\rY\u0014\u0011V\u0005\u0004\u0003+c\u0014'\u0002\u0012<y\u00055&!B:dC2\f\u0017g\u0001\u0014\u0002t\u0005Y\u0001O]8cC\nLG.\u001b;z)\u0015y\u0015QWA\\\u0011\u0019\tIf\u0005a\u0001\u001f\"9\u00111M\nA\u0002\u0005\u0015\u0014A\u0004'v_:<\u0017\t\u001e;f]RLwN\u001c\t\u0003eU\u00192!FA`!\rY\u0014\u0011Y\u0005\u0004\u0003\u0007d$AB!osJ+g\r\u0006\u0002\u0002<\u0006)\u0011\r\u001d9msV!\u00111ZAj)9\ti-!9\u0002d\u0006\u001d\u00181^Ax\u0003c$b!a4\u0002V\u0006m\u0007\u0003\u0002\u001a\u0001\u0003#\u00042ANAj\t\u0015AtC1\u0001:\u0011%\t9nFA\u0001\u0002\b\tI.\u0001\u0006fm&$WM\\2fIM\u0002R!]A\u0005\u0003#D\u0011\"!8\u0018\u0003\u0003\u0005\u001d!a8\u0002\u0015\u00154\u0018\u000eZ3oG\u0016$C\u0007E\u0003r\u0003'\t\t\u000eC\u0003D/\u0001\u0007Q\t\u0003\u0004N/\u0001\u0007\u0011Q\u001d\t\u0005\r\u001e\u000b\t\u000e\u0003\u0005R/A\u0005\t\u0019AAu!\u0019YD+!:\u0002f\"I\u0011Q^\f\u0011\u0002\u0003\u0007\u0011Q]\u0001\rg\u000e\fG.Z,fS\u001eDGo\u001d\u0005\b3^\u0001\n\u00111\u0001\\\u0011\u001d\tw\u0003%AA\u0002\r\fq\"\u00199qYf$C-\u001a4bk2$HeM\u000b\u0005\u0003o\u0014\t\"\u0006\u0002\u0002z*\"\u00111`A��\u001f\t\tiP\t\u0001,\u0005\t\u0005\u0001\u0003\u0002B\u0002\u0005\u001bi!A!\u0002\u000b\t\t\u001d!\u0011B\u0001\nk:\u001c\u0007.Z2lK\u0012T1Aa\u0003=\u0003)\tgN\\8uCRLwN\\\u0005\u0005\u0005\u001f\u0011)AA\tv]\u000eDWmY6fIZ\u000b'/[1oG\u0016$Q\u0001\u000f\rC\u0002e\nq\"\u00199qYf$C-\u001a4bk2$H\u0005N\u000b\u0005\u0003o\u00149\u0002B\u000393\t\u0007\u0011(A\bbaBd\u0017\u0010\n3fM\u0006,H\u000e\u001e\u00136+\u0011\u0011iB!\t\u0016\u0005\t}!fA.\u0002��\u0012)\u0001H\u0007b\u0001s\u0005y\u0011\r\u001d9ms\u0012\"WMZ1vYR$c'\u0006\u0003\u0003(\t-RC\u0001B\u0015U\r\u0019\u0017q \u0003\u0006qm\u0011\r!O\u0001\u001cI1,7o]5oSR$sM]3bi\u0016\u0014H\u0005Z3gCVdG\u000f\n\u001b\u0016\t\u0005](\u0011\u0007\u0003\u0006qq\u0011\r!O\u0001\u001cI1,7o]5oSR$sM]3bi\u0016\u0014H\u0005Z3gCVdG\u000fJ\u001b\u0016\t\tu!q\u0007\u0003\u0006qu\u0011\r!O\u0001\u001cI1,7o]5oSR$sM]3bi\u0016\u0014H\u0005Z3gCVdG\u000f\n\u001c\u0016\t\t\u001d\"Q\b\u0003\u0006qy\u0011\r!\u000f")
/* loaded from: input_file:org/platanios/tensorflow/api/ops/rnn/attention/LuongAttention.class */
public class LuongAttention<T> extends SimpleAttention<T> {
    private final Output<Object> memorySize;
    private final Output<T> memoryWeights;
    private final Function1<Output<T>, Output<T>> probabilityFn;
    private final Output<T> scaleFactor;
    private final String name;
    private final Cpackage.TF<T> evidence$1;
    private final $less.colon.less<Function1<Function1<T, Nothing$>, Nothing$>, Function1<Function1<Cpackage.TruncatedHalf, Nothing$>, Nothing$>> evidence$2;

    public static <T> LuongAttention<T> apply(Output<Object> output, Output<T> output2, Function1<Output<T>, Output<T>> function1, Output<T> output3, Output<Object> output4, String str, Cpackage.TF<T> tf, $less.colon.less<Function1<Function1<T, Nothing$>, Nothing$>, Function1<Function1<Cpackage.TruncatedHalf, Nothing$>, Nothing$>> lessVar) {
        return LuongAttention$.MODULE$.apply(output, output2, function1, output3, output4, str, tf, lessVar);
    }

    @Override // org.platanios.tensorflow.api.ops.rnn.attention.SimpleAttention, org.platanios.tensorflow.api.ops.rnn.attention.Attention
    public Output<Object> memorySize() {
        return this.memorySize;
    }

    public Output<T> memoryWeights() {
        return this.memoryWeights;
    }

    public Function1<Output<T>, Output<T>> probabilityFn() {
        return this.probabilityFn;
    }

    public Output<T> scaleFactor() {
        return this.scaleFactor;
    }

    @Override // org.platanios.tensorflow.api.ops.rnn.attention.SimpleAttention
    public Output<Object> scoreMaskValue() {
        return super.scoreMaskValue();
    }

    @Override // org.platanios.tensorflow.api.ops.rnn.attention.SimpleAttention, org.platanios.tensorflow.api.ops.rnn.attention.Attention
    public String name() {
        return this.name;
    }

    @Override // org.platanios.tensorflow.api.ops.rnn.attention.Attention
    public Shape keysShape(Shape shape) {
        return shape.apply(IndexerConstructionWithTwoNumbers$.MODULE$.indexerConstructionToIndex(Implicits$.MODULE$.intToIndexerConstruction(-1).$colon$colon(Implicits$.MODULE$.intToIndexerConstruction(0)))).$plus(memoryWeights().shape().apply(-1));
    }

    /* JADX WARN: Multi-variable type inference failed */
    @Override // org.platanios.tensorflow.api.ops.rnn.attention.SimpleAttention
    public Output<T> keys(Attention.Memory<T> memory, Output<T> output) {
        if (output.rank() != 3) {
            return Math$.MODULE$.matmul(output, memoryWeights(), Math$.MODULE$.matmul$default$3(), Math$.MODULE$.matmul$default$4(), Math$.MODULE$.matmul$default$5(), Math$.MODULE$.matmul$default$6(), Math$.MODULE$.matmul$default$7(), Math$.MODULE$.matmul$default$8(), Math$.MODULE$.matmul$default$9(), this.evidence$1, this.evidence$2);
        }
        Output<Object> shape = Basic$.MODULE$.shape(output, Basic$.MODULE$.shape$default$2(), Basic$.MODULE$.shape$default$3(), this.evidence$1);
        Output<T> reshape = Basic$.MODULE$.reshape(Math$.MODULE$.matmul(Basic$.MODULE$.reshape(output, Basic$.MODULE$.stack(Seq$.MODULE$.apply(ScalaRunTime$.MODULE$.wrapRefArray(new Output[]{Basic$.MODULE$.constant(Implicits$.MODULE$.intToTensor(-1), Basic$.MODULE$.constant$default$2(), Basic$.MODULE$.constant$default$3()), shape.apply(Implicits$.MODULE$.intToIndex(-1), Nil$.MODULE$)})), Basic$.MODULE$.stack$default$2(), Basic$.MODULE$.stack$default$3(), package$TF$.MODULE$.intEvTF()), Basic$.MODULE$.reshape$default$3(), this.evidence$1, package$TF$.MODULE$.intEvTF(), $less$colon$less$.MODULE$.refl()), memoryWeights(), Math$.MODULE$.matmul$default$3(), Math$.MODULE$.matmul$default$4(), Math$.MODULE$.matmul$default$5(), Math$.MODULE$.matmul$default$6(), Math$.MODULE$.matmul$default$7(), Math$.MODULE$.matmul$default$8(), Math$.MODULE$.matmul$default$9(), this.evidence$1, this.evidence$2), Basic$.MODULE$.concatenate(Seq$.MODULE$.apply(ScalaRunTime$.MODULE$.wrapRefArray(new Output[]{shape.apply(IndexerConstructionWithTwoNumbers$.MODULE$.indexerConstructionToIndex(Implicits$.MODULE$.intToIndexerConstruction(-1).$colon$colon(Implicits$.MODULE$.intToIndexerConstruction(0))), Nil$.MODULE$), Basic$.MODULE$.shape(memoryWeights(), Basic$.MODULE$.shape$default$2(), Basic$.MODULE$.shape$default$3(), this.evidence$1).slice(Implicits$.MODULE$.intToIndex(-1), ScalaRunTime$.MODULE$.wrapRefArray(new Indexer[]{NewAxis$.MODULE$}))})), Implicits$.MODULE$.intToOutput(0), Basic$.MODULE$.concatenate$default$3(), package$TF$.MODULE$.intEvTF()), Basic$.MODULE$.reshape$default$3(), this.evidence$1, package$TF$.MODULE$.intEvTF(), $less$colon$less$.MODULE$.refl());
        reshape.setShape(output.shape().apply(IndexerConstructionWithTwoNumbers$.MODULE$.indexerConstructionToIndex(Implicits$.MODULE$.intToIndexerConstruction(-1).$colon$colon(Implicits$.MODULE$.intToIndexerConstruction(0)))).$plus(memoryWeights().shape().apply(-1)));
        return reshape;
    }

    @Override // org.platanios.tensorflow.api.ops.rnn.attention.SimpleAttention
    public Output<T> score(Output<T> output, Attention.State<T, Output<T>> state) throws InvalidArgumentException {
        int apply = output.shape().apply(-1);
        int apply2 = state.keys().shape().apply(-1);
        if (apply != apply2) {
            throw package$exception$.MODULE$.InvalidArgumentException().apply(new StringBuilder(161).append("Incompatible or unknown inner dimensions between query and keys. ").append(new StringBuilder(21).append("Query (").append(output.name()).append(") has ").append(apply).append(" units. ").toString()).append(new StringBuilder(21).append("Keys (").append(state.keys().name()).append(") have ").append(apply2).append(" units. ").toString()).append("Perhaps you need to set the number of units of the attention model ").append("to the keys' number of units.").toString());
        }
        Output<T> squeeze = Basic$.MODULE$.squeeze(Math$.MODULE$.matmul(Implicits$.MODULE$.outputBasicOps(output).expandDims(Implicits$.MODULE$.intToOutput(1)), state.keys(), Math$.MODULE$.matmul$default$3(), true, Math$.MODULE$.matmul$default$5(), Math$.MODULE$.matmul$default$6(), Math$.MODULE$.matmul$default$7(), Math$.MODULE$.matmul$default$8(), Math$.MODULE$.matmul$default$9(), this.evidence$1, this.evidence$2), (Seq) Seq$.MODULE$.apply(ScalaRunTime$.MODULE$.wrapIntArray(new int[]{1})), Basic$.MODULE$.squeeze$default$3(), this.evidence$1);
        return scaleFactor() == null ? squeeze : scaleFactor().$times(squeeze, this.evidence$2);
    }

    @Override // org.platanios.tensorflow.api.ops.rnn.attention.SimpleAttention
    public Output<T> probability(Output<T> output, Attention.State<T, Output<T>> state) {
        return (Output) probabilityFn().apply(output);
    }

    /* JADX WARN: 'super' call moved to the top of the method (can break code semantics) */
    public LuongAttention(Output<Object> output, Output<T> output2, Function1<Output<T>, Output<T>> function1, Output<T> output3, Output<Object> output4, String str, Cpackage.TF<T> tf, $less.colon.less<Function1<Function1<T, Nothing$>, Nothing$>, Function1<Function1<Cpackage.TruncatedHalf, Nothing$>, Nothing$>> lessVar) {
        super(output, output4, str, tf, lessVar);
        this.memorySize = output;
        this.memoryWeights = output2;
        this.probabilityFn = function1;
        this.scaleFactor = output3;
        this.name = str;
        this.evidence$1 = tf;
        this.evidence$2 = lessVar;
    }
}
