package org.platanios.tensorflow.api.ops.seq2seq.decoders;

import com.typesafe.scalalogging.Logger;
import com.typesafe.scalalogging.Logger$;
import org.platanios.tensorflow.api.core.Indexer;
import org.platanios.tensorflow.api.core.IndexerConstructionWithTwoNumbers$;
import org.platanios.tensorflow.api.core.Shape;
import org.platanios.tensorflow.api.core.Shape$;
import org.platanios.tensorflow.api.core.types.Cpackage;
import org.platanios.tensorflow.api.core.types.package$TF$;
import org.platanios.tensorflow.api.implicits.Implicits$;
import org.platanios.tensorflow.api.implicits.helpers.OutputStructure;
import org.platanios.tensorflow.api.implicits.helpers.OutputToShape;
import org.platanios.tensorflow.api.ops.Basic$;
import org.platanios.tensorflow.api.ops.Checks$;
import org.platanios.tensorflow.api.ops.Math$;
import org.platanios.tensorflow.api.ops.Op;
import org.platanios.tensorflow.api.ops.Op$;
import org.platanios.tensorflow.api.ops.Op$Builder$;
import org.platanios.tensorflow.api.ops.Op$OpInput$;
import org.platanios.tensorflow.api.ops.Op$OpInputPrimitive$;
import org.platanios.tensorflow.api.ops.Op$OpOutput$;
import org.platanios.tensorflow.api.ops.Op$OpOutputPrimitive$;
import org.platanios.tensorflow.api.ops.Output;
import org.platanios.tensorflow.api.ops.Output$;
import org.platanios.tensorflow.api.ops.rnn.cell.RNNCell;
import org.platanios.tensorflow.api.tensors.Tensor;
import org.platanios.tensorflow.api.tensors.Tensor$;
import org.platanios.tensorflow.api.utilities.DefaultsTo$;
import org.platanios.tensorflow.jni.InvalidArgumentException;
import org.slf4j.LoggerFactory;
import scala.Float$;
import scala.Function1;
import scala.Predef$;
import scala.Tuple4;
import scala.collection.Seq;
import scala.collection.Seq$;
import scala.collection.immutable.Nil$;
import scala.runtime.BoxedUnit;
import scala.runtime.BoxesRunTime;

/* compiled from: BeamSearchDecoder.scala */
/* loaded from: input_file:org/platanios/tensorflow/api/ops/seq2seq/decoders/BeamSearchDecoder$.class */
public final class BeamSearchDecoder$ {
    public static BeamSearchDecoder$ MODULE$;
    private final Logger logger;

    static {
        new BeamSearchDecoder$();
    }

    public <T, State, StateShape> LengthPenalty $lessinit$greater$default$7() {
        return NoPenalty$.MODULE$;
    }

    public <T, State, StateShape> Function1<Output<T>, Output<T>> $lessinit$greater$default$8() {
        return output -> {
            return output;
        };
    }

    public <T, State, StateShape> boolean $lessinit$greater$default$9() {
        return true;
    }

    public <T, State, StateShape> String $lessinit$greater$default$10() {
        return "BeamSearchRNNDecoder";
    }

    public Logger logger() {
        return this.logger;
    }

    public <T, State, StateShape> BeamSearchDecoder<T, State, StateShape> apply(RNNCell<Output<T>, State, Shape, StateShape> rNNCell, State state, Function1<Output<Object>, Output<T>> function1, Output<Object> output, Output<Object> output2, int i, LengthPenalty lengthPenalty, Function1<Output<T>, Output<T>> function12, boolean z, String str, Cpackage.TF<T> tf, OutputStructure<State> outputStructure, OutputToShape<State> outputToShape) {
        return new BeamSearchDecoder<>(rNNCell, state, function1, output, output2, i, lengthPenalty, function12, z, str, tf, outputStructure, outputToShape);
    }

    public <T, State, StateShape> LengthPenalty apply$default$7() {
        return NoPenalty$.MODULE$;
    }

    public <T, State, StateShape> Function1<Output<T>, Output<T>> apply$default$8() {
        return output -> {
            return output;
        };
    }

    public <T, State, StateShape> boolean apply$default$9() {
        return true;
    }

    public <T, State, StateShape> String apply$default$10() {
        return "BeamSearchRNNDecoder";
    }

    public Output<Object> maskLogProbabilities(Output<Object> output, Output<Object> output2, Output<Object> output3) {
        Output<Object> slice = Basic$.MODULE$.shape(output, Basic$.MODULE$.shape$default$2(), Basic$.MODULE$.shape$default$3(), package$TF$.MODULE$.floatEvTF()).castTo(package$TF$.MODULE$.intEvTF()).slice(Implicits$.MODULE$.intToIndex(2), Predef$.MODULE$.wrapRefArray(new Indexer[0]));
        return Math$.MODULE$.select(Basic$.MODULE$.tile(Implicits$.MODULE$.booleanOutputBasicOps(output3).expandDims(Implicits$.MODULE$.intToOutput(2)), Basic$.MODULE$.stack(Seq$.MODULE$.apply(Predef$.MODULE$.wrapRefArray(new Output[]{Implicits$.MODULE$.intToOutput(1), Implicits$.MODULE$.intToOutput(1), slice})), Basic$.MODULE$.stack$default$2(), Basic$.MODULE$.stack$default$3(), package$TF$.MODULE$.intEvTF()), Basic$.MODULE$.tile$default$3(), package$TF$.MODULE$.booleanEvTF(), package$TF$.MODULE$.intEvTF(), Predef$.MODULE$.$conforms()), Basic$.MODULE$.tile(Implicits$.MODULE$.floatOutputBasicOps(Basic$.MODULE$.oneHot(output2, slice, Basic$.MODULE$.zeros(Implicits$.MODULE$.shapeToOutput(Shape$.MODULE$.apply((Seq<Object>) Nil$.MODULE$)), package$TF$.MODULE$.floatEvTF()), Basic$.MODULE$.constant(Implicits$.MODULE$.floatToTensor(Float$.MODULE$.MinValue()), Basic$.MODULE$.constant$default$2(), Basic$.MODULE$.constant$default$3()), Basic$.MODULE$.oneHot$default$5(), Basic$.MODULE$.oneHot$default$6(), package$TF$.MODULE$.floatEvTF(), package$TF$.MODULE$.intEvTF(), Predef$.MODULE$.$conforms())).reshape(Implicits$.MODULE$.shapeToOutput(Shape$.MODULE$.apply((Seq<Object>) Predef$.MODULE$.wrapIntArray(new int[]{1, 1, -1}))), package$TF$.MODULE$.intEvTF(), Predef$.MODULE$.$conforms()), Basic$.MODULE$.concatenate(Seq$.MODULE$.apply(Predef$.MODULE$.wrapRefArray(new Output[]{Basic$.MODULE$.shape(output3, Basic$.MODULE$.shape$default$2(), Basic$.MODULE$.shape$default$3(), package$TF$.MODULE$.booleanEvTF()).castTo(package$TF$.MODULE$.intEvTF()), Implicits$.MODULE$.outputFromTensor(Tensor$.MODULE$.apply(Predef$.MODULE$.wrapRefArray(new Tensor[]{Implicits$.MODULE$.intToTensor(1)}), package$TF$.MODULE$.intEvTF()), package$TF$.MODULE$.intEvTF())})), Implicits$.MODULE$.intToOutput(0), Basic$.MODULE$.concatenate$default$3(), package$TF$.MODULE$.intEvTF()), Basic$.MODULE$.tile$default$3(), package$TF$.MODULE$.floatEvTF(), package$TF$.MODULE$.intEvTF(), Predef$.MODULE$.$conforms()), output, Math$.MODULE$.select$default$4(), package$TF$.MODULE$.floatEvTF());
    }

    public Output<Object> gatherTree(Output<Object> output, Output<Object> output2, Output<Object> output3, Output<Object> output4, String str) {
        return (Output) new Op.Builder("GatherTree", str, new Tuple4(output, output2, output3, output4), Op$Builder$.MODULE$.apply$default$4(), Op$OpInput$.MODULE$.opInputPrimitiveTuple4Evidence(Op$OpInputPrimitive$.MODULE$.outputEvidence(), Op$OpInputPrimitive$.MODULE$.outputEvidence(), Op$OpInputPrimitive$.MODULE$.outputEvidence(), Op$OpInputPrimitive$.MODULE$.outputEvidence()), Op$OpOutput$.MODULE$.opOutputPrimitiveEvidence(Op$OpOutputPrimitive$.MODULE$.outputEvidence())).build().output();
    }

    public <T> Output<T> gather(Output<Object> output, Output<T> output2, Output<Object> output3, Output<Object> output4, Seq<Output<Object>> seq, String str) throws InvalidArgumentException {
        Cpackage.TF<T> fromDataType = package$TF$.MODULE$.fromDataType(output2.dataType());
        return (Output) Op$.MODULE$.nameScope(str, () -> {
            Implicits$ implicits$ = Implicits$.MODULE$;
            Math$ math$ = Math$.MODULE$;
            Output<Object> intToOutput = Implicits$.MODULE$.intToOutput(0);
            Math$.MODULE$.range$default$3();
            Output gather = Basic$.MODULE$.gather(Implicits$.MODULE$.outputBasicOps(output2).reshape(Basic$.MODULE$.stack(seq, Basic$.MODULE$.stack$default$2(), Basic$.MODULE$.stack$default$3(), package$TF$.MODULE$.intEvTF()), package$TF$.MODULE$.intEvTF(), Predef$.MODULE$.$conforms()), Implicits$.MODULE$.intOutputBasicOps(output.$plus(implicits$.intOutputBasicOps(math$.range(intToOutput, output3, null, Math$.MODULE$.range$default$4(), package$TF$.MODULE$.intEvTF(), Predef$.MODULE$.$conforms()).$times(output4, Predef$.MODULE$.$conforms())).expandDims(Implicits$.MODULE$.intToOutput(1)), Predef$.MODULE$.$conforms())).reshape(Implicits$.MODULE$.shapeToOutput(Shape$.MODULE$.apply((Seq<Object>) Predef$.MODULE$.wrapIntArray(new int[]{-1}))), package$TF$.MODULE$.intEvTF(), Predef$.MODULE$.$conforms()), Implicits$.MODULE$.intToOutput(0), Basic$.MODULE$.gather$default$4(), fromDataType, package$TF$.MODULE$.intEvTF(), Predef$.MODULE$.$conforms(), DefaultsTo$.MODULE$.defaultDefaultsTo(), package$TF$.MODULE$.intEvTF(), Predef$.MODULE$.$conforms());
            Output slice = Basic$.MODULE$.shape(output2, Basic$.MODULE$.shape$default$2(), Basic$.MODULE$.shape$default$3(), fromDataType).castTo(package$TF$.MODULE$.intEvTF()).slice(IndexerConstructionWithTwoNumbers$.MODULE$.indexerConstructionToIndex(Implicits$.MODULE$.intToIndexerConstruction(1 + seq.size()).$colon$colon(Implicits$.MODULE$.intToIndexerConstruction(0))), Predef$.MODULE$.wrapRefArray(new Indexer[0]));
            Shape $plus$plus = Shape$.MODULE$.apply((Seq<Object>) Predef$.MODULE$.wrapIntArray(new int[]{BoxesRunTime.unboxToInt(Output$.MODULE$.constantValue(output3).map(tensor -> {
                return BoxesRunTime.boxToInteger($anonfun$gather$2(tensor));
            }).getOrElse(() -> {
                return -1;
            }))})).$plus$plus(output2.shape().apply(IndexerConstructionWithTwoNumbers$.MODULE$.indexerConstructionToIndex(Implicits$.MODULE$.intToIndexerConstruction(1 + seq.size()).$colon$colon(Implicits$.MODULE$.intToIndexerConstruction(1)))));
            Output reshape = Basic$.MODULE$.reshape(gather, slice, "Output", fromDataType, package$TF$.MODULE$.intEvTF(), Predef$.MODULE$.$conforms());
            reshape.setShape($plus$plus);
            return reshape;
        });
    }

    public <T> String gather$default$6() {
        return "GatherTensorHelper";
    }

    public String gatherTree$default$5() {
        return "GatherTree";
    }

    public boolean checkStaticBatchBeam(Shape shape, int i, int i2) {
        if (i == -1 || shape.apply(0) == -1 || (shape.apply(0) == i * i2 && (shape.rank() <= 1 || shape.apply(1) == -1 || (shape.apply(0) == i && shape.apply(1) == i2)))) {
            return true;
        }
        Shape apply = Shape$.MODULE$.apply((Seq<Object>) Predef$.MODULE$.wrapIntArray(new int[]{i, i2, -1}));
        if (logger().underlying().isWarnEnabled()) {
            logger().underlying().warn(new StringBuilder(125).append("Tensor array reordering expects elements to be reshapable to '").append(apply).append("' which is incompatible with ").append(new StringBuilder(96).append("the current shape '").append(shape).append("'. Consider setting `reorderTensorArrays` to `false` to disable tensor array ").toString()).append("reordering during the beam search.").toString());
            BoxedUnit boxedUnit = BoxedUnit.UNIT;
        } else {
            BoxedUnit boxedUnit2 = BoxedUnit.UNIT;
        }
        return false;
    }

    /* JADX WARN: Multi-variable type inference failed */
    public <T> Op<Seq<Output<Object>>, Seq<Output<Object>>> checkBatchBeam(Output<T> output, Output<Object> output2, Output<Object> output3) {
        Cpackage.TF<T> fromDataType = package$TF$.MODULE$.fromDataType(output.dataType());
        Implicits$ implicits$ = Implicits$.MODULE$;
        Checks$ checks$ = Checks$.MODULE$;
        Output<R> castTo = Basic$.MODULE$.shape(output, Basic$.MODULE$.shape$default$2(), Basic$.MODULE$.shape$default$3(), fromDataType).castTo(package$TF$.MODULE$.intEvTF());
        return implicits$.opAsUntyped(checks$.mo362assert(output.rank() == 2 ? Math$.MODULE$.equal(castTo.apply(Implicits$.MODULE$.intToIndex(1), Predef$.MODULE$.wrapRefArray(new Indexer[0])), output2.$times(output3, Predef$.MODULE$.$conforms()), Math$.MODULE$.equal$default$3(), package$TF$.MODULE$.intEvTF(), Predef$.MODULE$.$conforms()) : Math$.MODULE$.logicalOr(Math$.MODULE$.equal(castTo.apply(Implicits$.MODULE$.intToIndex(1), Predef$.MODULE$.wrapRefArray(new Indexer[0])), output2.$times(output3, Predef$.MODULE$.$conforms()), Math$.MODULE$.equal$default$3(), package$TF$.MODULE$.intEvTF(), Predef$.MODULE$.$conforms()), Math$.MODULE$.logicalAnd(Math$.MODULE$.equal(castTo.apply(Implicits$.MODULE$.intToIndex(1), Predef$.MODULE$.wrapRefArray(new Indexer[0])), output2, Math$.MODULE$.equal$default$3(), package$TF$.MODULE$.intEvTF(), Predef$.MODULE$.$conforms()), Math$.MODULE$.equal(castTo.apply(Implicits$.MODULE$.intToIndex(2), Predef$.MODULE$.wrapRefArray(new Indexer[0])), output3, Math$.MODULE$.equal$default$3(), package$TF$.MODULE$.intEvTF(), Predef$.MODULE$.$conforms()), Math$.MODULE$.logicalAnd$default$3()), Math$.MODULE$.logicalOr$default$3()), (Seq) Seq$.MODULE$.apply(Predef$.MODULE$.wrapRefArray(new Output[]{Implicits$.MODULE$.outputAsUntyped(Tensor$.MODULE$.apply(Predef$.MODULE$.wrapRefArray(new Tensor[]{Implicits$.MODULE$.tensorFromSupportedType(new StringBuilder(189).append("Tensor array reordering expects elements to be reshapable to '[batchSize, beamSize, -1]' which is ").append(new StringBuilder(69).append("incompatible with the dynamic shape of '").append(output.name()).append("' elements. Consider setting ").toString()).append("`reorderTensorArrays` to `false` to disable tensor array reordering during the beam search.").toString(), package$TF$.MODULE$.stringEvTF())}), package$TF$.MODULE$.stringEvTF()).toOutput())})), Checks$.MODULE$.assert$default$3(), Checks$.MODULE$.assert$default$4()));
    }

    public static final /* synthetic */ int $anonfun$gather$2(Tensor tensor) {
        return BoxesRunTime.unboxToInt(tensor.scalar());
    }

    private BeamSearchDecoder$() {
        MODULE$ = this;
        this.logger = Logger$.MODULE$.apply(LoggerFactory.getLogger("Ops / Beam Search Decoder"));
    }
}
