package com.omega.engine.updater;

import com.omega.common.utils.MatrixOperation;
import com.omega.common.utils.MatrixUtils;
import com.omega.engine.nn.layer.ConvolutionLayer;
import com.omega.engine.nn.layer.Layer;
import com.omega.engine.nn.layer.LayerType;
import com.omega.engine.nn.layer.normalization.NormalizationLayer;

/* loaded from: input_file:com/omega/engine/updater/Momentum.class */
public class Momentum extends Updater {
    @Override // com.omega.engine.updater.Updater
    public void update(Layer layer) {
        if (this.vdw == null) {
            this.vdw = MatrixUtils.zero(layer.width * layer.oWidth);
            if (layer.hasBias) {
                this.vdb = MatrixUtils.zero(layer.oWidth);
            }
        }
        this.vdw = MomentumUtils.momentum(layer.diffW.data, this.vdw, layer.learnRate);
        if (layer.hasBias) {
            this.vdb = MomentumUtils.momentum(layer.diffB.data, this.vdb, layer.learnRate);
        }
        layer.weight.data = MatrixOperation.add(layer.weight.data, this.vdw);
        if (layer.hasBias) {
            layer.bias.data = MatrixOperation.add(layer.bias.data, this.vdb);
        }
    }

    @Override // com.omega.engine.updater.Updater
    public void updateForMatrix(Layer layer) {
        if (!layer.getLayerType().equals(LayerType.conv)) {
            throw new RuntimeException("this function param must be conv layer.");
        }
        ConvolutionLayer convolutionLayer = (ConvolutionLayer) layer;
        if (this.vdmw == null) {
            this.vdmw = MatrixUtils.zero(convolutionLayer.kernelNum * convolutionLayer.channel * convolutionLayer.kHeight * convolutionLayer.kWidth);
            if (layer.hasBias) {
                this.vdmb = MatrixUtils.zero(convolutionLayer.kernelNum);
            }
        }
        this.vdmw = MomentumUtils.momentum(convolutionLayer.diffW.data, this.vdmw, convolutionLayer.learnRate);
        if (layer.hasBias) {
            this.vdmb = MomentumUtils.momentum(convolutionLayer.diffB.data, this.vdmb, convolutionLayer.learnRate);
        }
        convolutionLayer.weight.data = MatrixOperation.add(convolutionLayer.weight.data, this.vdmw);
        if (layer.hasBias) {
            convolutionLayer.bias.data = MatrixOperation.add(convolutionLayer.bias.data, this.vdmb);
        }
    }

    @Override // com.omega.engine.updater.Updater
    public void updateForBN(NormalizationLayer normalizationLayer) {
    }

    @Override // com.omega.engine.updater.Updater
    public UpdaterType getUpdaterType() {
        return UpdaterType.momentum;
    }

    @Override // com.omega.engine.updater.Updater
    public void update(Layer layer, int i) {
    }
}
