package com.omega.engine.active;

import com.omega.common.data.Tensor;
import com.omega.common.utils.JsonUtils;
import com.omega.common.utils.MatrixOperation;
import com.omega.common.utils.MatrixUtils;
import jcuda.Pointer;

/* loaded from: input_file:com/omega/engine/active/ActiveFunction.class */
public abstract class ActiveFunction {
    public ActiveType activeType;
    public float eta = 1.0E-5f;
    public float[] input;
    public float[] output;
    public float[] diff;
    public float[][][] input2d;
    public float[][][] output2d;
    public float[][][] diff2d;

    public abstract float[] active(float[] fArr);

    public abstract void active(Tensor tensor, Tensor tensor2);

    public abstract void active(Pointer pointer, Pointer pointer2, int i);

    public abstract float[] diff();

    public abstract void diff(Tensor tensor, Tensor tensor2);

    public abstract void diff(Pointer pointer, Pointer pointer2, Pointer pointer3, int i);

    public abstract float[] activeTemp(float[] fArr);

    public abstract float[] diffTemp(float[] fArr);

    public abstract float[][][] active(float[][][] fArr);

    public abstract float[][][][] active(float[][][][] fArr);

    public abstract float[][][] diff2d();

    public abstract float[][][][] diff(float[][][][] fArr);

    public abstract float[][][] activeTemp(float[][][] fArr);

    public abstract float[][][] diffTemp(float[][][] fArr);

    public float gradientCheck(float[] fArr) {
        this.input = MatrixUtils.clone(fArr);
        float[] diff = diff();
        float[] division = MatrixOperation.division(MatrixOperation.subtraction(active(MatrixOperation.add(fArr, this.eta)), active(MatrixOperation.subtraction(fArr, this.eta))), 2.0f * this.eta);
        if (this.activeType == ActiveType.relu) {
            for (int i = 0; i < division.length; i++) {
                if (fArr[i] == 0.0f) {
                    division[i] = 0.0f;
                }
            }
        }
        System.out.println("diff:" + JsonUtils.toJson(diff));
        System.out.println("gc:" + JsonUtils.toJson(division));
        return MatrixOperation.sum(MatrixOperation.subtraction(diff, division));
    }
}
