package hex.genmodel.algos.gam;

import hex.genmodel.ConverterFactoryProvidingModel;
import hex.genmodel.GenModel;
import hex.genmodel.MojoModel;
import hex.genmodel.easy.CategoricalEncoder;
import hex.genmodel.easy.EasyPredictModelWrapper;
import hex.genmodel.easy.RowData;
import hex.genmodel.easy.RowToRawDataConverter;
import hex.genmodel.utils.ArrayUtils;
import hex.genmodel.utils.DistributionFamily;
import hex.genmodel.utils.LinkFunctionType;
import java.util.Map;
import org.apache.commons.logging.impl.SimpleLog;

/* loaded from: input_file:hex/genmodel/algos/gam/GamMojoModelBase.class */
public abstract class GamMojoModelBase extends MojoModel implements ConverterFactoryProvidingModel, Cloneable {
    public LinkFunctionType _link_function;
    boolean _useAllFactorLevels;
    int _cats;
    int[] _catNAFills;
    int[] _catOffsets;
    int _nums;
    int _numsCenter;
    double[] _numNAFillsCenter;
    boolean _meanImputation;
    double[] _beta_no_center;
    double[] _beta_center;
    double[][] _beta_multinomial;
    double[][] _beta_multinomial_no_center;
    double[][] _beta_multinomial_center;
    int[] _spline_orders;
    int[] _spline_orders_sorted;
    DistributionFamily _family;
    String[][] _gam_columns;
    String[][] _gam_columns_sorted;
    int[] _d;
    int[] _m;
    int[] _M;
    int[] _gamPredSize;
    int _num_gam_columns;
    int[] _bs;
    int[] _bs_sorted;
    int[] _num_knots;
    int[] _num_knots_sorted;
    int[] _num_knots_sorted_minus1;
    int[] _numBasisSize;
    int[] _num_knots_TP;
    double[][][] _knots;
    double[][][] _binvD;
    double[][][] _zTranspose;
    double[][][] _zTransposeCS;
    String[][] _gamColNames;
    String[][] _gamColNamesCenter;
    String[] _names_no_centering;
    int _totFeatureSize;
    int _betaSizePerClass;
    int _betaCenterSizePerClass;
    double _tweedieLinkPower;
    double[][] _hj;
    int _numExpandedGamCols;
    int _numExpandedGamColsCenter;
    int _lastClass;
    int[][][] _allPolyBasisList;
    int _numTPCol;
    int _numCSCol;
    int _numISCol;
    int[] _tpDistzCSSize;
    boolean[] _dEven;
    double[] _constantTerms;
    double[][] _gamColMeansRaw;
    double[][] _oneOGamColStd;
    boolean _standardize;
    ISplines[] _iSplineBasis;

    /* renamed from: hex.genmodel.algos.gam.GamMojoModelBase$1, reason: invalid class name */
    /* loaded from: input_file:hex/genmodel/algos/gam/GamMojoModelBase$1.class */
    static /* synthetic */ class AnonymousClass1 {
        static final /* synthetic */ int[] $SwitchMap$hex$genmodel$utils$LinkFunctionType = new int[LinkFunctionType.values().length];

        static {
            try {
                $SwitchMap$hex$genmodel$utils$LinkFunctionType[LinkFunctionType.identity.ordinal()] = 1;
            } catch (NoSuchFieldError e) {
            }
            try {
                $SwitchMap$hex$genmodel$utils$LinkFunctionType[LinkFunctionType.logit.ordinal()] = 2;
            } catch (NoSuchFieldError e2) {
            }
            try {
                $SwitchMap$hex$genmodel$utils$LinkFunctionType[LinkFunctionType.log.ordinal()] = 3;
            } catch (NoSuchFieldError e3) {
            }
            try {
                $SwitchMap$hex$genmodel$utils$LinkFunctionType[LinkFunctionType.inverse.ordinal()] = 4;
            } catch (NoSuchFieldError e4) {
            }
            try {
                $SwitchMap$hex$genmodel$utils$LinkFunctionType[LinkFunctionType.tweedie.ordinal()] = 5;
            } catch (NoSuchFieldError e5) {
            }
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public GamMojoModelBase(String[] strArr, String[][] strArr2, String str) {
        super(strArr, strArr2, str);
    }

    @Override // hex.genmodel.GenModel
    public double[] score0(double[] dArr, double[] dArr2) {
        if (this._meanImputation) {
            imputeMissingWithMeans(dArr);
        }
        return gamScore0(dArr, dArr2);
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    /* JADX WARN: Type inference failed for: r1v42, types: [double[], double[][]] */
    public void init() {
        this._num_knots_sorted_minus1 = new int[this._num_knots_sorted.length];
        for (int i = 0; i < this._num_knots_sorted.length; i++) {
            this._num_knots_sorted_minus1[i] = this._num_knots_sorted[i] - 1;
        }
        if (this._numCSCol > 0) {
            this._hj = new double[this._numCSCol];
            for (int i2 = 0; i2 < this._numCSCol; i2++) {
                this._hj[i2] = ArrayUtils.eleDiff(this._knots[i2][0]);
            }
        }
        if (this._numISCol > 0) {
            this._numBasisSize = new int[this._numISCol];
            this._iSplineBasis = new ISplines[this._numISCol];
            for (int i3 = 0; i3 < this._numISCol; i3++) {
                int i4 = i3 + this._numCSCol;
                this._numBasisSize[i3] = (this._num_knots_sorted[i4] + this._spline_orders_sorted[i4]) - 2;
                this._iSplineBasis[i3] = new ISplines(this._spline_orders_sorted[i4], this._knots[i4][0]);
            }
        }
        if (this._numTPCol > 0) {
            this._tpDistzCSSize = new int[this._numTPCol];
            this._dEven = new boolean[this._numTPCol];
            this._constantTerms = new double[this._numTPCol];
            for (int i5 = 0; i5 < this._numTPCol; i5++) {
                int i6 = i5 + this._numCSCol + this._numISCol;
                this._tpDistzCSSize[i5] = this._num_knots_sorted[i6] - this._M[i5];
                this._dEven[i5] = this._d[i6] % 2 == 0;
                this._constantTerms[i5] = GamUtilsThinPlateRegression.calTPConstantTerm(this._m[i5], this._d[i6], this._dEven[i5]);
            }
        }
        this._lastClass = this._nclasses - 1;
    }

    @Override // hex.genmodel.GenModel
    public GenModel internal_threadSafeInstance() {
        try {
            GamMojoModelBase gamMojoModelBase = (GamMojoModelBase) clone();
            gamMojoModelBase.init();
            return gamMojoModelBase;
        } catch (CloneNotSupportedException e) {
            throw new RuntimeException(e);
        }
    }

    abstract double[] gamScore0(double[] dArr, double[] dArr2);

    private void imputeMissingWithMeans(double[] dArr) {
        for (int i = 0; i < this._cats; i++) {
            if (Double.isNaN(dArr[i])) {
                dArr[i] = this._catNAFills[i];
            }
        }
        for (int i2 = 0; i2 < this._numsCenter; i2++) {
            if (Double.isNaN(dArr[i2 + this._cats])) {
                dArr[i2 + this._cats] = this._numNAFillsCenter[i2];
            }
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public double evalLink(double d) {
        switch (AnonymousClass1.$SwitchMap$hex$genmodel$utils$LinkFunctionType[this._link_function.ordinal()]) {
            case 1:
                return GenModel.GLM_identityInv(d);
            case SimpleLog.LOG_LEVEL_DEBUG /* 2 */:
                return GenModel.GLM_logitInv(d);
            case SimpleLog.LOG_LEVEL_INFO /* 3 */:
                return GenModel.GLM_logInv(d);
            case SimpleLog.LOG_LEVEL_WARN /* 4 */:
                return GenModel.GLM_inverseInv(d);
            case SimpleLog.LOG_LEVEL_ERROR /* 5 */:
                return GenModel.GLM_tweedieInv(d, this._tweedieLinkPower);
            default:
                throw new UnsupportedOperationException("Unexpected link function " + this._link_function);
        }
    }

    int readCatVal(double d, int i) {
        int i2 = this._useAllFactorLevels ? (int) d : ((int) d) - 1;
        if (i2 < 0) {
            return -1;
        }
        return i2 + this._catOffsets[i];
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public double generateEta(double[] dArr, double[] dArr2) {
        double d = 0.0d;
        int length = this._catOffsets.length - 1;
        for (int i = 0; i < length; i++) {
            int readCatVal = readCatVal(dArr2[i], i);
            if (readCatVal < this._catOffsets[i + 1] && readCatVal >= 0) {
                d += dArr[readCatVal];
            }
        }
        int i2 = this._catOffsets[this._cats] - this._cats;
        int length2 = (dArr.length - 1) - i2;
        for (int i3 = this._cats; i3 < length2; i3++) {
            d += dArr[i2 + i3] * dArr2[i3];
        }
        return d + dArr[dArr.length - 1];
    }

    private boolean gamificationNeeded(double[] dArr, int i) {
        for (int i2 = i; i2 < dArr.length; i2++) {
            if (!Double.isNaN(dArr[i2])) {
                return false;
            }
        }
        return true;
    }

    int addCSGamification(RowData rowData, int i, int i2, double[] dArr) {
        Object obj = rowData.get(this._gam_columns_sorted[i][0]);
        if (obj == null) {
            return i2;
        }
        double parseDouble = obj instanceof String ? Double.parseDouble((String) obj) : ((Double) obj).doubleValue();
        double[] dArr2 = new double[this._num_knots_sorted[i]];
        double[] dArr3 = new double[this._num_knots_sorted_minus1[i]];
        GamUtilsCubicRegression.expandOneGamCol(parseDouble, this._binvD[i], dArr2, this._hj[i], this._knots[i][0]);
        ArrayUtils.multArray(dArr2, this._zTranspose[i], dArr3);
        System.arraycopy(dArr3, 0, dArr, i2, this._num_knots_sorted_minus1[i]);
        return i2;
    }

    int addISGamification(RowData rowData, int i, int i2, int i3, double[] dArr) {
        Object obj = rowData.get(this._gam_columns_sorted[i][0]);
        if (obj == null) {
            return i3;
        }
        double parseDouble = obj instanceof String ? Double.parseDouble((String) obj) : ((Double) obj).doubleValue();
        double[] dArr2 = new double[this._numBasisSize[i2]];
        this._iSplineBasis[i2].gamifyVal(dArr2, parseDouble);
        System.arraycopy(dArr2, 0, dArr, i3, this._numBasisSize[i2]);
        return i3;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public double[] addExpandGamCols(double[] dArr, RowData rowData) {
        int i = this._nfeatures - this._numExpandedGamColsCenter;
        if (!gamificationNeeded(dArr, i)) {
            return dArr;
        }
        double[] nanArray = ArrayUtils.nanArray(this._nfeatures);
        System.arraycopy(dArr, 0, nanArray, 0, i);
        int i2 = 0;
        int i3 = 0;
        for (int i4 = 0; i4 < this._num_gam_columns; i4++) {
            if (this._bs_sorted[i4] == 0) {
                i = addCSGamification(rowData, i4, i, nanArray);
            } else if (this._bs_sorted[i4] == 1) {
                addTPGamification(rowData, i4, i2, i, nanArray);
                i2++;
            } else {
                if (this._bs_sorted[i4] != 2) {
                    throw new IllegalArgumentException("spline type not implemented!");
                }
                addISGamification(rowData, i4, i3, i, nanArray);
                i3++;
            }
            i += this._num_knots_sorted_minus1[i4];
        }
        return nanArray;
    }

    int addTPGamification(RowData rowData, int i, int i2, int i3, double[] dArr) {
        double[] grabPredictorVals = grabPredictorVals(this._gam_columns_sorted[i], rowData);
        if (grabPredictorVals == null) {
            return i3;
        }
        double[] dArr2 = new double[this._num_knots_sorted[i]];
        GamUtilsThinPlateRegression.calculateDistance(dArr2, grabPredictorVals, this._num_knots_sorted[i], this._knots[i], this._d[i], this._m[i2], this._dEven[i2], this._constantTerms[i2], this._oneOGamColStd[i2], this._standardize);
        double[] dArr3 = new double[this._tpDistzCSSize[i2]];
        ArrayUtils.multArray(dArr2, this._zTransposeCS[i2], dArr3);
        double[] dArr4 = new double[this._M[i2]];
        GamUtilsThinPlateRegression.calculatePolynomialBasis(dArr4, grabPredictorVals, this._d[i], this._M[i2], this._allPolyBasisList[i2], this._gamColMeansRaw[i2], this._oneOGamColStd[i2], this._standardize);
        double[] dArr5 = new double[this._num_knots_sorted[i]];
        double[] dArr6 = new double[this._num_knots_sorted_minus1[i]];
        System.arraycopy(dArr3, 0, dArr5, 0, dArr3.length);
        System.arraycopy(dArr4, 0, dArr5, dArr3.length, this._M[i2]);
        ArrayUtils.multArray(dArr5, this._zTranspose[i], dArr6);
        System.arraycopy(dArr6, 0, dArr, i3, dArr6.length);
        return i3;
    }

    double[] grabPredictorVals(String[] strArr, RowData rowData) {
        int length = strArr.length;
        double[] dArr = new double[length];
        for (int i = 0; i < length; i++) {
            Object obj = rowData.get(strArr[i]);
            if (obj == null) {
                return null;
            }
            dArr[i] = obj instanceof String ? Double.parseDouble((String) obj) : ((Double) obj).doubleValue();
        }
        return dArr;
    }

    @Override // hex.genmodel.ConverterFactoryProvidingModel
    public RowToRawDataConverter makeConverterFactory(Map<String, Integer> map, Map<Integer, CategoricalEncoder> map2, EasyPredictModelWrapper.ErrorConsumer errorConsumer, EasyPredictModelWrapper.Config config) {
        return new GamRowToRawDataConverter(this, map, map2, errorConsumer, config);
    }
}
