package org.platanios.tensorflow.api.ops.seq2seq.decoders;

import com.typesafe.scalalogging.Logger;
import com.typesafe.scalalogging.Logger$;
import org.platanios.tensorflow.api.core.Indexer;
import org.platanios.tensorflow.api.core.IndexerConstructionWithTwoNumbers$;
import org.platanios.tensorflow.api.core.NewAxis$;
import org.platanios.tensorflow.api.core.Shape;
import org.platanios.tensorflow.api.core.Shape$;
import org.platanios.tensorflow.api.core.package$exception$;
import org.platanios.tensorflow.api.core.package$exception$InvalidShapeException;
import org.platanios.tensorflow.api.core.package$exception$InvalidShapeException$;
import org.platanios.tensorflow.api.implicits.Implicits$;
import org.platanios.tensorflow.api.ops.Basic$;
import org.platanios.tensorflow.api.ops.Checks$;
import org.platanios.tensorflow.api.ops.Math$;
import org.platanios.tensorflow.api.ops.Op;
import org.platanios.tensorflow.api.ops.Op$;
import org.platanios.tensorflow.api.ops.Output;
import org.platanios.tensorflow.api.ops.Output$;
import org.platanios.tensorflow.api.ops.OutputConvertible;
import org.platanios.tensorflow.api.ops.TensorArray;
import org.platanios.tensorflow.api.ops.control_flow.WhileLoopVariable;
import org.platanios.tensorflow.api.ops.rnn.cell.RNNCell;
import org.platanios.tensorflow.api.ops.seq2seq.decoders.BeamSearchDecoder;
import org.platanios.tensorflow.api.tensors.Tensor;
import org.platanios.tensorflow.api.tensors.Tensor$;
import org.platanios.tensorflow.api.tensors.TensorConvertible$;
import org.platanios.tensorflow.api.types.DataType;
import org.platanios.tensorflow.api.types.SupportedType$;
import org.platanios.tensorflow.jni.InvalidArgumentException;
import org.slf4j.LoggerFactory;
import scala.Function1;
import scala.Predef$;
import scala.Tuple2;
import scala.collection.Seq;
import scala.collection.Seq$;
import scala.collection.immutable.Nil$;
import scala.collection.mutable.ArrayBuffer;
import scala.collection.mutable.ArrayBuffer$;
import scala.runtime.BoxedUnit;
import scala.runtime.BoxesRunTime;

/* compiled from: BeamSearchDecoder.scala */
/* loaded from: input_file:org/platanios/tensorflow/api/ops/seq2seq/decoders/BeamSearchDecoder$.class */
public final class BeamSearchDecoder$ {
    public static BeamSearchDecoder$ MODULE$;
    private final Logger logger;

    static {
        new BeamSearchDecoder$();
    }

    public <S, SS> LengthPenalty $lessinit$greater$default$7() {
        return NoPenalty$.MODULE$;
    }

    public <S, SS> Function1<BeamSearchDecoder.Output, BeamSearchDecoder.Output> $lessinit$greater$default$8() {
        return output -> {
            return output;
        };
    }

    public <S, SS> boolean $lessinit$greater$default$9() {
        return true;
    }

    public <S, SS> String $lessinit$greater$default$10() {
        return "BeamSearchRNNDecoder";
    }

    public Logger logger() {
        return this.logger;
    }

    public <S, SS> BeamSearchDecoder<S, SS> apply(RNNCell<Output, Shape, S, SS> rNNCell, S s, Function1<Output, Output> function1, Output output, Output output2, int i, LengthPenalty lengthPenalty, Function1<Output, Output> function12, boolean z, String str, WhileLoopVariable<Output> whileLoopVariable, WhileLoopVariable<S> whileLoopVariable2) {
        return new BeamSearchDecoder<>(rNNCell, s, function1, output, output2, i, lengthPenalty, function12, z, str, whileLoopVariable, whileLoopVariable2);
    }

    public <S, SS> LengthPenalty apply$default$7() {
        return NoPenalty$.MODULE$;
    }

    public <S, SS> Function1<Output, Output> apply$default$8() {
        return output -> {
            return output;
        };
    }

    public <S, SS> boolean apply$default$9() {
        return true;
    }

    public <S, SS> String apply$default$10() {
        return "BeamSearchRNNDecoder";
    }

    public Output maskLogProbabilities(Output output, Output output2, Output output3) {
        Output apply = Basic$.MODULE$.shape(output, Basic$.MODULE$.shape$default$2(), Basic$.MODULE$.shape$default$3(), Basic$.MODULE$.shape$default$4()).apply(Implicits$.MODULE$.intToIndex(2), Predef$.MODULE$.wrapRefArray(new Indexer[0]));
        DataType dataType = output.dataType();
        return Math$.MODULE$.select(Basic$.MODULE$.tile(Implicits$.MODULE$.outputToBasicOps(output3).expandDims(Implicits$.MODULE$.tensorConvertibleToOutput(BoxesRunTime.boxToInteger(2), TensorConvertible$.MODULE$.fromSupportedType(SupportedType$.MODULE$.intIsSupported()))), Basic$.MODULE$.stack((Seq) Seq$.MODULE$.apply(Predef$.MODULE$.wrapRefArray(new Output[]{Implicits$.MODULE$.tensorConvertibleToOutput(BoxesRunTime.boxToInteger(1), TensorConvertible$.MODULE$.fromSupportedType(SupportedType$.MODULE$.intIsSupported())), Implicits$.MODULE$.tensorConvertibleToOutput(BoxesRunTime.boxToInteger(1), TensorConvertible$.MODULE$.fromSupportedType(SupportedType$.MODULE$.intIsSupported())), apply})), Basic$.MODULE$.stack$default$2(), Basic$.MODULE$.stack$default$3()), Basic$.MODULE$.tile$default$3()), Basic$.MODULE$.tile(Implicits$.MODULE$.outputToBasicOps(Basic$.MODULE$.oneHot(output2, apply, Basic$.MODULE$.zeros(dataType, Implicits$.MODULE$.tensorConvertibleToOutput(Shape$.MODULE$.apply((Seq<Object>) Nil$.MODULE$), TensorConvertible$.MODULE$.fromShape()), Basic$.MODULE$.zeros$default$3()), Basic$.MODULE$.constant(Tensor$.MODULE$.apply(dataType.mo752min(), Predef$.MODULE$.genericWrapArray(new Object[0]), TensorConvertible$.MODULE$.fromSupportedType(dataType.evSupportedType())).slice(Implicits$.MODULE$.intToIndex(0), Predef$.MODULE$.wrapRefArray(new Indexer[0])), Basic$.MODULE$.constant$default$2(), Basic$.MODULE$.constant$default$3(), Basic$.MODULE$.constant$default$4()), Basic$.MODULE$.oneHot$default$5(), Basic$.MODULE$.oneHot$default$6(), Basic$.MODULE$.oneHot$default$7())).reshape(Implicits$.MODULE$.tensorConvertibleToOutput(Shape$.MODULE$.apply((Seq<Object>) Predef$.MODULE$.wrapIntArray(new int[]{1, 1, -1})), TensorConvertible$.MODULE$.fromShape())), Basic$.MODULE$.concatenate((Seq) Seq$.MODULE$.apply(Predef$.MODULE$.wrapRefArray(new Output[]{Basic$.MODULE$.shape(output3, Basic$.MODULE$.shape$default$2(), Basic$.MODULE$.shape$default$3(), Basic$.MODULE$.shape$default$4()), Implicits$.MODULE$.tensorToOutput(Tensor$.MODULE$.apply(BoxesRunTime.boxToInteger(1), Predef$.MODULE$.wrapIntArray(new int[0]), TensorConvertible$.MODULE$.fromSupportedType(SupportedType$.MODULE$.intIsSupported())))})), Implicits$.MODULE$.tensorConvertibleToOutput(BoxesRunTime.boxToInteger(0), TensorConvertible$.MODULE$.fromSupportedType(SupportedType$.MODULE$.intIsSupported())), Basic$.MODULE$.concatenate$default$3()), Basic$.MODULE$.tile$default$3()), output, Math$.MODULE$.select$default$4());
    }

    public Output gatherTree(Output output, Output output2, Output output3, Output output4, String str) {
        return new Op.Builder("GatherTree", str).addInput(output).addInput(output2).addInput(output3).addInput(output4).build().outputs()[0];
    }

    public OutputConvertible tileBatch(OutputConvertible outputConvertible, int i) throws InvalidArgumentException {
        OutputConvertible outputConvertible2;
        Output output;
        if (outputConvertible instanceof Output) {
            Output output2 = (Output) outputConvertible;
            if (output2.rank() == -1) {
                throw package$exception$.MODULE$.InvalidArgumentException().apply("The provided tensor must have statically known rank.");
            }
            if (output2.rank() == 0) {
                Output tile = Basic$.MODULE$.tile(Implicits$.MODULE$.outputToBasicOps(output2).expandDims(Implicits$.MODULE$.tensorConvertibleToOutput(BoxesRunTime.boxToInteger(0), TensorConvertible$.MODULE$.fromSupportedType(SupportedType$.MODULE$.intIsSupported()))), Implicits$.MODULE$.tensorToOutput(Tensor$.MODULE$.apply(BoxesRunTime.boxToInteger(i), Predef$.MODULE$.wrapIntArray(new int[0]), TensorConvertible$.MODULE$.fromSupportedType(SupportedType$.MODULE$.intIsSupported()))), Basic$.MODULE$.tile$default$3());
                tile.setShape(Shape$.MODULE$.apply((Seq<Object>) Predef$.MODULE$.wrapIntArray(new int[]{i})));
                output = tile;
            } else {
                Output shape = Basic$.MODULE$.shape(output2, Basic$.MODULE$.shape$default$2(), Basic$.MODULE$.shape$default$3(), Basic$.MODULE$.shape$default$4());
                ArrayBuffer fill = ArrayBuffer$.MODULE$.fill(output2.rank() + 1, () -> {
                    return 1;
                });
                fill.update(1, BoxesRunTime.boxToInteger(i));
                int apply = output2.shape().apply(0) != -1 ? output2.shape().apply(0) * i : -1;
                Output reshape = Implicits$.MODULE$.outputToBasicOps(Basic$.MODULE$.tile(Implicits$.MODULE$.outputToBasicOps(output2).expandDims(Implicits$.MODULE$.tensorConvertibleToOutput(BoxesRunTime.boxToInteger(1), TensorConvertible$.MODULE$.fromSupportedType(SupportedType$.MODULE$.intIsSupported()))), Implicits$.MODULE$.tensorConvertibleToOutput(fill, TensorConvertible$.MODULE$.fromTraversable(TensorConvertible$.MODULE$.fromSupportedType(SupportedType$.MODULE$.intIsSupported()))), Basic$.MODULE$.tile$default$3())).reshape(Basic$.MODULE$.concatenate((Seq) Seq$.MODULE$.apply(Predef$.MODULE$.wrapRefArray(new Output[]{Implicits$.MODULE$.outputToBasicOps(Implicits$.MODULE$.outputToMathOps(shape.apply(Implicits$.MODULE$.intToIndex(0), Predef$.MODULE$.wrapRefArray(new Indexer[0]))).$times(Implicits$.MODULE$.tensorConvertibleToOutput(BoxesRunTime.boxToInteger(i), TensorConvertible$.MODULE$.fromSupportedType(SupportedType$.MODULE$.intIsSupported())))).expandDims(Implicits$.MODULE$.tensorConvertibleToOutput(BoxesRunTime.boxToInteger(0), TensorConvertible$.MODULE$.fromSupportedType(SupportedType$.MODULE$.intIsSupported()))), shape.apply(Implicits$.MODULE$.intToIndexerConstruction(1).$colon$colon(), Predef$.MODULE$.wrapRefArray(new Indexer[0]))})), Basic$.MODULE$.concatenate$default$2(), Basic$.MODULE$.concatenate$default$3()));
                if (output2.rank() > 1) {
                    reshape.setShape(Shape$.MODULE$.apply((Seq<Object>) Predef$.MODULE$.wrapIntArray(new int[]{apply})).$plus$plus(output2.shape().apply(Implicits$.MODULE$.intToIndexerConstruction(1).$colon$colon())));
                } else {
                    reshape.setShape(Shape$.MODULE$.apply((Seq<Object>) Predef$.MODULE$.wrapIntArray(new int[]{apply})));
                }
                output = reshape;
            }
            outputConvertible2 = output;
        } else {
            if (!(outputConvertible instanceof TensorArray)) {
                throw package$exception$.MODULE$.InvalidArgumentException().apply("Unsupported argument type for use with the beam search decoder.");
            }
            outputConvertible2 = outputConvertible;
        }
        return outputConvertible2;
    }

    public <S> S tileForBeamSearch(S s, int i, WhileLoopVariable<S> whileLoopVariable) throws InvalidArgumentException {
        return whileLoopVariable.map(s, outputConvertible -> {
            return MODULE$.tileBatch(outputConvertible, i);
        });
    }

    public OutputConvertible maybeSplitBatchBeams(OutputConvertible outputConvertible, Shape shape, Output output, int i) throws InvalidArgumentException {
        OutputConvertible outputConvertible2;
        if (outputConvertible instanceof Output) {
            Output output2 = (Output) outputConvertible;
            if (output2.rank() == -1) {
                throw package$exception$.MODULE$.InvalidArgumentException().apply(new StringBuilder(58).append("Expected tensor (").append(output2).append(") to have known rank, but it was unknown.").toString());
            }
            outputConvertible2 = output2.rank() == 0 ? outputConvertible : splitBatchBeams(outputConvertible, shape, output, i);
        } else {
            if (!(outputConvertible instanceof TensorArray)) {
                throw package$exception$.MODULE$.InvalidArgumentException().apply("Unsupported argument type for use with the beam search decoder.");
            }
            outputConvertible2 = outputConvertible;
        }
        return outputConvertible2;
    }

    public OutputConvertible splitBatchBeams(OutputConvertible outputConvertible, Shape shape, Output output, int i) throws InvalidArgumentException, package$exception$InvalidShapeException {
        OutputConvertible outputConvertible2;
        Tuple2 tuple2 = new Tuple2(outputConvertible, shape);
        if (tuple2 != null) {
            OutputConvertible outputConvertible3 = (OutputConvertible) tuple2._1();
            Shape shape2 = (Shape) tuple2._2();
            if (outputConvertible3 instanceof Output) {
                Output output2 = (Output) outputConvertible3;
                if (shape2 != null) {
                    Output reshape = Basic$.MODULE$.reshape(output2, Basic$.MODULE$.concatenate((Seq) Seq$.MODULE$.apply(Predef$.MODULE$.wrapRefArray(new Output[]{output.apply(NewAxis$.MODULE$, Predef$.MODULE$.wrapRefArray(new Indexer[0])), Tensor$.MODULE$.apply(output.dataType(), BoxesRunTime.boxToInteger(i), Predef$.MODULE$.wrapIntArray(new int[0]), TensorConvertible$.MODULE$.fromSupportedType(SupportedType$.MODULE$.intIsSupported())).toOutput(), Implicits$.MODULE$.outputToCastOps(Basic$.MODULE$.shape(output2, Basic$.MODULE$.shape$default$2(), Basic$.MODULE$.shape$default$3(), Basic$.MODULE$.shape$default$4()).apply(Implicits$.MODULE$.intToIndexerConstruction(1).$colon$colon(), Predef$.MODULE$.wrapRefArray(new Indexer[0]))).cast(output.dataType())})), Implicits$.MODULE$.tensorConvertibleToOutput(BoxesRunTime.boxToInteger(0), TensorConvertible$.MODULE$.fromSupportedType(SupportedType$.MODULE$.intIsSupported())), Basic$.MODULE$.concatenate$default$3()), Basic$.MODULE$.reshape$default$3());
                    Shape $plus$plus = Shape$.MODULE$.apply((Seq<Object>) Predef$.MODULE$.wrapIntArray(new int[]{BoxesRunTime.unboxToInt(Output$.MODULE$.constantValue(output).map(tensor -> {
                        return BoxesRunTime.boxToInteger($anonfun$splitBatchBeams$1(tensor));
                    }).getOrElse(() -> {
                        return -1;
                    })), i})).$plus$plus(shape2);
                    if (!reshape.shape().isCompatibleWith($plus$plus)) {
                        throw new package$exception$InvalidShapeException(new StringBuilder(158).append("Unexpected behavior when reshaping between beam width and batch size. ").append(new StringBuilder(33).append("The reshaped tensor has shape: ").append(reshape.shape()).append(". ").toString()).append(new StringBuilder(64).append("We expected it to have shape [batchSize, beamWidth, depth] == ").append($plus$plus).append(". ").toString()).append("Perhaps you forgot to create a zero state with batchSize = encoderBatchSize * beamWidth?").toString(), package$exception$InvalidShapeException$.MODULE$.apply$default$2());
                    }
                    reshape.setShape($plus$plus);
                    outputConvertible2 = reshape;
                    return outputConvertible2;
                }
            }
        }
        if (tuple2 == null || !(tuple2._1() instanceof TensorArray)) {
            throw package$exception$.MODULE$.InvalidArgumentException().apply("Unsupported argument type for use with the beam search decoder.");
        }
        outputConvertible2 = outputConvertible;
        return outputConvertible2;
    }

    public OutputConvertible maybeMergeBatchBeams(OutputConvertible outputConvertible, Shape shape, Output output, int i) throws InvalidArgumentException {
        OutputConvertible outputConvertible2;
        if (outputConvertible instanceof Output) {
            Output output2 = (Output) outputConvertible;
            if (output2.rank() == -1) {
                throw package$exception$.MODULE$.InvalidArgumentException().apply(new StringBuilder(58).append("Expected tensor (").append(output2).append(") to have known rank, but it was unknown.").toString());
            }
            outputConvertible2 = output2.rank() == 0 ? outputConvertible : mergeBatchBeams(outputConvertible, shape, output, i);
        } else {
            if (!(outputConvertible instanceof TensorArray)) {
                throw package$exception$.MODULE$.InvalidArgumentException().apply("Unsupported argument type for use with the beam search decoder.");
            }
            outputConvertible2 = outputConvertible;
        }
        return outputConvertible2;
    }

    public OutputConvertible mergeBatchBeams(OutputConvertible outputConvertible, Shape shape, Output output, int i) throws InvalidArgumentException, package$exception$InvalidShapeException {
        OutputConvertible outputConvertible2;
        Tuple2 tuple2 = new Tuple2(outputConvertible, shape);
        if (tuple2 != null) {
            OutputConvertible outputConvertible3 = (OutputConvertible) tuple2._1();
            Shape shape2 = (Shape) tuple2._2();
            if (outputConvertible3 instanceof Output) {
                Output output2 = (Output) outputConvertible3;
                if (shape2 != null) {
                    Output reshape = Basic$.MODULE$.reshape(output2, Basic$.MODULE$.concatenate((Seq) Seq$.MODULE$.apply(Predef$.MODULE$.wrapRefArray(new Output[]{Implicits$.MODULE$.outputToMathOps(output.apply(NewAxis$.MODULE$, Predef$.MODULE$.wrapRefArray(new Indexer[0]))).$times(Tensor$.MODULE$.apply(output.dataType(), BoxesRunTime.boxToInteger(i), Predef$.MODULE$.wrapIntArray(new int[0]), TensorConvertible$.MODULE$.fromSupportedType(SupportedType$.MODULE$.intIsSupported())).toOutput()), Implicits$.MODULE$.outputToCastOps(Basic$.MODULE$.shape(output2, Basic$.MODULE$.shape$default$2(), Basic$.MODULE$.shape$default$3(), Basic$.MODULE$.shape$default$4()).apply(Implicits$.MODULE$.intToIndexerConstruction(2).$colon$colon(), Predef$.MODULE$.wrapRefArray(new Indexer[0]))).cast(output.dataType())})), Implicits$.MODULE$.tensorConvertibleToOutput(BoxesRunTime.boxToInteger(0), TensorConvertible$.MODULE$.fromSupportedType(SupportedType$.MODULE$.intIsSupported())), Basic$.MODULE$.concatenate$default$3()), Basic$.MODULE$.reshape$default$3());
                    int unboxToInt = BoxesRunTime.unboxToInt(Output$.MODULE$.constantValue(output).map(tensor -> {
                        return BoxesRunTime.boxToInteger($anonfun$mergeBatchBeams$1(tensor));
                    }).getOrElse(() -> {
                        return -1;
                    }));
                    Shape $plus$plus = Shape$.MODULE$.apply((Seq<Object>) Predef$.MODULE$.wrapIntArray(new int[]{unboxToInt != -1 ? unboxToInt * i : -1})).$plus$plus(shape2);
                    if (!reshape.shape().isCompatibleWith($plus$plus)) {
                        throw new package$exception$InvalidShapeException(new StringBuilder(158).append("Unexpected behavior when reshaping between beam width and batch size. ").append(new StringBuilder(33).append("The reshaped tensor has shape: ").append(reshape.shape()).append(". ").toString()).append(new StringBuilder(64).append("We expected it to have shape [batchSize, beamWidth, depth] == ").append($plus$plus).append(". ").toString()).append("Perhaps you forgot to create a zero state with batchSize = encoderBatchSize * beamWidth?").toString(), package$exception$InvalidShapeException$.MODULE$.apply$default$2());
                    }
                    reshape.setShape($plus$plus);
                    outputConvertible2 = reshape;
                    return outputConvertible2;
                }
            }
        }
        if (tuple2 == null || !(tuple2._1() instanceof TensorArray)) {
            throw package$exception$.MODULE$.InvalidArgumentException().apply("Unsupported argument type for use with the beam search decoder.");
        }
        outputConvertible2 = outputConvertible;
        return outputConvertible2;
    }

    public OutputConvertible maybeGather(Output output, OutputConvertible outputConvertible, Output output2, Output output3, Seq<Output> seq, String str) throws InvalidArgumentException {
        OutputConvertible outputConvertible2;
        if (outputConvertible instanceof Output) {
            Output output4 = (Output) outputConvertible;
            if (output4.rank() == -1) {
                throw package$exception$.MODULE$.InvalidArgumentException().apply(new StringBuilder(58).append("Expected tensor (").append(output4).append(") to have known rank, but it was unknown.").toString());
            }
            outputConvertible2 = output4.rank() < seq.size() ? outputConvertible : gather(output, outputConvertible, output2, output3, seq, str);
        } else {
            if (!(outputConvertible instanceof TensorArray)) {
                throw package$exception$.MODULE$.InvalidArgumentException().apply("Unsupported argument type for use with the beam search decoder.");
            }
            outputConvertible2 = outputConvertible;
        }
        return outputConvertible2;
    }

    public String maybeGather$default$6() {
        return "GatherTensorHelper";
    }

    public OutputConvertible gather(Output output, OutputConvertible outputConvertible, Output output2, Output output3, Seq<Output> seq, String str) throws InvalidArgumentException {
        OutputConvertible outputConvertible2;
        if (outputConvertible instanceof Output) {
            Output output4 = (Output) outputConvertible;
            outputConvertible2 = (OutputConvertible) Op$.MODULE$.createWithNameScope(str, Op$.MODULE$.createWithNameScope$default$2(), () -> {
                Output gather = Basic$.MODULE$.gather(Implicits$.MODULE$.outputToBasicOps(output4).reshape(Basic$.MODULE$.stack(seq, Basic$.MODULE$.stack$default$2(), Basic$.MODULE$.stack$default$3())), Implicits$.MODULE$.outputToBasicOps(Implicits$.MODULE$.outputToMathOps(output).$plus(Implicits$.MODULE$.outputToBasicOps(Implicits$.MODULE$.outputToMathOps(Math$.MODULE$.range(Implicits$.MODULE$.tensorConvertibleToOutput(BoxesRunTime.boxToInteger(0), TensorConvertible$.MODULE$.fromSupportedType(SupportedType$.MODULE$.intIsSupported())), output2, Math$.MODULE$.range$default$3(), Math$.MODULE$.range$default$4(), Math$.MODULE$.range$default$5())).$times(output3)).expandDims(Implicits$.MODULE$.tensorConvertibleToOutput(BoxesRunTime.boxToInteger(1), TensorConvertible$.MODULE$.fromSupportedType(SupportedType$.MODULE$.intIsSupported()))))).reshape(Implicits$.MODULE$.tensorConvertibleToOutput(Shape$.MODULE$.apply((Seq<Object>) Predef$.MODULE$.wrapIntArray(new int[]{-1})), TensorConvertible$.MODULE$.fromShape())), Basic$.MODULE$.gather$default$3(), Basic$.MODULE$.gather$default$4());
                Output apply = Basic$.MODULE$.shape(output4, Basic$.MODULE$.shape$default$2(), Basic$.MODULE$.shape$default$3(), Basic$.MODULE$.shape$default$4()).apply(IndexerConstructionWithTwoNumbers$.MODULE$.indexerConstructionToIndex(Implicits$.MODULE$.intToIndexerConstruction(1 + seq.size()).$colon$colon(Implicits$.MODULE$.intToIndexerConstruction(0))), Predef$.MODULE$.wrapRefArray(new Indexer[0]));
                Shape $plus$plus = Shape$.MODULE$.apply((Seq<Object>) Predef$.MODULE$.wrapIntArray(new int[]{BoxesRunTime.unboxToInt(Output$.MODULE$.constantValue(output2).map(tensor -> {
                    return BoxesRunTime.boxToInteger($anonfun$gather$2(tensor));
                }).getOrElse(() -> {
                    return -1;
                }))})).$plus$plus(output4.shape().apply(IndexerConstructionWithTwoNumbers$.MODULE$.indexerConstructionToIndex(Implicits$.MODULE$.intToIndexerConstruction(1 + seq.size()).$colon$colon(Implicits$.MODULE$.intToIndexerConstruction(1)))));
                Output reshape = Basic$.MODULE$.reshape(gather, apply, "Output");
                reshape.setShape($plus$plus);
                return reshape;
            });
        } else {
            if (!(outputConvertible instanceof TensorArray)) {
                throw package$exception$.MODULE$.InvalidArgumentException().apply("Unsupported argument type for use with the beam search decoder.");
            }
            outputConvertible2 = outputConvertible;
        }
        return outputConvertible2;
    }

    public String gather$default$6() {
        return "GatherTensorHelper";
    }

    public String gatherTree$default$5() {
        return "GatherTree";
    }

    public boolean checkStaticBatchBeam(Shape shape, int i, int i2) {
        if (i == -1 || shape.apply(0) == -1 || (shape.apply(0) == i * i2 && (shape.rank() <= 1 || shape.apply(1) == -1 || (shape.apply(0) == i && shape.apply(1) == i2)))) {
            return true;
        }
        Shape apply = Shape$.MODULE$.apply((Seq<Object>) Predef$.MODULE$.wrapIntArray(new int[]{i, i2, -1}));
        if (logger().underlying().isWarnEnabled()) {
            logger().underlying().warn(new StringBuilder(125).append("Tensor array reordering expects elements to be reshapable to '").append(apply).append("' which is incompatible with ").append(new StringBuilder(96).append("the current shape '").append(shape).append("'. Consider setting `reorderTensorArrays` to `false` to disable tensor array ").toString()).append("reordering during the beam search.").toString());
            BoxedUnit boxedUnit = BoxedUnit.UNIT;
        } else {
            BoxedUnit boxedUnit2 = BoxedUnit.UNIT;
        }
        return false;
    }

    public Op checkBatchBeam(Output output, Output output2, Output output3) {
        Checks$ checks$ = Checks$.MODULE$;
        Output shape = Basic$.MODULE$.shape(output, Basic$.MODULE$.shape$default$2(), Basic$.MODULE$.shape$default$3(), Basic$.MODULE$.shape$default$4());
        return checks$.mo327assert(output.rank() == 2 ? Math$.MODULE$.equal(shape.apply(Implicits$.MODULE$.intToIndex(1), Predef$.MODULE$.wrapRefArray(new Indexer[0])), Implicits$.MODULE$.outputToMathOps(output2).$times(output3), Math$.MODULE$.equal$default$3()) : Math$.MODULE$.logicalOr(Math$.MODULE$.equal(shape.apply(Implicits$.MODULE$.intToIndex(1), Predef$.MODULE$.wrapRefArray(new Indexer[0])), Implicits$.MODULE$.outputToMathOps(output2).$times(output3), Math$.MODULE$.equal$default$3()), Math$.MODULE$.logicalAnd(Math$.MODULE$.equal(shape.apply(Implicits$.MODULE$.intToIndex(1), Predef$.MODULE$.wrapRefArray(new Indexer[0])), output2, Math$.MODULE$.equal$default$3()), Math$.MODULE$.equal(shape.apply(Implicits$.MODULE$.intToIndex(2), Predef$.MODULE$.wrapRefArray(new Indexer[0])), output3, Math$.MODULE$.equal$default$3()), Math$.MODULE$.logicalAnd$default$3()), Math$.MODULE$.logicalOr$default$3()), (Seq) Seq$.MODULE$.apply(Predef$.MODULE$.wrapRefArray(new Output[]{Implicits$.MODULE$.tensorConvertibleToOutput(new StringBuilder(167).append("Tensor array reordering expects elements to be reshapable to '[batchSize, beamSize, -1]' which is ").append(new StringBuilder(91).append("incompatible with the dynamic shape of '").append(output.name()).append("' elements. Consider setting `reorderTensorArrays` ").toString()).append("to `false` to disable tensor array reordering during the beam search.").toString(), TensorConvertible$.MODULE$.fromSupportedType(SupportedType$.MODULE$.stringIsSupported()))})), Checks$.MODULE$.assert$default$3(), Checks$.MODULE$.assert$default$4());
    }

    /* JADX WARN: Removed duplicated region for block: B:18:0x00ca  */
    /* JADX WARN: Removed duplicated region for block: B:23:0x0127  */
    /*
        Code decompiled incorrectly, please refer to instructions dump.
        To view partially-correct add '--show-bad-code' argument
    */
    public org.platanios.tensorflow.api.ops.OutputConvertible maybeSortTensorArrayBeams(org.platanios.tensorflow.api.ops.OutputConvertible r14, org.platanios.tensorflow.api.ops.Output r15, org.platanios.tensorflow.api.ops.Output r16, org.platanios.tensorflow.api.ops.Output r17, int r18) {
        /*
            Method dump skipped, instructions count: 585
            To view this dump add '--comments-level debug' option
        */
        throw new UnsupportedOperationException("Method not decompiled: org.platanios.tensorflow.api.ops.seq2seq.decoders.BeamSearchDecoder$.maybeSortTensorArrayBeams(org.platanios.tensorflow.api.ops.OutputConvertible, org.platanios.tensorflow.api.ops.Output, org.platanios.tensorflow.api.ops.Output, org.platanios.tensorflow.api.ops.Output, int):org.platanios.tensorflow.api.ops.OutputConvertible");
    }

    public static final /* synthetic */ int $anonfun$splitBatchBeams$1(Tensor tensor) {
        return BoxesRunTime.unboxToInt(tensor.scalar());
    }

    public static final /* synthetic */ int $anonfun$mergeBatchBeams$1(Tensor tensor) {
        return BoxesRunTime.unboxToInt(tensor.scalar());
    }

    public static final /* synthetic */ int $anonfun$gather$2(Tensor tensor) {
        return BoxesRunTime.unboxToInt(tensor.scalar());
    }

    public static final /* synthetic */ int $anonfun$maybeSortTensorArrayBeams$1(Tensor tensor) {
        return BoxesRunTime.unboxToInt(tensor.scalar());
    }

    public static final /* synthetic */ Output $anonfun$maybeSortTensorArrayBeams$4(int i) {
        return Implicits$.MODULE$.tensorConvertibleToOutput(BoxesRunTime.boxToInteger(i), TensorConvertible$.MODULE$.fromSupportedType(SupportedType$.MODULE$.intIsSupported()));
    }

    private BeamSearchDecoder$() {
        MODULE$ = this;
        this.logger = Logger$.MODULE$.apply(LoggerFactory.getLogger("Ops / Beam Search Decoder"));
    }
}
