package com.omega.engine.nn.layer;

import com.omega.common.data.Tensor;
import com.omega.engine.gpu.cudnn.PoolingCudnnKernel;
import com.omega.engine.nn.layer.gpu.PoolingBaseKernel;
import com.omega.engine.nn.layer.gpu.PoolingKernel;
import com.omega.engine.pooling.PoolingType;

/* loaded from: input_file:com/omega/engine/nn/layer/PoolingLayer.class */
public class PoolingLayer extends Layer {
    public PoolingType poolingType;
    public int pWidth;
    public int pHeight;
    public int stride;
    public int padding;
    private PoolingBaseKernel kernel;

    public PoolingLayer(int i, int i2, int i3, int i4, int i5, int i6, PoolingType poolingType) {
        this.pWidth = 0;
        this.pHeight = 0;
        this.stride = 1;
        this.padding = 0;
        this.channel = i;
        this.width = i2;
        this.height = i3;
        this.pWidth = i4;
        this.pHeight = i5;
        this.stride = i6;
        this.poolingType = poolingType;
        initParam();
    }

    public PoolingLayer(int i, int i2, int i3, int i4, int i5, int i6, int i7, PoolingType poolingType) {
        this.pWidth = 0;
        this.pHeight = 0;
        this.stride = 1;
        this.padding = 0;
        this.channel = i;
        this.width = i2;
        this.height = i3;
        this.pWidth = i4;
        this.pHeight = i5;
        this.stride = i6;
        this.padding = i7;
        this.poolingType = poolingType;
        initParam();
    }

    @Override // com.omega.engine.nn.layer.Layer
    public void init() {
        this.number = this.network.number;
        if (this.output == null || this.output.number != this.number) {
            this.output = Tensor.createTensor(this.output, this.number, this.oChannel, this.oHeight, this.oWidth, true);
        }
        if (this.kernel == null) {
            if (this.network.CUDNN) {
                this.kernel = new PoolingCudnnKernel(this.poolingType, this.channel, this.height, this.width, this.oHeight, this.oWidth, this.pWidth, this.pHeight, this.stride, this.padding);
            } else {
                this.kernel = new PoolingKernel(this.poolingType, this.channel, this.height, this.width, this.pHeight, this.pWidth, this.stride, this.padding);
            }
        }
    }

    @Override // com.omega.engine.nn.layer.Layer
    public void initBack() {
        if (this.diff == null || this.diff.number != this.number) {
            this.diff = new Tensor(this.number, this.channel, this.height, this.width, true);
        }
    }

    @Override // com.omega.engine.nn.layer.Layer
    public void initParam() {
        this.oChannel = this.channel;
        this.oWidth = (((this.width + this.padding) - this.pWidth) / this.stride) + 1;
        this.oHeight = (((this.height + this.padding) - this.pHeight) / this.stride) + 1;
    }

    @Override // com.omega.engine.nn.layer.Layer
    public void output() {
        this.kernel.forward(this.input, this.output);
    }

    @Override // com.omega.engine.nn.layer.Layer
    public void diff() {
        this.kernel.backward(this.input, this.output, this.delta, this.diff);
    }

    @Override // com.omega.engine.nn.layer.Layer
    public void forward() {
        init();
        setInput();
        output();
    }

    @Override // com.omega.engine.nn.layer.Layer
    public void back() {
        initBack();
        setDelta();
        diff();
        if (this.network.GRADIENT_CHECK) {
            gradientCheck();
        }
    }

    @Override // com.omega.engine.nn.layer.Layer
    public void update() {
    }

    @Override // com.omega.engine.nn.layer.Layer
    public LayerType getLayerType() {
        return LayerType.pooling;
    }

    @Override // com.omega.engine.nn.layer.Layer
    public Tensor getOutput() {
        return this.output;
    }

    @Override // com.omega.engine.nn.layer.Layer
    public void showDiff() {
    }

    @Override // com.omega.engine.nn.layer.Layer
    public float[][][][] output(float[][][][] fArr) {
        return (float[][][][]) null;
    }

    @Override // com.omega.engine.nn.layer.Layer
    public void initCache() {
    }

    @Override // com.omega.engine.nn.layer.Layer
    public void forward(Tensor tensor) {
        init();
        setInput(tensor);
        output();
    }

    @Override // com.omega.engine.nn.layer.Layer
    public void back(Tensor tensor) {
        initBack();
        setDelta(tensor);
        diff();
        if (this.network.GRADIENT_CHECK) {
            gradientCheck();
        }
    }

    @Override // com.omega.engine.nn.layer.Layer
    public void backTemp() {
    }
}
