package ai.djl.basicmodelzoo.nlp;

import ai.djl.modality.nlp.Decoder;
import ai.djl.modality.nlp.embedding.TrainableTextEmbedding;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDArrays;
import ai.djl.ndarray.NDList;
import ai.djl.nn.Block;
import ai.djl.nn.SequentialBlock;
import ai.djl.nn.core.Linear;
import ai.djl.nn.recurrent.RecurrentBlock;
import ai.djl.training.ParameterStore;
import ai.djl.util.PairList;

/* loaded from: input_file:ai/djl/basicmodelzoo/nlp/SimpleSequenceDecoder.class */
public class SimpleSequenceDecoder extends Decoder {
    private RecurrentBlock recurrentBlock;

    public SimpleSequenceDecoder(RecurrentBlock recurrentBlock, int i) {
        this(null, recurrentBlock, i);
    }

    public SimpleSequenceDecoder(TrainableTextEmbedding trainableTextEmbedding, RecurrentBlock recurrentBlock, int i) {
        super(getBlock(trainableTextEmbedding, recurrentBlock, i));
        this.recurrentBlock = recurrentBlock;
    }

    private static Block getBlock(TrainableTextEmbedding trainableTextEmbedding, RecurrentBlock recurrentBlock, int i) {
        SequentialBlock sequentialBlock = new SequentialBlock();
        sequentialBlock.add(trainableTextEmbedding).add(recurrentBlock).add(Linear.builder().setOutChannels(i).optFlatten(false).build());
        return sequentialBlock;
    }

    public void initState(NDList nDList) {
        this.recurrentBlock.setBeginStates(nDList);
    }

    public NDList forward(ParameterStore parameterStore, NDList nDList, boolean z, PairList<String, Object> pairList) {
        if (z) {
            return this.block.forward(parameterStore, nDList, true, pairList);
        }
        if (((NDArray) nDList.get(0)).getShape().get(1) != 1) {
            throw new IllegalArgumentException("Input sequence length must be 1 during prediction");
        }
        NDList nDList2 = new NDList();
        for (int i = 0; i < 10; i++) {
            nDList = new NDList(new NDArray[]{this.block.forward(parameterStore, nDList, false).head().argMax(2)});
            nDList2.add(nDList.head().transpose(new int[]{1, 0}));
        }
        return new NDList(new NDArray[]{NDArrays.stack(nDList2).transpose(new int[]{2, 1, 0})});
    }
}
