package com.omega.engine.nn.layer;

import com.omega.common.data.Tensor;
import com.omega.engine.gpu.BaseGPUOP;
import com.omega.engine.gpu.data.CacheDataSet;
import com.omega.engine.nn.model.LayerInit;
import com.omega.engine.nn.network.Network;
import com.omega.engine.updater.Updater;

/* loaded from: input_file:com/omega/engine/nn/layer/Layer.class */
public abstract class Layer {
    public Network network;
    public Tensor input;
    public Tensor output;
    public Tensor diff;
    public Tensor delta;
    public Tensor weight;
    public Tensor bias;
    public Tensor diffW;
    public Tensor diffB;
    public Tensor cache_delta;
    public Tensor org_delta;
    public LayerType layerType;
    public Updater updater;
    private CacheDataSet tampDataSet;
    public boolean PROPAGATE_DOWN = true;
    public int index = 0;
    public int number = 0;
    public int channel = 0;
    public int height = 0;
    public int width = 0;
    public int oChannel = 0;
    public int oHeight = 0;
    public int oWidth = 0;
    public ParamsInit paramsInit = ParamsInit.linear;
    public boolean hasBias = true;
    public float lambda = 0.01f;
    public float learnRate = 0.001f;
    public float eta = 1.0E-5f;
    public boolean freeze = false;
    public boolean hasParams = false;
    public boolean isOutput = false;

    public abstract void init();

    public abstract void initBack();

    public abstract void initParam();

    public abstract void output();

    public abstract Tensor getOutput();

    public abstract void diff();

    public abstract void forward();

    public abstract void back();

    public abstract void backTemp();

    public abstract void forward(Tensor tensor);

    public abstract void back(Tensor tensor);

    public abstract void update();

    public abstract void showDiff();

    public abstract LayerType getLayerType();

    public abstract float[][][][] output(float[][][][] fArr);

    public abstract void initCache();

    public void setUpdater(Updater updater) {
        this.updater = updater;
    }

    public void setNetwork(Network network) {
        this.network = network;
    }

    public void setIndex(int i) {
        this.index = i;
    }

    public LayerInit save() {
        return new LayerInit(this);
    }

    public void setInput(Tensor tensor) {
        this.input = tensor;
    }

    public void setInput() {
        this.input = this.network.getPreLayer(this.index).output;
    }

    public void setDelta() {
        if (this.delta == null && this.index < this.network.layerList.size() - 1 && this.network.getNextLayer(this.index).getLayerType() != LayerType.route) {
            this.delta = this.network.getNextLayer(this.index).diff;
        }
        if (this.cache_delta != null) {
            if (this.delta == null || this.delta.number != this.cache_delta.number) {
                this.delta = this.cache_delta;
            } else if (this.cache_delta != this.delta) {
                BaseGPUOP.getKernel().axpy_gpu(this.cache_delta, this.delta, this.delta.getDataLength(), 1.0f, 1, 1);
            }
        }
    }

    public void setDelta(Tensor tensor) {
        this.delta = tensor;
    }

    public float gradientCheck() {
        return 0.0f;
    }

    public CacheDataSet getTampDataSet() {
        return this.tampDataSet;
    }

    public void setTampDataSet(CacheDataSet cacheDataSet) {
        this.tampDataSet = cacheDataSet;
    }

    public int[] outputShape() {
        return new int[]{this.number, this.oChannel, this.oHeight, this.oWidth};
    }
}
