package com.omega.engine.nn.network;

import com.omega.common.data.Tensor;
import com.omega.common.utils.MatrixOperation;
import com.omega.engine.ad.op.TensorOP;
import com.omega.engine.loss.LossFunction;
import com.omega.engine.nn.layer.Layer;
import com.omega.engine.nn.model.NetworkInit;
import com.omega.engine.updater.UpdaterFactory;
import com.omega.engine.updater.UpdaterType;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import jcuda.Pointer;

/* loaded from: input_file:com/omega/engine/nn/network/Network.class */
public abstract class Network {
    public int time;
    public NetworkType networkType;
    public Map<String, Float> updaterParams;
    public LossFunction lossFunction;
    public Tensor input;
    public Tensor lossDiff;
    public Pointer workspace;
    public boolean CUDNN = false;
    private int threadNum = 8;
    public List<Tensor> paramters = new ArrayList();
    public UpdaterType updater = UpdaterType.none;
    public RunModel RUN_MODEL = RunModel.TRAIN;
    public boolean GRADIENT_CHECK = false;
    public int layerCount = 0;
    public int trainingTime = 100;
    public int currentTrainingTime = 0;
    public List<Layer> layerList = new ArrayList();
    public float accuracy = 0.01f;
    public float learnRate = 0.01f;
    public float errorRate = 0.001f;
    public float currentError = 0.0f;
    public int number = 0;
    public int channel = 0;
    public int height = 0;
    public int width = 0;
    public int oChannel = 0;
    public int oHeight = 0;
    public int oWidth = 0;
    public int train_time = 0;
    public long workspaceSize = 0;
    public boolean PROPAGATE_DOWN = false;

    public abstract void init() throws Exception;

    public abstract Tensor predict(Tensor tensor);

    public abstract Tensor forward(Tensor tensor);

    public abstract void back(Tensor tensor);

    public abstract Tensor loss(Tensor tensor, Tensor tensor2);

    public abstract Tensor lossDiff(Tensor tensor, Tensor tensor2);

    public abstract Tensor loss(Tensor tensor, Tensor tensor2, Tensor tensor3);

    public abstract Tensor lossDiff(Tensor tensor, Tensor tensor2, Tensor tensor3);

    public abstract NetworkType getNetworkType();

    public abstract void clearGrad();

    public Tensor createParamterGrad(int i, int i2, int i3, int i4, boolean z) {
        Tensor tensor = new Tensor(i, i2, i3, i4, z);
        addPamamter(tensor);
        return tensor;
    }

    public void addPamamter(Tensor tensor) {
        this.paramters.add(tensor);
    }

    public Tensor getDiff() {
        return getNextLayer(0).diff;
    }

    public void setNumber(int i) {
        this.number = i;
    }

    public void setInputData(Tensor tensor) {
        this.number = tensor.number;
        this.input = tensor;
    }

    public void setLossDiff(Tensor tensor) {
        this.lossDiff = tensor;
        getLastLayer().setDelta(this.lossDiff);
    }

    public Tensor getOutput() {
        return getLastLayer().getOutput();
    }

    public int getOutputNum() {
        return getLastLayer().getOutput().channel * getLastLayer().getOutput().height * getLastLayer().getOutput().width;
    }

    public Layer getLastLayer() {
        return this.layerList.get(this.layerList.size() - 1);
    }

    public Layer getPreLayer(int i) {
        if (i <= 0 || i >= this.layerCount) {
            return null;
        }
        return this.layerList.get(i - 1);
    }

    public Layer getNextLayer(int i) {
        if (i < this.layerCount - 1) {
            return this.layerList.get(i + 1);
        }
        return null;
    }

    public Layer getLayer(int i) {
        if (i < this.layerCount) {
            return this.layerList.get(i);
        }
        return null;
    }

    public Tensor getDelta(int i) {
        return (i <= 0 || i >= this.layerCount - 1) ? this.lossDiff : this.layerList.get(i + 1).diff;
    }

    public void addLayer(Layer layer) {
        layer.setNetwork(this);
        layer.setIndex(this.layerList.size());
        if (layer.updater == null) {
            layer.setUpdater(UpdaterFactory.create(this.updater, this.updaterParams));
        }
        if (layer.index <= 1) {
            layer.PROPAGATE_DOWN = false;
        }
        this.layerList.add(layer);
    }

    public NetworkInit save() {
        NetworkInit networkInit = new NetworkInit(this);
        Iterator<Layer> it = this.layerList.iterator();
        while (it.hasNext()) {
            networkInit.getLayers().add(it.next().save());
        }
        return networkInit;
    }

    public void update() {
        this.train_time++;
        for (int i = this.layerCount - 1; i >= 0; i--) {
            Layer layer = this.layerList.get(i);
            layer.learnRate = this.learnRate;
            layer.update();
        }
    }

    public void update(int i) {
        this.train_time++;
        for (int i2 = this.layerCount - 1; i2 >= 0; i2--) {
            Layer layer = this.layerList.get(i2);
            layer.learnRate = this.learnRate;
            layer.update();
        }
    }

    public void unfreeze() {
        for (int i = 0; i < this.layerCount; i++) {
            this.layerList.get(i).freeze = false;
        }
    }

    public void saveToJson(String str) {
    }

    public int getThreadNum() {
        return this.threadNum;
    }

    public void setThreadNum(int i) {
        this.threadNum = i;
    }

    public void clipGradNorm(float f) {
        float[] fArr = new float[this.paramters.size()];
        System.out.println(this.paramters.size());
        for (int i = 0; i < this.paramters.size(); i++) {
            fArr[i] = this.paramters.get(i).norm();
        }
        float pow = f / (((float) Math.pow(MatrixOperation.norm(fArr), 0.5d)) + 1.0E-6f);
        System.out.println("clip_coef:" + pow);
        if (pow < 1.0f) {
            for (Tensor tensor : this.paramters) {
                TensorOP.mul(tensor, pow, tensor);
            }
        }
    }
}
