package hex.genmodel.algos.deepwater;

import deepwater.backends.BackendModel;
import deepwater.backends.BackendParams;
import deepwater.backends.BackendTrain;
import deepwater.backends.RuntimeOptions;
import deepwater.datasets.ImageDataSet;
import hex.genmodel.GenModel;
import hex.genmodel.MojoModel;

/* loaded from: input_file:hex/genmodel/algos/deepwater/DeepwaterMojoModel.class */
public class DeepwaterMojoModel extends MojoModel {
    public String _problem_type;
    public int _mini_batch_size;
    public int _height;
    public int _width;
    public int _channels;
    public int _nums;
    public int _cats;
    public int[] _catOffsets;
    public double[] _normMul;
    public double[] _normSub;
    public double[] _normRespMul;
    public double[] _normRespSub;
    public boolean _useAllFactorLevels;
    transient byte[] _network;
    transient byte[] _parameters;
    public transient float[] _meanImageData;
    BackendTrain _backend;
    BackendModel _model;
    ImageDataSet _imageDataSet;
    RuntimeOptions _opts;
    BackendParams _backendParams;
    static final /* synthetic */ boolean $assertionsDisabled;

    /* JADX INFO: Access modifiers changed from: package-private */
    public DeepwaterMojoModel(String[] strArr, String[][] strArr2) {
        super(strArr, strArr2);
    }

    @Override // hex.genmodel.GenModel
    public final double[] score0(double[] dArr, double d, double[] dArr2) {
        float[] fArr;
        if (!$assertionsDisabled && dArr == null) {
            throw new AssertionError("doubles are null");
        }
        int i = this._catOffsets == null ? 0 : this._catOffsets[this._cats];
        if (this._nums > 0) {
            fArr = new float[this._nums + i];
            GenModel.setInput(dArr, fArr, this._nums, this._cats, this._catOffsets, this._normMul, this._normSub, this._useAllFactorLevels);
        } else {
            fArr = new float[dArr.length];
            for (int i2 = 0; i2 < fArr.length; i2++) {
                fArr[i2] = ((float) dArr[i2]) - (this._meanImageData == null ? 0.0f : this._meanImageData[i2]);
            }
        }
        float[] predict = this._backend.predict(this._model, fArr);
        if (!$assertionsDisabled && this._nclasses != predict.length) {
            throw new AssertionError("nclasses " + this._nclasses + " predFloats.length " + predict.length);
        }
        if (this._nclasses > 1) {
            for (int i3 = 0; i3 < predict.length; i3++) {
                dArr2[1 + i3] = predict[i3];
            }
            if (this._balanceClasses) {
                GenModel.correctProbabilities(dArr2, this._priorClassDistrib, this._modelClassDistrib);
            }
            dArr2[0] = GenModel.getPrediction(dArr2, this._priorClassDistrib, dArr, this._defaultThreshold);
        } else if (this._normRespMul == null || this._normRespSub == null) {
            dArr2[0] = predict[0];
        } else {
            dArr2[0] = (predict[0] * this._normRespMul[0]) + this._normRespSub[0];
        }
        return dArr2;
    }

    @Override // hex.genmodel.GenModel
    public double[] score0(double[] dArr, double[] dArr2) {
        return score0(dArr, 0.0d, dArr2);
    }

    public static BackendTrain createDeepWaterBackend(String str) {
        try {
            if (str.equals("mxnet")) {
                str = "deepwater.backends.mxnet.MXNetBackend";
            }
            if (str.equals("tensorflow")) {
                str = "deepwater.backends.tensorflow.TensorflowBackend";
            }
            if (str.equals("caffe")) {
                str = "deepwater.backends.caffe.CaffeBackend";
            }
            if (str.equals("xgrpc")) {
                str = "deepwater.backends.grpc.XGRPCBackendTrain";
            }
            return (BackendTrain) Class.forName(str).newInstance();
        } catch (Exception e) {
            return null;
        }
    }

    static {
        $assertionsDisabled = !DeepwaterMojoModel.class.desiredAssertionStatus();
    }
}
