package com.omega.engine.nn.network;

import com.omega.common.data.Tensor;
import com.omega.engine.gpu.BaseKernel;
import com.omega.engine.loss.LossFactory;
import com.omega.engine.loss.LossType;
import com.omega.engine.nn.layer.BaseRNNLayer;
import com.omega.engine.nn.layer.EmbeddingLayer;
import com.omega.engine.nn.layer.FullyLayer;
import com.omega.engine.nn.layer.InputLayer;
import com.omega.engine.nn.layer.LayerType;
import com.omega.engine.nn.layer.RNNBlockLayer;
import com.omega.engine.nn.layer.SoftmaxWithCrossEntropyLayer;
import com.omega.engine.nn.model.RNNCellType;
import com.omega.engine.updater.UpdaterType;

/* loaded from: input_file:com/omega/engine/nn/network/Seq2Seq.class */
public class Seq2Seq extends Network {
    private RNNCellType cellType;
    public int en_time;
    public int de_time;
    public int en_len;
    public int de_len;
    private InputLayer inputLayer;
    private EmbeddingLayer en_emLayer;
    private BaseRNNLayer en_rnnLayer;
    private EmbeddingLayer de_emLayer;
    private BaseRNNLayer de_rnnLayer;
    private FullyLayer fullyLayer;
    private Tensor en_delta;
    private BaseKernel baseKernel;

    public Seq2Seq(RNNCellType rNNCellType, LossType lossType, UpdaterType updaterType, int i, int i2, int i3, int i4, int i5, int i6, int i7, int i8) {
        this.en_time = 1;
        this.de_time = 1;
        this.cellType = rNNCellType;
        this.lossFunction = LossFactory.create(lossType);
        this.updater = updaterType;
        this.en_time = i;
        this.de_time = i2;
        this.en_len = i5;
        this.de_len = i8;
        this.inputLayer = new InputLayer(1, 1, i5);
        this.en_emLayer = new EmbeddingLayer(i5, i3, this);
        this.de_emLayer = new EmbeddingLayer(i8, i6, this);
        this.fullyLayer = new FullyLayer(i7, i8, true, this);
        switch (this.cellType) {
            case RNN:
                this.en_rnnLayer = new RNNBlockLayer(i, i3, i4, 1, false, false, 0.0f);
                this.de_rnnLayer = new RNNBlockLayer(i2, i6, i7, 1, false, false, 0.0f);
                break;
            case LSTM:
                this.en_rnnLayer = new RNNBlockLayer(i, i3, i4, 2, false, false, 0.0f);
                this.de_rnnLayer = new RNNBlockLayer(i2, i6, i7, 2, false, false, 0.0f);
                break;
            case GRU:
                this.en_rnnLayer = new RNNBlockLayer(i, i3, i4, 3, false, false, 0.0f);
                this.de_rnnLayer = new RNNBlockLayer(i2, i6, i7, 3, false, false, 0.0f);
                break;
        }
        addLayer(this.inputLayer);
        addLayer(this.en_emLayer);
        addLayer(this.en_rnnLayer);
        addLayer(this.de_emLayer);
        addLayer(this.de_rnnLayer);
        addLayer(this.fullyLayer);
    }

    @Override // com.omega.engine.nn.network.Network
    public void init() throws Exception {
        if (this.layerList.size() <= 0) {
            throw new Exception("layer size must greater than 2.");
        }
        this.layerCount = this.layerList.size();
        this.channel = this.layerList.get(0).channel;
        this.height = this.layerList.get(0).height;
        this.width = this.layerList.get(0).width;
        this.oChannel = getLastLayer().oChannel;
        this.oHeight = getLastLayer().oHeight;
        this.oWidth = getLastLayer().oWidth;
        if (this.layerList.get(0).getLayerType() != LayerType.input) {
            throw new Exception("first layer must be input layer.");
        }
        if ((this.layerList.get(this.layerList.size() - 1).getLayerType() == LayerType.softmax || this.layerList.get(this.layerList.size() - 1).getLayerType() == LayerType.softmax_cross_entropy) && this.lossFunction.getLossType() != LossType.cross_entropy) {
            throw new Exception("The softmax function support only cross entropy loss function now.");
        }
        if (this.baseKernel == null) {
            this.baseKernel = new BaseKernel();
        }
        System.out.println("the network is ready.");
    }

    @Override // com.omega.engine.nn.network.Network
    public NetworkType getNetworkType() {
        return NetworkType.SEQ2SEQ;
    }

    @Override // com.omega.engine.nn.network.Network
    public Tensor predict(Tensor tensor) {
        this.RUN_MODEL = RunModel.TEST;
        forward(tensor);
        return getOutput();
    }

    @Override // com.omega.engine.nn.network.Network
    public Tensor forward(Tensor tensor) {
        return getOutput();
    }

    public Tensor forward(Tensor tensor, Tensor tensor2) {
        setInputData(tensor);
        this.inputLayer.forward();
        this.en_emLayer.forward();
        this.en_rnnLayer.forward(this.en_time, this.en_emLayer.output.number);
        this.de_emLayer.forward(tensor2);
        this.de_rnnLayer.forward(this.de_emLayer.getOutput(), this.en_rnnLayer.getHy(), this.en_rnnLayer.getCy(), this.de_time);
        this.fullyLayer.forward(this.de_rnnLayer.getOutput());
        return getOutput();
    }

    public void initEnRNNLayerDelta(Tensor tensor) {
        if (this.en_delta == null || this.en_delta.number != this.en_rnnLayer.getOutput().number) {
            this.en_delta = Tensor.createTensor(this.en_delta, this.en_rnnLayer.getOutput().number, this.en_rnnLayer.getOutput().channel, this.en_rnnLayer.getOutput().height, this.en_rnnLayer.getOutput().width, true);
        }
        this.en_delta.clearGPU();
        this.baseKernel.copy_gpu(tensor, this.en_delta, tensor.getDataLength(), 0, 1, (this.en_time - 1) * tensor.getDataLength(), 1);
    }

    @Override // com.omega.engine.nn.network.Network
    public void back(Tensor tensor) {
        setLossDiff(tensor);
        this.fullyLayer.back();
        this.de_rnnLayer.back();
        this.de_emLayer.back();
        initEnRNNLayerDelta(this.de_rnnLayer.getDhx());
        this.en_rnnLayer.back(this.en_delta, this.en_rnnLayer.getHy(), this.en_rnnLayer.getCy(), null, this.de_rnnLayer.getDcx());
        this.en_emLayer.back(this.en_rnnLayer.diff);
    }

    @Override // com.omega.engine.nn.network.Network
    public Tensor loss(Tensor tensor, Tensor tensor2) {
        switch (getLastLayer().getLayerType()) {
            case softmax_cross_entropy:
                ((SoftmaxWithCrossEntropyLayer) getLastLayer()).setCurrentLabel(tensor2);
                break;
        }
        return this.lossFunction.loss(tensor, tensor2);
    }

    @Override // com.omega.engine.nn.network.Network
    public Tensor lossDiff(Tensor tensor, Tensor tensor2) {
        return this.lossFunction.diff(tensor, tensor2);
    }

    @Override // com.omega.engine.nn.network.Network
    public void clearGrad() {
    }

    @Override // com.omega.engine.nn.network.Network
    public Tensor loss(Tensor tensor, Tensor tensor2, Tensor tensor3) {
        return this.lossFunction.loss(tensor, tensor2, tensor3);
    }

    @Override // com.omega.engine.nn.network.Network
    public Tensor lossDiff(Tensor tensor, Tensor tensor2, Tensor tensor3) {
        return this.lossFunction.diff(tensor, tensor2, tensor3);
    }

    public Tensor[] encoder(Tensor tensor) {
        setInputData(tensor);
        this.inputLayer.forward();
        this.en_emLayer.forward();
        this.en_rnnLayer.forward(this.en_time, this.en_emLayer.getOutput().number);
        return new Tensor[]{this.en_rnnLayer.getOutput(), this.en_rnnLayer.getHy(), this.en_rnnLayer.getCy()};
    }

    public Tensor decoder(Tensor tensor, Tensor tensor2, Tensor tensor3) {
        this.de_emLayer.forward(tensor3);
        this.de_rnnLayer.forward(this.de_emLayer.getOutput(), tensor, tensor2, 1);
        this.baseKernel.copy_gpu(this.de_rnnLayer.getHy(), tensor, tensor.getDataLength(), 0, 1, 0, 1);
        this.baseKernel.copy_gpu(this.de_rnnLayer.getCy(), tensor2, tensor2.getDataLength(), 0, 1, 0, 1);
        this.fullyLayer.forward(this.de_rnnLayer.getOutput());
        return this.fullyLayer.getOutput();
    }
}
