package com.omega.engine.updater;

import com.omega.engine.nn.layer.Layer;
import com.omega.engine.nn.layer.normalization.NormalizationLayer;
import com.omega.engine.updater.gpu.RMSPropKernel;

/* loaded from: input_file:com/omega/engine/updater/RMSProp.class */
public class RMSProp extends Updater {
    private RMSPropKernel kernel;
    private boolean clamp = false;
    private float min = -0.01f;
    private float max = 0.01f;

    @Override // com.omega.engine.updater.Updater
    public void update(Layer layer) {
        if (this.kernel == null) {
            if (layer.hasBias) {
                this.kernel = new RMSPropKernel(layer.weight.dataLength, layer.bias.dataLength);
            } else {
                this.kernel = new RMSPropKernel(layer.weight.dataLength);
            }
        }
        this.kernel.updateW(layer.diffW, layer.weight, layer.network, layer.learnRate);
        if (layer.hasBias) {
            this.kernel.updateB(layer.diffB, layer.bias, layer.network, layer.learnRate);
        }
    }

    @Override // com.omega.engine.updater.Updater
    public void updateForMatrix(Layer layer) {
    }

    @Override // com.omega.engine.updater.Updater
    public void updateForBN(NormalizationLayer normalizationLayer) {
        if (this.kernel == null) {
            this.kernel = new RMSPropKernel(normalizationLayer.gamma.dataLength, normalizationLayer.beta.dataLength);
        }
        this.kernel.updateW(normalizationLayer.diffGamma, normalizationLayer.gamma, normalizationLayer.network, normalizationLayer.learnRate);
        this.kernel.updateB(normalizationLayer.diffBeta, normalizationLayer.beta, normalizationLayer.network, normalizationLayer.learnRate);
    }

    @Override // com.omega.engine.updater.Updater
    public UpdaterType getUpdaterType() {
        return UpdaterType.RMSProp;
    }

    @Override // com.omega.engine.updater.Updater
    public void update(Layer layer, int i) {
        if (this.kernel == null) {
            if (layer.hasBias) {
                this.kernel = new RMSPropKernel(layer.weight.dataLength, layer.bias.dataLength);
            } else {
                this.kernel = new RMSPropKernel(layer.weight.dataLength);
            }
        }
        this.kernel.updateW(layer.diffW, layer.weight, layer.network, layer.learnRate, i);
        if (layer.hasBias) {
            this.kernel.updateB(layer.diffB, layer.bias, layer.network, layer.learnRate, i);
        }
    }
}
