package com.omega.engine.nn.layer.active;

import com.omega.common.data.Tensor;
import com.omega.common.task.Task;
import com.omega.common.task.TaskEngine;
import com.omega.engine.nn.layer.Layer;
import com.omega.engine.nn.layer.LayerType;
import com.omega.engine.nn.layer.active.gpu.TanhKernel;
import com.omega.engine.nn.network.Network;
import java.util.Vector;

/* loaded from: input_file:com/omega/engine/nn/layer/active/TanhLayer.class */
public class TanhLayer extends ActiveFunctionLayer {
    private TanhKernel kernel;

    public TanhLayer() {
    }

    public TanhLayer(Layer layer) {
        setPreLayer(layer);
    }

    public TanhLayer(Network network) {
        this.network = network;
    }

    @Override // com.omega.engine.nn.layer.Layer
    public void initParam() {
    }

    @Override // com.omega.engine.nn.layer.active.ActiveFunctionLayer, com.omega.engine.nn.layer.Layer
    public void init() {
        super.init();
        if (this.kernel == null) {
            this.kernel = new TanhKernel();
        }
    }

    @Override // com.omega.engine.nn.layer.active.ActiveFunctionLayer
    public void init(Tensor tensor) {
        super.init(tensor);
        if (this.kernel == null) {
            this.kernel = new TanhKernel();
        }
    }

    @Override // com.omega.engine.nn.layer.Layer
    public void output() {
        this.kernel.forward(this.input, this.output);
    }

    @Override // com.omega.engine.nn.layer.Layer
    public Tensor getOutput() {
        return this.output;
    }

    @Override // com.omega.engine.nn.layer.Layer
    public void diff() {
        this.kernel.backward(this.output, this.delta, this.diff);
    }

    public void diffTemp() {
        this.kernel.backwardTemp(this.output, this.delta, this.diff);
    }

    @Override // com.omega.engine.nn.layer.Layer
    public void forward() {
        init();
        setInput();
        output();
    }

    @Override // com.omega.engine.nn.layer.Layer
    public void back() {
        initBack();
        setDelta();
        diff();
        if (this.network.GRADIENT_CHECK) {
            gradientCheck();
        }
    }

    @Override // com.omega.engine.nn.layer.Layer
    public void backTemp() {
        initBack();
        setDelta();
        diffTemp();
        if (this.network.GRADIENT_CHECK) {
            gradientCheck();
        }
    }

    @Override // com.omega.engine.nn.layer.Layer
    public void showDiff() {
    }

    @Override // com.omega.engine.nn.layer.Layer
    public LayerType getLayerType() {
        return LayerType.tanh;
    }

    @Override // com.omega.engine.nn.layer.Layer
    public float[][][][] output(final float[][][][] fArr) {
        final float[][][][] fArr2 = new float[this.number][this.oChannel][this.oHeight][this.oWidth];
        Vector<Task<Object>> vector = new Vector<>();
        for (int i = 0; i < this.number; i++) {
            final int i2 = i;
            vector.add(new Task<Object>(i2) { // from class: com.omega.engine.nn.layer.active.TanhLayer.1
                @Override // com.omega.common.task.Task, java.util.concurrent.Callable
                public Object call() throws Exception {
                    for (int i3 = 0; i3 < TanhLayer.this.channel; i3++) {
                        for (int i4 = 0; i4 < TanhLayer.this.height; i4++) {
                            for (int i5 = 0; i5 < TanhLayer.this.width; i5++) {
                                fArr2[i2][i3][i4][i5] = (float) (1.0d / (1.0d + Math.exp(-fArr[i2][i3][i4][i5])));
                            }
                        }
                    }
                    return null;
                }
            });
        }
        TaskEngine.getInstance(this.network.getThreadNum()).dispatchTask(vector);
        return fArr2;
    }

    @Override // com.omega.engine.nn.layer.Layer
    public void initCache() {
    }

    public void initBack(Tensor tensor) {
        this.diff = tensor;
    }

    @Override // com.omega.engine.nn.layer.Layer
    public void forward(Tensor tensor) {
        init(tensor);
        setInput(tensor);
        output();
    }

    @Override // com.omega.engine.nn.layer.Layer
    public void back(Tensor tensor) {
        initBack(tensor);
        setDelta(tensor);
        diff();
        if (this.network.GRADIENT_CHECK) {
            gradientCheck();
        }
    }

    @Override // com.omega.engine.nn.layer.active.ActiveFunctionLayer
    public void forward(Tensor tensor, int i, int i2) {
        init(tensor);
        setInput(tensor);
        output(i, i2);
    }

    @Override // com.omega.engine.nn.layer.active.ActiveFunctionLayer
    public void back(Tensor tensor, int i, int i2) {
        initBack(tensor);
        setDelta(tensor);
        diff(i, i2);
        if (this.network.GRADIENT_CHECK) {
            gradientCheck();
        }
    }

    @Override // com.omega.engine.nn.layer.active.ActiveFunctionLayer
    public void output(int i, int i2) {
        this.kernel.forward(this.input, this.output, i2 * i * this.input.getOnceSize(), i * this.input.getOnceSize());
    }

    @Override // com.omega.engine.nn.layer.active.ActiveFunctionLayer
    public void diff(int i, int i2) {
        this.kernel.backward(this.input, this.delta, this.diff, i2 * i * this.diff.getOnceSize(), i * this.diff.getOnceSize());
    }
}
