package com.omega.engine.nn.grad;

import com.omega.common.data.Tensor;
import com.omega.common.utils.MatrixOperation;

/* loaded from: input_file:com/omega/engine/nn/grad/GradClipping.class */
public class GradClipping {
    public static Tensor gradClipping(Tensor tensor, float f) {
        if (tensor.isHasGPU()) {
            tensor.syncHost();
            grad_clipping_cpu(tensor.data, f);
            tensor.hostToDevice();
        } else {
            grad_clipping_cpu(tensor.data, f);
        }
        return tensor;
    }

    public static float[] grad_clipping_cpu(float[] fArr, float f) {
        float sqrt = (float) Math.sqrt(MatrixOperation.sum(MatrixOperation.pow(fArr, 2.0f)));
        System.out.println(sqrt);
        if (sqrt > f) {
            fArr = MatrixOperation.multiplication(fArr, f / sqrt);
        }
        return fArr;
    }
}
