package org.platanios.tensorflow.api.ops.seq2seq.decoders;

import org.platanios.tensorflow.api.core.Indexer;
import org.platanios.tensorflow.api.core.IndexerConstructionWithTwoNumbers$;
import org.platanios.tensorflow.api.core.NewAxis$;
import org.platanios.tensorflow.api.core.Shape;
import org.platanios.tensorflow.api.core.Shape$;
import org.platanios.tensorflow.api.core.package$exception$;
import org.platanios.tensorflow.api.core.package$exception$InvalidShapeException;
import org.platanios.tensorflow.api.core.package$exception$InvalidShapeException$;
import org.platanios.tensorflow.api.implicits.Implicits$;
import org.platanios.tensorflow.api.ops.Basic;
import org.platanios.tensorflow.api.ops.Basic$;
import org.platanios.tensorflow.api.ops.Math;
import org.platanios.tensorflow.api.ops.Math$;
import org.platanios.tensorflow.api.ops.Op;
import org.platanios.tensorflow.api.ops.Op$;
import org.platanios.tensorflow.api.ops.Output;
import org.platanios.tensorflow.api.ops.Output$;
import org.platanios.tensorflow.api.ops.Symbol;
import org.platanios.tensorflow.api.ops.TensorArray;
import org.platanios.tensorflow.api.ops.TensorArray$;
import org.platanios.tensorflow.api.ops.control_flow.ControlFlow$;
import org.platanios.tensorflow.api.ops.control_flow.WhileLoopVariable;
import org.platanios.tensorflow.api.ops.control_flow.WhileLoopVariable$;
import org.platanios.tensorflow.api.ops.rnn.cell.RNNCell;
import org.platanios.tensorflow.api.ops.seq2seq.decoders.BeamSearchDecoder;
import org.platanios.tensorflow.api.tensors.Tensor;
import org.platanios.tensorflow.api.tensors.Tensor$;
import org.platanios.tensorflow.api.tensors.TensorConvertible$;
import org.platanios.tensorflow.api.types.DataType;
import org.platanios.tensorflow.api.types.SupportedType$;
import org.platanios.tensorflow.jni.InvalidArgumentException;
import scala.Function1;
import scala.MatchError;
import scala.None$;
import scala.Predef$;
import scala.StringContext;
import scala.Tuple2;
import scala.collection.Seq;
import scala.collection.Seq$;
import scala.collection.immutable.Nil$;
import scala.collection.mutable.ArrayBuffer;
import scala.collection.mutable.ArrayBuffer$;
import scala.runtime.BoxesRunTime;
import scala.runtime.ObjectRef;
import shapeless.$colon;
import shapeless.Generic;
import shapeless.HNil;
import shapeless.HNil$;
import shapeless.Lazy;
import shapeless.Lazy$;
import shapeless.ops.hlist$Tupler$;

/* compiled from: BeamSearchDecoder.scala */
/* loaded from: input_file:org/platanios/tensorflow/api/ops/seq2seq/decoders/BeamSearchDecoder$.class */
public final class BeamSearchDecoder$ {
    public static BeamSearchDecoder$ MODULE$;

    static {
        new BeamSearchDecoder$();
    }

    public <S, SS> LengthPenalty $lessinit$greater$default$7() {
        return NoPenalty$.MODULE$;
    }

    public <S, SS> Function1<BeamSearchDecoder.Output, BeamSearchDecoder.Output> $lessinit$greater$default$8() {
        return output -> {
            return output;
        };
    }

    public <S, SS> String $lessinit$greater$default$9() {
        return "BeamSearchRNNDecoder";
    }

    public <S, SS> BeamSearchDecoder<S, SS> apply(RNNCell<Output, Shape, S, SS> rNNCell, S s, Function1<Output, Output> function1, Output output, Output output2, int i, LengthPenalty lengthPenalty, Function1<Output, Output> function12, String str, WhileLoopVariable<Output> whileLoopVariable, WhileLoopVariable<S> whileLoopVariable2) {
        return new BeamSearchDecoder<>(rNNCell, s, function1, output, output2, i, lengthPenalty, function12, str, whileLoopVariable, whileLoopVariable2);
    }

    public <S, SS> LengthPenalty apply$default$7() {
        return NoPenalty$.MODULE$;
    }

    public <S, SS> Function1<Output, Output> apply$default$8() {
        return output -> {
            return output;
        };
    }

    public <S, SS> String apply$default$9() {
        return "BeamSearchRNNDecoder";
    }

    public Output maskLogProbabilities(Output output, Output output2, Output output3) {
        Output apply = Basic$.MODULE$.shape(output, Basic$.MODULE$.shape$default$2(), Basic$.MODULE$.shape$default$3(), Basic$.MODULE$.shape$default$4()).apply(Predef$.MODULE$.wrapRefArray(new Indexer[]{Implicits$.MODULE$.intToIndex(2)}));
        DataType dataType = output.dataType();
        return Math$.MODULE$.select(Basic$.MODULE$.tile(Implicits$.MODULE$.outputToBasicOps(output3).expandDims(Implicits$.MODULE$.tensorConvertibleToOutput(BoxesRunTime.boxToInteger(2), TensorConvertible$.MODULE$.supportedTypeTensorConvertible(SupportedType$.MODULE$.intIsSupportedType()))), Basic$.MODULE$.stack((Seq) Seq$.MODULE$.apply(Predef$.MODULE$.wrapRefArray(new Output[]{Implicits$.MODULE$.tensorConvertibleToOutput(BoxesRunTime.boxToInteger(1), TensorConvertible$.MODULE$.supportedTypeTensorConvertible(SupportedType$.MODULE$.intIsSupportedType())), Implicits$.MODULE$.tensorConvertibleToOutput(BoxesRunTime.boxToInteger(1), TensorConvertible$.MODULE$.supportedTypeTensorConvertible(SupportedType$.MODULE$.intIsSupportedType())), apply})), Basic$.MODULE$.stack$default$2(), Basic$.MODULE$.stack$default$3()), Basic$.MODULE$.tile$default$3()), Basic$.MODULE$.tile(Implicits$.MODULE$.outputToBasicOps(Basic$.MODULE$.oneHot(output2, apply, Basic$.MODULE$.zeros(dataType, Implicits$.MODULE$.tensorConvertibleToOutput(Shape$.MODULE$.apply((Seq<Object>) Nil$.MODULE$), TensorConvertible$.MODULE$.shapeTensorConvertible()), Basic$.MODULE$.zeros$default$3()), Basic$.MODULE$.constant(Tensor$.MODULE$.apply(dataType.mo682min(), Predef$.MODULE$.genericWrapArray(new Object[0]), TensorConvertible$.MODULE$.supportedTypeTensorConvertible(dataType.supportedType())).slice(Predef$.MODULE$.wrapRefArray(new Indexer[]{Implicits$.MODULE$.intToIndex(0)})), Basic$.MODULE$.constant$default$2(), Basic$.MODULE$.constant$default$3(), Basic$.MODULE$.constant$default$4(), org.platanios.tensorflow.api.package$.MODULE$.tensorEagerExecutionContext()), Basic$.MODULE$.oneHot$default$5(), Basic$.MODULE$.oneHot$default$6(), Basic$.MODULE$.oneHot$default$7())).reshape(Implicits$.MODULE$.tensorConvertibleToOutput(Shape$.MODULE$.apply((Seq<Object>) Predef$.MODULE$.wrapIntArray(new int[]{1, 1, -1})), TensorConvertible$.MODULE$.shapeTensorConvertible())), Basic$.MODULE$.concatenate((Seq) Seq$.MODULE$.apply(Predef$.MODULE$.wrapRefArray(new Output[]{Basic$.MODULE$.shape(output3, Basic$.MODULE$.shape$default$2(), Basic$.MODULE$.shape$default$3(), Basic$.MODULE$.shape$default$4()), Implicits$.MODULE$.tensorConvertibleToOutput(Tensor$.MODULE$.apply(BoxesRunTime.boxToInteger(1), Predef$.MODULE$.wrapIntArray(new int[0]), TensorConvertible$.MODULE$.supportedTypeTensorConvertible(SupportedType$.MODULE$.intIsSupportedType())), TensorConvertible$.MODULE$.tensorLikeTensorConvertible())})), Implicits$.MODULE$.tensorConvertibleToOutput(BoxesRunTime.boxToInteger(0), TensorConvertible$.MODULE$.supportedTypeTensorConvertible(SupportedType$.MODULE$.intIsSupportedType())), Basic$.MODULE$.concatenate$default$3()), Basic$.MODULE$.tile$default$3()), output, Math$.MODULE$.select$default$4());
    }

    public Output gatherTree(Output output, Output output2, Output output3, Output output4, String str) {
        return new Op.Builder("GatherTree", str, org.platanios.tensorflow.api.package$.MODULE$.opCreationContext()).addInput(output).addInput(output2).addInput(output3).addInput(output4).build().outputs()[0];
    }

    public Symbol tileBatch(Symbol symbol, int i) throws InvalidArgumentException {
        Symbol symbol2;
        Output output;
        if (symbol instanceof Output) {
            Output output2 = (Output) symbol;
            if (output2.rank() == -1) {
                throw package$exception$.MODULE$.InvalidArgumentException().apply("The provided tensor must have statically known rank.");
            }
            if (output2.rank() == 0) {
                Output tile = Basic$.MODULE$.tile(Implicits$.MODULE$.outputToBasicOps(output2).expandDims(Implicits$.MODULE$.tensorConvertibleToOutput(BoxesRunTime.boxToInteger(0), TensorConvertible$.MODULE$.supportedTypeTensorConvertible(SupportedType$.MODULE$.intIsSupportedType()))), Implicits$.MODULE$.tensorConvertibleToOutput(Tensor$.MODULE$.apply(BoxesRunTime.boxToInteger(i), Predef$.MODULE$.wrapIntArray(new int[0]), TensorConvertible$.MODULE$.supportedTypeTensorConvertible(SupportedType$.MODULE$.intIsSupportedType())), TensorConvertible$.MODULE$.tensorLikeTensorConvertible()), Basic$.MODULE$.tile$default$3());
                tile.setShape(Shape$.MODULE$.apply((Seq<Object>) Predef$.MODULE$.wrapIntArray(new int[]{i})));
                output = tile;
            } else {
                Output shape = Basic$.MODULE$.shape(output2, Basic$.MODULE$.shape$default$2(), Basic$.MODULE$.shape$default$3(), Basic$.MODULE$.shape$default$4());
                ArrayBuffer fill = ArrayBuffer$.MODULE$.fill(output2.rank() + 1, () -> {
                    return 1;
                });
                fill.update(1, BoxesRunTime.boxToInteger(i));
                int apply = output2.shape().apply(0) != -1 ? output2.shape().apply(0) * i : -1;
                Output reshape = Implicits$.MODULE$.outputToBasicOps(Basic$.MODULE$.tile(Implicits$.MODULE$.outputToBasicOps(output2).expandDims(Implicits$.MODULE$.tensorConvertibleToOutput(BoxesRunTime.boxToInteger(1), TensorConvertible$.MODULE$.supportedTypeTensorConvertible(SupportedType$.MODULE$.intIsSupportedType()))), Implicits$.MODULE$.tensorConvertibleToOutput(fill, TensorConvertible$.MODULE$.traversableTensorConvertible(TensorConvertible$.MODULE$.supportedTypeTensorConvertible(SupportedType$.MODULE$.intIsSupportedType()))), Basic$.MODULE$.tile$default$3())).reshape(Basic$.MODULE$.concatenate((Seq) Seq$.MODULE$.apply(Predef$.MODULE$.wrapRefArray(new Output[]{Implicits$.MODULE$.outputToBasicOps(Implicits$.MODULE$.outputToMathOps(shape.apply(Predef$.MODULE$.wrapRefArray(new Indexer[]{Implicits$.MODULE$.intToIndex(0)}))).$times(Implicits$.MODULE$.tensorConvertibleToOutput(BoxesRunTime.boxToInteger(i), TensorConvertible$.MODULE$.supportedTypeTensorConvertible(SupportedType$.MODULE$.intIsSupportedType())))).expandDims(Implicits$.MODULE$.tensorConvertibleToOutput(BoxesRunTime.boxToInteger(0), TensorConvertible$.MODULE$.supportedTypeTensorConvertible(SupportedType$.MODULE$.intIsSupportedType()))), shape.apply(Predef$.MODULE$.wrapRefArray(new Indexer[]{Implicits$.MODULE$.intToIndexerConstruction(1).$colon$colon()}))})), Basic$.MODULE$.concatenate$default$2(), Basic$.MODULE$.concatenate$default$3()));
                if (output2.rank() > 1) {
                    reshape.setShape(Shape$.MODULE$.apply((Seq<Object>) Predef$.MODULE$.wrapIntArray(new int[]{apply})).$plus$plus(output2.shape().apply(Implicits$.MODULE$.intToIndexerConstruction(1).$colon$colon())));
                } else {
                    reshape.setShape(Shape$.MODULE$.apply((Seq<Object>) Predef$.MODULE$.wrapIntArray(new int[]{apply})));
                }
                output = reshape;
            }
            symbol2 = output;
        } else {
            if (!(symbol instanceof TensorArray)) {
                throw package$exception$.MODULE$.InvalidArgumentException().apply("Unsupported argument type for use with the beam search decoder.");
            }
            symbol2 = symbol;
        }
        return symbol2;
    }

    public <S> S tileForBeamSearch(S s, int i, WhileLoopVariable<S> whileLoopVariable) throws InvalidArgumentException {
        return whileLoopVariable.map(s, symbol -> {
            return MODULE$.tileBatch(symbol, i);
        });
    }

    public Symbol maybeSplitBatchBeams(Symbol symbol, Shape shape, Output output, int i) throws InvalidArgumentException {
        Symbol symbol2;
        if (symbol instanceof Output) {
            Output output2 = (Output) symbol;
            if (output2.rank() == -1) {
                throw package$exception$.MODULE$.InvalidArgumentException().apply(new StringContext(Predef$.MODULE$.wrapRefArray(new String[]{"Expected tensor (", ") to have known rank, but it was unknown."})).s(Predef$.MODULE$.genericWrapArray(new Object[]{output2})));
            }
            symbol2 = output2.rank() == 0 ? symbol : splitBatchBeams(symbol, shape, output, i);
        } else {
            if (!(symbol instanceof TensorArray)) {
                throw package$exception$.MODULE$.InvalidArgumentException().apply("Unsupported argument type for use with the beam search decoder.");
            }
            symbol2 = symbol;
        }
        return symbol2;
    }

    public Symbol splitBatchBeams(Symbol symbol, Shape shape, Output output, int i) throws InvalidArgumentException, package$exception$InvalidShapeException {
        Symbol symbol2;
        Tuple2 tuple2 = new Tuple2(symbol, shape);
        if (tuple2 != null) {
            Symbol symbol3 = (Symbol) tuple2._1();
            Shape shape2 = (Shape) tuple2._2();
            if (symbol3 instanceof Output) {
                Output output2 = (Output) symbol3;
                if (shape2 != null) {
                    Output reshape = Basic$.MODULE$.reshape(output2, Basic$.MODULE$.concatenate((Seq) Seq$.MODULE$.apply(Predef$.MODULE$.wrapRefArray(new Output[]{output.apply(Predef$.MODULE$.wrapRefArray(new Indexer[]{NewAxis$.MODULE$})), Tensor$.MODULE$.apply(output.dataType(), BoxesRunTime.boxToInteger(i), Predef$.MODULE$.wrapIntArray(new int[0]), TensorConvertible$.MODULE$.supportedTypeTensorConvertible(SupportedType$.MODULE$.intIsSupportedType())).toOutput(), Implicits$.MODULE$.outputToMathOps(Basic$.MODULE$.shape(output2, Basic$.MODULE$.shape$default$2(), Basic$.MODULE$.shape$default$3(), Basic$.MODULE$.shape$default$4()).apply(Predef$.MODULE$.wrapRefArray(new Indexer[]{Implicits$.MODULE$.intToIndexerConstruction(1).$colon$colon()}))).cast(output.dataType())})), Implicits$.MODULE$.tensorConvertibleToOutput(BoxesRunTime.boxToInteger(0), TensorConvertible$.MODULE$.supportedTypeTensorConvertible(SupportedType$.MODULE$.intIsSupportedType())), Basic$.MODULE$.concatenate$default$3()), Basic$.MODULE$.reshape$default$3());
                    Shape $plus$plus = Shape$.MODULE$.apply((Seq<Object>) Predef$.MODULE$.wrapIntArray(new int[]{BoxesRunTime.unboxToInt(Output$.MODULE$.constantValue(output).map(tensor -> {
                        return BoxesRunTime.boxToInteger($anonfun$splitBatchBeams$1(tensor));
                    }).getOrElse(() -> {
                        return -1;
                    })), i})).$plus$plus(shape2);
                    if (!reshape.shape().isCompatibleWith($plus$plus)) {
                        throw new package$exception$InvalidShapeException("Unexpected behavior when reshaping between beam width and batch size. " + new StringContext(Predef$.MODULE$.wrapRefArray(new String[]{"The reshaped tensor has shape: ", ". "})).s(Predef$.MODULE$.genericWrapArray(new Object[]{reshape.shape()})) + new StringContext(Predef$.MODULE$.wrapRefArray(new String[]{"We expected it to have shape [batchSize, beamWidth, depth] == ", ". "})).s(Predef$.MODULE$.genericWrapArray(new Object[]{$plus$plus})) + "Perhaps you forgot to create a zero state with batchSize = encoderBatchSize * beamWidth?", package$exception$InvalidShapeException$.MODULE$.apply$default$2());
                    }
                    reshape.setShape($plus$plus);
                    symbol2 = reshape;
                    return symbol2;
                }
            }
        }
        if (tuple2 == null || !(tuple2._1() instanceof TensorArray)) {
            throw package$exception$.MODULE$.InvalidArgumentException().apply("Unsupported argument type for use with the beam search decoder.");
        }
        symbol2 = symbol;
        return symbol2;
    }

    public Symbol maybeMergeBatchBeams(Symbol symbol, Shape shape, Output output, int i) throws InvalidArgumentException {
        Symbol symbol2;
        if (symbol instanceof Output) {
            Output output2 = (Output) symbol;
            if (output2.rank() == -1) {
                throw package$exception$.MODULE$.InvalidArgumentException().apply(new StringContext(Predef$.MODULE$.wrapRefArray(new String[]{"Expected tensor (", ") to have known rank, but it was unknown."})).s(Predef$.MODULE$.genericWrapArray(new Object[]{output2})));
            }
            symbol2 = output2.rank() == 0 ? symbol : mergeBatchBeams(symbol, shape, output, i);
        } else {
            if (!(symbol instanceof TensorArray)) {
                throw package$exception$.MODULE$.InvalidArgumentException().apply("Unsupported argument type for use with the beam search decoder.");
            }
            symbol2 = symbol;
        }
        return symbol2;
    }

    public Symbol mergeBatchBeams(Symbol symbol, Shape shape, Output output, int i) throws InvalidArgumentException, package$exception$InvalidShapeException {
        Symbol symbol2;
        Tuple2 tuple2 = new Tuple2(symbol, shape);
        if (tuple2 != null) {
            Symbol symbol3 = (Symbol) tuple2._1();
            Shape shape2 = (Shape) tuple2._2();
            if (symbol3 instanceof Output) {
                Output output2 = (Output) symbol3;
                if (shape2 != null) {
                    Output reshape = Basic$.MODULE$.reshape(output2, Basic$.MODULE$.concatenate((Seq) Seq$.MODULE$.apply(Predef$.MODULE$.wrapRefArray(new Output[]{Implicits$.MODULE$.outputToMathOps(output.apply(Predef$.MODULE$.wrapRefArray(new Indexer[]{NewAxis$.MODULE$}))).$times(Tensor$.MODULE$.apply(output.dataType(), BoxesRunTime.boxToInteger(i), Predef$.MODULE$.wrapIntArray(new int[0]), TensorConvertible$.MODULE$.supportedTypeTensorConvertible(SupportedType$.MODULE$.intIsSupportedType())).toOutput()), Implicits$.MODULE$.outputToMathOps(Basic$.MODULE$.shape(output2, Basic$.MODULE$.shape$default$2(), Basic$.MODULE$.shape$default$3(), Basic$.MODULE$.shape$default$4()).apply(Predef$.MODULE$.wrapRefArray(new Indexer[]{Implicits$.MODULE$.intToIndexerConstruction(2).$colon$colon()}))).cast(output.dataType())})), Implicits$.MODULE$.tensorConvertibleToOutput(BoxesRunTime.boxToInteger(0), TensorConvertible$.MODULE$.supportedTypeTensorConvertible(SupportedType$.MODULE$.intIsSupportedType())), Basic$.MODULE$.concatenate$default$3()), Basic$.MODULE$.reshape$default$3());
                    int unboxToInt = BoxesRunTime.unboxToInt(Output$.MODULE$.constantValue(output).map(tensor -> {
                        return BoxesRunTime.boxToInteger($anonfun$mergeBatchBeams$1(tensor));
                    }).getOrElse(() -> {
                        return -1;
                    }));
                    Shape $plus$plus = Shape$.MODULE$.apply((Seq<Object>) Predef$.MODULE$.wrapIntArray(new int[]{unboxToInt != -1 ? unboxToInt * i : -1})).$plus$plus(shape2);
                    if (!reshape.shape().isCompatibleWith($plus$plus)) {
                        throw new package$exception$InvalidShapeException("Unexpected behavior when reshaping between beam width and batch size. " + new StringContext(Predef$.MODULE$.wrapRefArray(new String[]{"The reshaped tensor has shape: ", ". "})).s(Predef$.MODULE$.genericWrapArray(new Object[]{reshape.shape()})) + new StringContext(Predef$.MODULE$.wrapRefArray(new String[]{"We expected it to have shape [batchSize, beamWidth, depth] == ", ". "})).s(Predef$.MODULE$.genericWrapArray(new Object[]{$plus$plus})) + "Perhaps you forgot to create a zero state with batchSize = encoderBatchSize * beamWidth?", package$exception$InvalidShapeException$.MODULE$.apply$default$2());
                    }
                    reshape.setShape($plus$plus);
                    symbol2 = reshape;
                    return symbol2;
                }
            }
        }
        if (tuple2 == null || !(tuple2._1() instanceof TensorArray)) {
            throw package$exception$.MODULE$.InvalidArgumentException().apply("Unsupported argument type for use with the beam search decoder.");
        }
        symbol2 = symbol;
        return symbol2;
    }

    public Symbol maybeGather(Output output, Symbol symbol, Output output2, Output output3, Seq<Output> seq, String str) throws InvalidArgumentException {
        Symbol symbol2;
        if (symbol instanceof Output) {
            Output output4 = (Output) symbol;
            if (output4.rank() == -1) {
                throw package$exception$.MODULE$.InvalidArgumentException().apply(new StringContext(Predef$.MODULE$.wrapRefArray(new String[]{"Expected tensor (", ") to have known rank, but it was unknown."})).s(Predef$.MODULE$.genericWrapArray(new Object[]{output4})));
            }
            symbol2 = output4.rank() < seq.size() ? symbol : gather(output, symbol, output2, output3, seq, str);
        } else {
            if (!(symbol instanceof TensorArray)) {
                throw package$exception$.MODULE$.InvalidArgumentException().apply("Unsupported argument type for use with the beam search decoder.");
            }
            symbol2 = symbol;
        }
        return symbol2;
    }

    public String maybeGather$default$6() {
        return "GatherTensorHelper";
    }

    public Symbol gather(Output output, Symbol symbol, Output output2, Output output3, Seq<Output> seq, String str) throws InvalidArgumentException {
        Symbol symbol2;
        if (symbol instanceof Output) {
            Output output4 = (Output) symbol;
            symbol2 = (Symbol) Op$.MODULE$.createWithNameScope(str, Op$.MODULE$.createWithNameScope$default$2(), () -> {
                Output gather = Basic$.MODULE$.gather(Implicits$.MODULE$.outputToBasicOps(output4).reshape(Basic$.MODULE$.stack(seq, Basic$.MODULE$.stack$default$2(), Basic$.MODULE$.stack$default$3())), Implicits$.MODULE$.outputToBasicOps(Implicits$.MODULE$.outputToMathOps(output).$plus(Implicits$.MODULE$.outputToBasicOps(Implicits$.MODULE$.outputToMathOps(Math$.MODULE$.range(Implicits$.MODULE$.tensorConvertibleToOutput(BoxesRunTime.boxToInteger(0), TensorConvertible$.MODULE$.supportedTypeTensorConvertible(SupportedType$.MODULE$.intIsSupportedType())), output2, Math$.MODULE$.range$default$3(), Math$.MODULE$.range$default$4(), Math$.MODULE$.range$default$5())).$times(output3)).expandDims(Implicits$.MODULE$.tensorConvertibleToOutput(BoxesRunTime.boxToInteger(1), TensorConvertible$.MODULE$.supportedTypeTensorConvertible(SupportedType$.MODULE$.intIsSupportedType()))))).reshape(Implicits$.MODULE$.tensorConvertibleToOutput(Shape$.MODULE$.apply((Seq<Object>) Predef$.MODULE$.wrapIntArray(new int[]{-1})), TensorConvertible$.MODULE$.shapeTensorConvertible())), Basic$.MODULE$.gather$default$3(), Basic$.MODULE$.gather$default$4());
                Output apply = Basic$.MODULE$.shape(output4, Basic$.MODULE$.shape$default$2(), Basic$.MODULE$.shape$default$3(), Basic$.MODULE$.shape$default$4()).apply(Predef$.MODULE$.wrapRefArray(new Indexer[]{IndexerConstructionWithTwoNumbers$.MODULE$.indexerConstructionToIndex(Implicits$.MODULE$.intToIndexerConstruction(1 + seq.size()).$colon$colon(Implicits$.MODULE$.intToIndexerConstruction(0)))}));
                Shape $plus$plus = Shape$.MODULE$.apply((Seq<Object>) Predef$.MODULE$.wrapIntArray(new int[]{BoxesRunTime.unboxToInt(Output$.MODULE$.constantValue(output2).map(tensor -> {
                    return BoxesRunTime.boxToInteger($anonfun$gather$2(tensor));
                }).getOrElse(() -> {
                    return -1;
                }))})).$plus$plus(output4.shape().apply(IndexerConstructionWithTwoNumbers$.MODULE$.indexerConstructionToIndex(Implicits$.MODULE$.intToIndexerConstruction(1 + seq.size()).$colon$colon(Implicits$.MODULE$.intToIndexerConstruction(1)))));
                Output reshape = Basic$.MODULE$.reshape(gather, apply, "Output");
                reshape.setShape($plus$plus);
                return reshape;
            }, org.platanios.tensorflow.api.package$.MODULE$.opCreationContext());
        } else {
            if (!(symbol instanceof TensorArray)) {
                throw package$exception$.MODULE$.InvalidArgumentException().apply("Unsupported argument type for use with the beam search decoder.");
            }
            symbol2 = symbol;
        }
        return symbol2;
    }

    public String gather$default$6() {
        return "GatherTensorHelper";
    }

    public String gatherTree$default$5() {
        return "GatherTree";
    }

    public Symbol maybeSortTensorArrayBeams(Symbol symbol, Output output, Output output2) {
        Symbol symbol2;
        if (symbol instanceof TensorArray) {
            TensorArray tensorArray = (TensorArray) symbol;
            Output apply = Basic$.MODULE$.shape(output2, Basic$.MODULE$.shape$default$2(), Basic$.MODULE$.shape$default$3(), Basic$.MODULE$.shape$default$4()).apply(Predef$.MODULE$.wrapRefArray(new Indexer[]{Implicits$.MODULE$.intToIndex(0)}));
            Output apply2 = Basic$.MODULE$.shape(output2, Basic$.MODULE$.shape$default$2(), Basic$.MODULE$.shape$default$3(), Basic$.MODULE$.shape$default$4()).apply(Predef$.MODULE$.wrapRefArray(new Indexer[]{Implicits$.MODULE$.intToIndex(1)}));
            Output apply3 = Basic$.MODULE$.shape(output2, Basic$.MODULE$.shape$default$2(), Basic$.MODULE$.shape$default$3(), Basic$.MODULE$.shape$default$4()).apply(Predef$.MODULE$.wrapRefArray(new Indexer[]{Implicits$.MODULE$.intToIndex(2)}));
            Output tile = Basic$.MODULE$.tile(Math$.MODULE$.range(Implicits$.MODULE$.tensorConvertibleToOutput(BoxesRunTime.boxToInteger(0), TensorConvertible$.MODULE$.supportedTypeTensorConvertible(SupportedType$.MODULE$.intIsSupportedType())), apply3, Math$.MODULE$.range$default$3(), Math$.MODULE$.range$default$4(), Math$.MODULE$.range$default$5()).apply(Predef$.MODULE$.wrapRefArray(new Indexer[]{NewAxis$.MODULE$, NewAxis$.MODULE$})), Basic$.MODULE$.stack((Seq) Seq$.MODULE$.apply(Predef$.MODULE$.wrapRefArray(new Output[]{apply, apply2, Implicits$.MODULE$.tensorConvertibleToOutput(BoxesRunTime.boxToInteger(1), TensorConvertible$.MODULE$.supportedTypeTensorConvertible(SupportedType$.MODULE$.intIsSupportedType()))})), Basic$.MODULE$.stack$default$2(), Basic$.MODULE$.stack$default$3()), Basic$.MODULE$.tile$default$3());
            Basic.BasicOps outputToBasicOps = Implicits$.MODULE$.outputToBasicOps(Basic$.MODULE$.sequenceMask(output, apply, org.platanios.tensorflow.api.types.package$.MODULE$.INT32(), Basic$.MODULE$.sequenceMask$default$4()));
            Output transpose = outputToBasicOps.transpose(Implicits$.MODULE$.tensorConvertibleToOutput(Seq$.MODULE$.apply(Predef$.MODULE$.wrapIntArray(new int[]{2, 0, 1})), TensorConvertible$.MODULE$.traversableTensorConvertible(TensorConvertible$.MODULE$.supportedTypeTensorConvertible(SupportedType$.MODULE$.intIsSupportedType()))), outputToBasicOps.transpose$default$2());
            Output $plus = Implicits$.MODULE$.outputToMathOps(Implicits$.MODULE$.outputToMathOps(tile).$times(transpose)).$plus(Implicits$.MODULE$.outputToMathOps(Implicits$.MODULE$.outputConvertibleToMathOps(BoxesRunTime.boxToInteger(1), obj -> {
                return $anonfun$maybeSortTensorArrayBeams$1(BoxesRunTime.unboxToInt(obj));
            }).$minus(transpose)).$times(Implicits$.MODULE$.outputToMathOps(apply3).$plus(Implicits$.MODULE$.tensorConvertibleToOutput(BoxesRunTime.boxToInteger(1), TensorConvertible$.MODULE$.supportedTypeTensorConvertible(SupportedType$.MODULE$.intIsSupportedType())))));
            Implicits$ implicits$ = Implicits$.MODULE$;
            Math.MathOps outputToMathOps = Implicits$.MODULE$.outputToMathOps(output);
            ObjectRef create = ObjectRef.create(gatherTree($plus, output2, implicits$.outputToMathOps(outputToMathOps.max(Implicits$.MODULE$.tensorConvertibleToOutput(Tensor$.MODULE$.apply(BoxesRunTime.boxToInteger(1), Predef$.MODULE$.wrapIntArray(new int[0]), TensorConvertible$.MODULE$.supportedTypeTensorConvertible(SupportedType$.MODULE$.intIsSupportedType())), TensorConvertible$.MODULE$.tensorLikeTensorConvertible()), outputToMathOps.max$default$2())).cast(org.platanios.tensorflow.api.types.package$.MODULE$.INT32()), Implicits$.MODULE$.outputToMathOps(apply3).$plus(Implicits$.MODULE$.tensorConvertibleToOutput(BoxesRunTime.boxToInteger(1), TensorConvertible$.MODULE$.supportedTypeTensorConvertible(SupportedType$.MODULE$.intIsSupportedType()))), gatherTree$default$5()));
            create.elem = Math$.MODULE$.select(Implicits$.MODULE$.outputToMathOps(transpose).cast(org.platanios.tensorflow.api.types.package$.MODULE$.BOOLEAN()), (Output) create.elem, tile, Math$.MODULE$.select$default$4());
            Output size = tensorArray.size(tensorArray.size$default$1());
            TensorArray create2 = TensorArray$.MODULE$.create(size, tensorArray.dataType(), false, TensorArray$.MODULE$.create$default$4(), TensorArray$.MODULE$.create$default$5(), TensorArray$.MODULE$.create$default$6(), TensorArray$.MODULE$.create$default$7(), TensorArray$.MODULE$.create$default$8(), TensorArray$.MODULE$.create$default$9());
            Function1 function1 = tuple2 -> {
                return Implicits$.MODULE$.outputToMathOps((Output) tuple2._1()).$less(size);
            };
            Function1 function12 = tuple22 -> {
                Output output3 = (Output) tuple22._1();
                BeamSearchDecoder$ beamSearchDecoder$ = MODULE$;
                Basic.BasicOps outputToBasicOps2 = Implicits$.MODULE$.outputToBasicOps((Output) create.elem);
                return new Tuple2(Implicits$.MODULE$.outputToMathOps(output3).$plus(Implicits$.MODULE$.tensorConvertibleToOutput(BoxesRunTime.boxToInteger(1), TensorConvertible$.MODULE$.supportedTypeTensorConvertible(SupportedType$.MODULE$.intIsSupportedType()))), ((TensorArray) tuple22._2()).write(output3, (Output) beamSearchDecoder$.gather(outputToBasicOps2.gather(output3, outputToBasicOps2.gather$default$2()), tensorArray.read(output3, tensorArray.read$default$2()), apply2, apply3, (Seq) Seq$.MODULE$.apply(Predef$.MODULE$.wrapRefArray(new Output[]{Implicits$.MODULE$.outputToMathOps(apply2).$times(apply3), Implicits$.MODULE$.tensorConvertibleToOutput(BoxesRunTime.boxToInteger(-1), TensorConvertible$.MODULE$.supportedTypeTensorConvertible(SupportedType$.MODULE$.intIsSupportedType()))})), MODULE$.gather$default$6()), ((TensorArray) tuple22._2()).write$default$3()));
            };
            Tuple2 tuple23 = new Tuple2(Basic$.MODULE$.constant(Implicits$.MODULE$.tensorConvertibleToTensor(BoxesRunTime.boxToInteger(0), TensorConvertible$.MODULE$.supportedTypeTensorConvertible(SupportedType$.MODULE$.intIsSupportedType())), Basic$.MODULE$.constant$default$2(), Basic$.MODULE$.constant$default$3(), Basic$.MODULE$.constant$default$4(), org.platanios.tensorflow.api.package$.MODULE$.tensorEagerExecutionContext()), create2);
            None$ whileLoop$default$4 = ControlFlow$.MODULE$.whileLoop$default$4();
            boolean whileLoop$default$6 = ControlFlow$.MODULE$.whileLoop$default$6();
            boolean whileLoop$default$7 = ControlFlow$.MODULE$.whileLoop$default$7();
            Output whileLoop$default$8 = ControlFlow$.MODULE$.whileLoop$default$8();
            String whileLoop$default$9 = ControlFlow$.MODULE$.whileLoop$default$9();
            ControlFlow$ controlFlow$ = ControlFlow$.MODULE$;
            WhileLoopVariable$ whileLoopVariable$ = WhileLoopVariable$.MODULE$;
            Generic<Tuple2<Output, TensorArray>> generic = new Generic<Tuple2<Output, TensorArray>>() { // from class: org.platanios.tensorflow.api.ops.seq2seq.decoders.BeamSearchDecoder$anon$macro$2619$1
                public $colon.colon<Output, $colon.colon<TensorArray, HNil>> to(Tuple2<Output, TensorArray> tuple24) {
                    if (tuple24 != null) {
                        return new $colon.colon<>((Output) tuple24._1(), new $colon.colon((TensorArray) tuple24._2(), HNil$.MODULE$));
                    }
                    throw new MatchError(tuple24);
                }

                public Tuple2<Output, TensorArray> from($colon.colon<Output, $colon.colon<TensorArray, HNil>> colonVar) {
                    if (colonVar != null) {
                        Output output3 = (Output) colonVar.head();
                        $colon.colon tail = colonVar.tail();
                        if (tail != null) {
                            TensorArray tensorArray2 = (TensorArray) tail.head();
                            if (HNil$.MODULE$.equals(tail.tail())) {
                                return new Tuple2<>(output3, tensorArray2);
                            }
                        }
                    }
                    throw new MatchError(colonVar);
                }
            };
            WhileLoopVariable$ whileLoopVariable$2 = WhileLoopVariable$.MODULE$;
            WhileLoopVariable<Output> outputWhileLoopVariable = WhileLoopVariable$.MODULE$.outputWhileLoopVariable();
            Lazy apply4 = Lazy$.MODULE$.apply(() -> {
                return outputWhileLoopVariable;
            });
            WhileLoopVariable$ whileLoopVariable$3 = WhileLoopVariable$.MODULE$;
            WhileLoopVariable<TensorArray> tensorArrayWhileLoopVariable = WhileLoopVariable$.MODULE$.tensorArrayWhileLoopVariable();
            symbol2 = (Symbol) ((Tuple2) controlFlow$.whileLoop(function1, function12, tuple23, whileLoop$default$4, 1, whileLoop$default$6, whileLoop$default$7, whileLoop$default$8, whileLoop$default$9, whileLoopVariable$.productConstructor(generic, whileLoopVariable$2.recursiveConstructor(apply4, whileLoopVariable$3.recursiveConstructor(Lazy$.MODULE$.apply(() -> {
                return tensorArrayWhileLoopVariable;
            }), WhileLoopVariable$.MODULE$.hnil())), hlist$Tupler$.MODULE$.hlistTupler2(), hlist$Tupler$.MODULE$.hlistTupler2(), hlist$Tupler$.MODULE$.hlistTupler2(), new Generic<Tuple2<Shape, Shape>>() { // from class: org.platanios.tensorflow.api.ops.seq2seq.decoders.BeamSearchDecoder$anon$macro$2666$1
                public $colon.colon<Shape, $colon.colon<Shape, HNil>> to(Tuple2<Shape, Shape> tuple24) {
                    if (tuple24 != null) {
                        return new $colon.colon<>((Shape) tuple24._1(), new $colon.colon((Shape) tuple24._2(), HNil$.MODULE$));
                    }
                    throw new MatchError(tuple24);
                }

                public Tuple2<Shape, Shape> from($colon.colon<Shape, $colon.colon<Shape, HNil>> colonVar) {
                    if (colonVar != null) {
                        Shape shape = (Shape) colonVar.head();
                        $colon.colon tail = colonVar.tail();
                        if (tail != null) {
                            Shape shape2 = (Shape) tail.head();
                            if (HNil$.MODULE$.equals(tail.tail())) {
                                return new Tuple2<>(shape, shape2);
                            }
                        }
                    }
                    throw new MatchError(colonVar);
                }
            })))._2();
        } else {
            symbol2 = symbol;
        }
        return symbol2;
    }

    public static final /* synthetic */ int $anonfun$splitBatchBeams$1(Tensor tensor) {
        return BoxesRunTime.unboxToInt(tensor.scalar());
    }

    public static final /* synthetic */ int $anonfun$mergeBatchBeams$1(Tensor tensor) {
        return BoxesRunTime.unboxToInt(tensor.scalar());
    }

    public static final /* synthetic */ int $anonfun$gather$2(Tensor tensor) {
        return BoxesRunTime.unboxToInt(tensor.scalar());
    }

    public static final /* synthetic */ Output $anonfun$maybeSortTensorArrayBeams$1(int i) {
        return Implicits$.MODULE$.tensorConvertibleToOutput(BoxesRunTime.boxToInteger(i), TensorConvertible$.MODULE$.supportedTypeTensorConvertible(SupportedType$.MODULE$.intIsSupportedType()));
    }

    private BeamSearchDecoder$() {
        MODULE$ = this;
    }
}
