package com.omega.engine.nn.layer;

import com.omega.common.data.Tensor;
import com.omega.common.utils.MatrixUtils;
import com.omega.common.utils.RandomUtils;
import com.omega.engine.ad.op.TensorOP;
import com.omega.engine.gpu.BaseKernel;
import com.omega.engine.nn.layer.active.ReluLayer;
import com.omega.engine.nn.layer.normalization.LNLayer;
import com.omega.engine.nn.network.CNN;
import com.omega.engine.nn.network.Network;

/* loaded from: input_file:com/omega/engine/nn/layer/PoswiseFeedForwardLayer.class */
public class PoswiseFeedForwardLayer extends Layer {
    private int time;
    private int embedDim;
    private int nChannel;
    private boolean bias;
    private boolean layer_norm;
    private ConvolutionLayer conv1;
    private ReluLayer relu1;
    private ConvolutionLayer conv2;
    private LNLayer lnLayer;
    private BaseKernel baseKernel;
    private Tensor it;
    private Tensor ro;

    public PoswiseFeedForwardLayer(int i, int i2, int i3, boolean z, boolean z2) {
        this.embedDim = 0;
        this.nChannel = 1;
        this.bias = false;
        this.layer_norm = false;
        this.time = i;
        this.embedDim = i2;
        this.nChannel = i3;
        this.bias = z;
        this.layer_norm = z2;
        initLayers();
    }

    public PoswiseFeedForwardLayer(int i, int i2, int i3, boolean z, boolean z2, Network network) {
        this.embedDim = 0;
        this.nChannel = 1;
        this.bias = false;
        this.layer_norm = false;
        this.network = network;
        this.time = i;
        this.embedDim = i2;
        this.nChannel = i3;
        this.bias = z;
        this.layer_norm = z2;
        initLayers();
    }

    public void initLayers() {
        this.conv1 = new ConvolutionLayer(this.embedDim, this.nChannel, this.time, 1, 1, 1, 0, 1, this.bias, this.network);
        this.relu1 = new ReluLayer(this.conv1);
        this.conv2 = new ConvolutionLayer(this.nChannel, this.embedDim, this.time, 1, 1, 1, 0, 1, this.bias, this.network);
        if (this.layer_norm) {
            this.lnLayer = new LNLayer(this.conv2);
        }
        if (this.baseKernel == null) {
            this.baseKernel = new BaseKernel();
        }
    }

    @Override // com.omega.engine.nn.layer.Layer
    public void init() {
        this.number = this.input.number;
        if (this.ro == null || this.ro.number != this.number) {
            this.it = Tensor.createTensor(this.it, this.number, this.embedDim, 1, this.time, true);
            this.ro = Tensor.createTensor(this.ro, this.number, this.time, 1, this.embedDim, true);
        }
        resize();
    }

    public void resize() {
        this.ro.viewOrg();
    }

    @Override // com.omega.engine.nn.layer.Layer
    public void initBack() {
    }

    @Override // com.omega.engine.nn.layer.Layer
    public void initParam() {
    }

    @Override // com.omega.engine.nn.layer.Layer
    public void output() {
        TensorOP.permute(this.input, this.it, new int[]{0, 3, 2, 1});
        this.it.showShape();
        this.conv1.forward(this.input);
        this.relu1.forward(this.conv1.getOutput());
        this.conv2.forward(this.relu1.getOutput());
        TensorOP.permute(this.conv2.getOutput(), this.ro, new int[]{0, 3, 2, 1});
        TensorOP.add(this.ro, this.input, this.ro);
        if (!this.layer_norm) {
            this.output = this.ro;
        } else {
            this.lnLayer.forward(this.ro);
            this.output = this.lnLayer.getOutput();
        }
    }

    @Override // com.omega.engine.nn.layer.Layer
    public Tensor getOutput() {
        return this.output;
    }

    @Override // com.omega.engine.nn.layer.Layer
    public void diff() {
        this.ro.view(this.number, this.embedDim, 1, this.time);
        if (this.layer_norm) {
            this.lnLayer.back(this.delta);
            TensorOP.permute(this.lnLayer.diff, this.ro, new int[]{0, 3, 2, 1});
        } else {
            TensorOP.permute(this.delta, this.ro, new int[]{0, 3, 2, 1});
        }
        this.conv2.back(this.ro);
        this.relu1.back(this.conv2.diff);
        this.conv1.back(this.relu1.diff);
        this.ro.view(this.number, this.time, 1, this.embedDim);
        TensorOP.permute(this.conv1.diff, this.ro, new int[]{0, 3, 2, 1});
        TensorOP.add(this.ro, this.delta, this.ro);
        this.diff = this.ro;
    }

    @Override // com.omega.engine.nn.layer.Layer
    public void forward() {
        setInput();
        init();
        output();
    }

    @Override // com.omega.engine.nn.layer.Layer
    public void back() {
        initBack();
        setDelta();
        diff();
        if (this.network.GRADIENT_CHECK) {
            gradientCheck();
        }
    }

    @Override // com.omega.engine.nn.layer.Layer
    public void forward(Tensor tensor) {
        setInput(tensor);
        init();
        output();
    }

    @Override // com.omega.engine.nn.layer.Layer
    public void back(Tensor tensor) {
        initBack();
        setDelta(tensor);
        diff();
        if (this.network.GRADIENT_CHECK) {
            gradientCheck();
        }
    }

    @Override // com.omega.engine.nn.layer.Layer
    public void update() {
        this.conv1.update();
        this.conv2.update();
        this.lnLayer.update();
    }

    @Override // com.omega.engine.nn.layer.Layer
    public void showDiff() {
    }

    @Override // com.omega.engine.nn.layer.Layer
    public LayerType getLayerType() {
        return LayerType.poswise_feed_forward;
    }

    @Override // com.omega.engine.nn.layer.Layer
    public float[][][][] output(float[][][][] fArr) {
        return (float[][][][]) null;
    }

    @Override // com.omega.engine.nn.layer.Layer
    public void initCache() {
    }

    @Override // com.omega.engine.nn.layer.Layer
    public void backTemp() {
    }

    public static void main(String[] strArr) {
        CNN cnn = new CNN(null);
        cnn.CUDNN = true;
        cnn.number = 5;
        Tensor tensor = new Tensor(5, 10, 1, 8, RandomUtils.order(5 * 10 * 8, 0.1f, 0.1f), true);
        tensor.showShape();
        tensor.showDM();
        Tensor tensor2 = new Tensor(5, 10, 1, 8, MatrixUtils.val(5 * 10 * 8, 1.0f), true);
        PoswiseFeedForwardLayer poswiseFeedForwardLayer = new PoswiseFeedForwardLayer(10, 8, 4, false, true, cnn);
        poswiseFeedForwardLayer.forward(tensor);
        poswiseFeedForwardLayer.getOutput().showDM();
        poswiseFeedForwardLayer.back(tensor2);
        poswiseFeedForwardLayer.diff.showDM();
    }
}
