package com.omega.engine.nn.layer.normalization;

import com.omega.common.data.Tensor;
import com.omega.common.utils.MatrixUtils;
import com.omega.engine.nn.layer.Layer;
import com.omega.engine.nn.layer.LayerType;
import com.omega.engine.nn.layer.normalization.gpu.LNKernel;
import com.omega.engine.nn.model.LayerInit;
import com.omega.engine.nn.network.Network;
import com.omega.engine.updater.UpdaterFactory;

/* loaded from: input_file:com/omega/engine/nn/layer/normalization/LNLayer.class */
public class LNLayer extends NormalizationLayer {
    public BNType bnType = null;
    private int meanNum = 0;
    public LNKernel kernel;

    public LNLayer() {
        this.hasParams = true;
    }

    public LNLayer(boolean z) {
        this.hasBias = true;
        this.hasParams = true;
    }

    public LNLayer(Layer layer) {
        setPreLayer(layer);
        this.hasParams = true;
        setUpdater(UpdaterFactory.create(this.network.updater, this.network.updaterParams));
    }

    public LNLayer(Layer layer, boolean z) {
        setPreLayer(layer);
        this.hasBias = true;
        this.hasParams = true;
        setUpdater(UpdaterFactory.create(this.network.updater, this.network.updaterParams));
    }

    public LNLayer(Network network) {
        this.network = network;
        setUpdater(UpdaterFactory.create(this.network.updater, this.network.updaterParams));
    }

    public LNLayer(Network network, boolean z) {
        this.network = network;
        this.hasBias = true;
        setUpdater(UpdaterFactory.create(this.network.updater, this.network.updaterParams));
    }

    @Override // com.omega.engine.nn.layer.normalization.NormalizationLayer, com.omega.engine.nn.layer.Layer
    public void init() {
        this.number = this.network.number;
        if (this.preLayer == null) {
            this.preLayer = this.network.getPreLayer(this.index);
        }
        if (this.bnType == null) {
            if (this.preLayer != null) {
                this.channel = this.preLayer.oChannel;
                this.height = this.preLayer.oHeight;
                this.width = this.preLayer.oWidth;
                this.oChannel = this.channel;
                this.oHeight = this.height;
                this.oWidth = this.width;
                if (this.preLayer.getLayerType() == LayerType.conv) {
                    setBnType(BNType.conv_bn);
                    this.meanNum = this.height * this.width;
                } else if (this.preLayer.getLayerType() == LayerType.full) {
                    setBnType(BNType.fully_bn);
                    this.meanNum = this.channel * this.height * this.width;
                } else if (this.preLayer.getLayerType() == LayerType.conv_transpose) {
                    setBnType(BNType.conv_bn);
                    this.meanNum = this.height * this.width;
                } else {
                    setBnType(BNType.fully_bn);
                    this.meanNum = this.channel * this.height * this.width;
                }
            } else {
                setBnType(BNType.fully_bn);
                this.meanNum = this.channel * this.height * this.width;
            }
        }
        if (this.kernel == null) {
            this.kernel = new LNKernel(this.width, this.bnType);
        }
        if (this.gamma == null) {
            this.gamma = new Tensor(1, 1, 1, this.meanNum, MatrixUtils.one(this.meanNum), true);
            if (this.network != null) {
                this.diffGamma = this.network.createParamterGrad(1, 1, 1, this.meanNum, true);
            } else {
                this.diffGamma = new Tensor(1, 1, 1, this.meanNum, true);
            }
        }
        if (this.beta == null && this.hasBias) {
            this.beta = new Tensor(1, 1, 1, this.meanNum, true);
            if (this.network != null) {
                this.diffBeta = this.network.createParamterGrad(1, 1, 1, this.meanNum, true);
            } else {
                this.diffBeta = new Tensor(1, 1, 1, this.meanNum, true);
            }
        }
        if (this.output == null || this.number != this.output.number) {
            this.output = Tensor.createTensor(this.output, this.number, this.oChannel, this.oHeight, this.oWidth, true);
        }
    }

    public void init(Tensor tensor) {
        this.number = tensor.number;
        if (this.bnType == null) {
            this.channel = tensor.channel;
            this.height = tensor.height;
            this.width = tensor.width;
            this.oChannel = this.channel;
            this.oHeight = this.height;
            this.oWidth = this.width;
            setBnType(BNType.fully_bn);
            this.meanNum = this.channel * this.height * this.width;
        }
        if (this.kernel == null) {
            this.kernel = new LNKernel(this.width, this.bnType);
        }
        if (this.gamma == null) {
            this.gamma = new Tensor(1, 1, 1, this.meanNum, MatrixUtils.one(this.meanNum), true);
            if (this.network != null) {
                this.diffGamma = this.network.createParamterGrad(1, 1, 1, this.meanNum, true);
            } else {
                this.diffGamma = new Tensor(1, 1, 1, this.meanNum, true);
            }
        }
        if (this.beta == null && this.hasBias) {
            this.beta = new Tensor(1, 1, 1, this.meanNum, true);
            if (this.network != null) {
                this.diffBeta = this.network.createParamterGrad(1, 1, 1, this.meanNum, true);
            } else {
                this.diffBeta = new Tensor(1, 1, 1, this.meanNum, true);
            }
        }
        if (this.output == null || this.number != this.output.number) {
            this.output = Tensor.createTensor(this.output, this.number, this.oChannel, this.oHeight, this.oWidth, true);
        }
    }

    @Override // com.omega.engine.nn.layer.Layer
    public void initParam() {
    }

    @Override // com.omega.engine.nn.layer.normalization.NormalizationLayer, com.omega.engine.nn.layer.Layer
    public void initBack() {
        if (this.diff == null) {
            this.diff = new Tensor(this.input.number, this.input.channel, this.input.height, this.input.width, true);
        }
    }

    public void initBack(Tensor tensor) {
        if (this.diff == null) {
            this.diff = new Tensor(tensor.number, tensor.channel, tensor.height, tensor.width, true);
        }
    }

    @Override // com.omega.engine.nn.layer.Layer
    public void output() {
        this.kernel.forward_llm(this.gamma, this.beta, this.input, this.output);
    }

    @Override // com.omega.engine.nn.layer.Layer
    public Tensor getOutput() {
        return this.output;
    }

    @Override // com.omega.engine.nn.layer.Layer
    public void forward() {
        init();
        setInput();
        output();
    }

    @Override // com.omega.engine.nn.layer.Layer
    public void diff() {
        this.kernel.backward_llm(this.input, this.delta, this.diff, this.gamma, this.diffGamma, this.diffBeta);
    }

    @Override // com.omega.engine.nn.layer.Layer
    public void back() {
        initBack();
        setDelta();
        diff();
        if (this.network.GRADIENT_CHECK) {
            gradientCheck();
        }
    }

    @Override // com.omega.engine.nn.layer.Layer
    public void update() {
        if (this.freeze) {
            return;
        }
        if (this.updater != null) {
            this.updater.updateForBN(this);
            return;
        }
        for (int i = 0; i < this.gamma.dataLength; i++) {
            float[] fArr = this.gamma.data;
            int i2 = i;
            fArr[i2] = fArr[i2] - (this.learnRate * this.diffGamma.data[i]);
        }
        for (int i3 = 0; i3 < this.beta.dataLength; i3++) {
            float[] fArr2 = this.beta.data;
            int i4 = i3;
            fArr2[i4] = fArr2[i4] - (this.learnRate * this.diffBeta.data[i3]);
        }
    }

    @Override // com.omega.engine.nn.layer.Layer
    public void showDiff() {
    }

    @Override // com.omega.engine.nn.layer.Layer
    public LayerType getLayerType() {
        return LayerType.layer_norm;
    }

    @Override // com.omega.engine.nn.layer.Layer
    public LayerInit save() {
        return null;
    }

    @Override // com.omega.engine.nn.layer.Layer
    public float[][][][] output(float[][][][] fArr) {
        return (float[][][][]) null;
    }

    public BNType getBnType() {
        return this.bnType;
    }

    public void setBnType(BNType bNType) {
        this.bnType = bNType;
    }

    @Override // com.omega.engine.nn.layer.Layer
    public void initCache() {
    }

    @Override // com.omega.engine.nn.layer.Layer
    public void forward(Tensor tensor) {
        init(tensor);
        setInput(tensor);
        output();
    }

    @Override // com.omega.engine.nn.layer.Layer
    public void back(Tensor tensor) {
        initBack(tensor);
        setDelta(tensor);
        diff();
        if (this.network.GRADIENT_CHECK) {
            gradientCheck();
        }
    }

    @Override // com.omega.engine.nn.layer.Layer
    public void backTemp() {
    }
}
