package hex.gam.GamSplines;

import hex.gam.GAMModel;
import hex.genmodel.utils.MathUtils;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import water.MRTask;
import water.fvec.Chunk;
import water.fvec.Frame;
import water.fvec.NewChunk;
import water.util.ArrayUtils;

/* loaded from: input_file:hex/gam/GamSplines/ThinPlateRegressionUtils.class */
public class ThinPlateRegressionUtils {

    /* loaded from: input_file:hex/gam/GamSplines/ThinPlateRegressionUtils$ScaleTPPenalty.class */
    public static class ScaleTPPenalty extends MRTask<ScaleTPPenalty> {
        public double[][] _penaltyMat;
        double[] _maxAbsRowSum;
        public int _initChunks;
        public double _s_scale;

        public ScaleTPPenalty(double[][] dArr, Frame frame) {
            this._penaltyMat = dArr;
            this._initChunks = frame.vec(0).nChunks();
        }

        @Override // water.MRTask
        public void map(Chunk[] chunkArr, NewChunk[] newChunkArr) {
            this._maxAbsRowSum = new double[this._initChunks];
            int cidx = chunkArr[0].cidx();
            this._maxAbsRowSum[cidx] = Double.NEGATIVE_INFINITY;
            int i = chunkArr[0]._len;
            for (int i2 = 0; i2 < i; i2++) {
                double d = 0.0d;
                for (Chunk chunk : chunkArr) {
                    d += Math.abs(chunk.atd(i2));
                }
                if (d > this._maxAbsRowSum[cidx]) {
                    this._maxAbsRowSum[cidx] = d;
                }
            }
        }

        @Override // water.MRTask
        public void reduce(ScaleTPPenalty scaleTPPenalty) {
            ArrayUtils.add(this._maxAbsRowSum, scaleTPPenalty._maxAbsRowSum);
        }

        @Override // water.MRTask
        public void postGlobal() {
            double maxValue = ArrayUtils.maxValue(this._maxAbsRowSum);
            this._s_scale = (maxValue * maxValue) / ArrayUtils.rNorm(this._penaltyMat, 'i');
            ArrayUtils.mult(this._penaltyMat, this._s_scale);
            this._s_scale = 1.0d / this._s_scale;
        }
    }

    public static int calculatem(int i) {
        return ((int) Math.floor((i + 1.0d) * 0.5d)) + 1;
    }

    public static int calculateM(int i, int i2) {
        return MathUtils.combinatorial((i + i2) - 1, i);
    }

    public static List<Integer[]> findPolyBasis(int i, int i2) {
        int[] iArr = new int[i2 - 1];
        for (int i3 = 1; i3 < i2; i3++) {
            iArr[i3 - 1] = i3;
        }
        Integer[] numArr = new Integer[i];
        ArrayList arrayList = new ArrayList();
        for (int i4 : iArr) {
            ArrayList arrayList2 = new ArrayList();
            findOnePerm(i4, iArr, 0, arrayList2, null);
            mergeCombos(arrayList2, numArr, iArr, arrayList);
        }
        return findAllPolybasis(arrayList);
    }

    public static List<Integer[]> findAllPolybasis(List<Integer[]> list) {
        int size = list.size();
        ArrayList arrayList = new ArrayList();
        for (int i = 0; i < size; i++) {
            Integer[] numArr = list.get(i);
            int[] generateOrderFreq = generateOrderFreq(numArr);
            ArrayList arrayList2 = new ArrayList();
            findPermute(generateOrderFreq, new ArrayList(), numArr.length, arrayList2);
            addPermutationList(arrayList, arrayList2);
        }
        Integer[] numArr2 = new Integer[list.get(0).length];
        for (int i2 = 0; i2 < numArr2.length; i2++) {
            numArr2[i2] = 0;
        }
        arrayList.add(0, numArr2);
        return arrayList;
    }

    public static void addPermutationList(List<Integer[]> list, List<List<Integer>> list2) {
        Iterator<List<Integer>> it = list2.iterator();
        while (it.hasNext()) {
            list.add((Integer[]) it.next().toArray(new Integer[0]));
        }
    }

    public static void findPermute(int[] iArr, List<Integer> list, int i, List<List<Integer>> list2) {
        if (i == 0) {
            list2.add(list);
            return;
        }
        for (int i2 = 0; i2 < iArr.length; i2++) {
            int i3 = iArr[i2];
            if (i3 > 0) {
                int i4 = i2;
                iArr[i4] = iArr[i4] - 1;
                ArrayList arrayList = new ArrayList(list);
                arrayList.add(Integer.valueOf(i2));
                findPermute(iArr, arrayList, i - 1, list2);
                iArr[i2] = i3;
            }
        }
    }

    public static int[] generateOrderFreq(Integer[] numArr) {
        int[] iArr = new int[ArrayUtils.maxValue(numArr) + 1];
        for (Integer num : numArr) {
            int intValue = num.intValue();
            iArr[intValue] = iArr[intValue] + 1;
        }
        return iArr;
    }

    /* JADX WARN: Multi-variable type inference failed */
    public static void mergeCombos(ArrayList<int[]> arrayList, Integer[] numArr, int[] iArr, List<Integer[]> list) {
        Iterator<int[]> it = arrayList.iterator();
        while (it.hasNext()) {
            int[] next = it.next();
            Arrays.fill((Object[]) numArr, (Object) 0);
            expandCombo(next, iArr, numArr);
            list.add(numArr.clone());
        }
    }

    public static void expandCombo(int[] iArr, int[] iArr2, Integer[] numArr) {
        int i = 0;
        for (int i2 = 0; i2 < iArr2.length; i2++) {
            if (iArr[i2] == 0) {
                int i3 = i;
                i++;
                numArr[i3] = 0;
            } else {
                for (int i4 = 0; i4 < iArr[i2]; i4++) {
                    int i5 = i;
                    i++;
                    numArr[i5] = Integer.valueOf(iArr2[i2]);
                }
            }
        }
    }

    /* JADX WARN: Multi-variable type inference failed */
    public static void findOnePerm(int i, int[] iArr, int i2, ArrayList<int[]> arrayList, int[] iArr2) {
        if (i == 0) {
            if (iArr2 != null) {
                arrayList.add(iArr2.clone());
            }
        } else {
            if (i < 0 || i2 >= iArr.length) {
                return;
            }
            int i3 = i / iArr[i2];
            if (iArr2 == null) {
                iArr2 = (int[]) iArr.clone();
            }
            for (int i4 = 0; i4 <= i3; i4++) {
                setCombo(iArr2, i2, i4);
                findOnePerm(i - (i4 * iArr[i2]), iArr, i2 + 1, arrayList, iArr2);
            }
        }
    }

    public static void setCombo(int[] iArr, int i, int i2) {
        iArr[i] = i2;
        int length = iArr.length;
        for (int i3 = i + 1; i3 < length; i3++) {
            iArr[i3] = 0;
        }
    }

    public static double[][] generateStarT(double[][] dArr, List<Integer[]> list, double[] dArr2, double[] dArr3, boolean z) {
        int length = dArr[0].length;
        int size = list.size();
        int length2 = dArr.length;
        double[][] dArr4 = new double[length2][length];
        for (int i = 0; i < length2; i++) {
            for (int i2 = 0; i2 < length; i2++) {
                dArr4[i][i2] = z ? (dArr[i][i2] - dArr2[i]) * dArr3[i] : dArr[i][i2] - dArr2[i];
            }
        }
        double[][] dArr5 = new double[length][size];
        for (int i3 = 0; i3 < length; i3++) {
            for (int i4 = 0; i4 < size; i4++) {
                Integer[] numArr = list.get(i4);
                double d = 1.0d;
                for (int i5 = 0; i5 < length2; i5++) {
                    d *= Math.pow(dArr4[i5][i3], numArr[i5].intValue());
                }
                dArr5[i3][i4] = d;
            }
        }
        return dArr5;
    }

    public static void fillRowOneValue(NewChunk[] newChunkArr, int i, double d) {
        for (int i2 = 0; i2 < i; i2++) {
            newChunkArr[i2].addNum(d);
        }
    }

    public static void fillRowArray(NewChunk[] newChunkArr, int i, double[] dArr) {
        for (int i2 = 0; i2 < i; i2++) {
            newChunkArr[i2].addNum(dArr[i2]);
        }
    }

    public static boolean checkRowNA(Chunk[] chunkArr, int i) {
        for (Chunk chunk : chunkArr) {
            if (Double.isNaN(chunk.atd(i))) {
                return true;
            }
        }
        return false;
    }

    public static boolean checkFrameRowNA(Frame frame, long j) {
        int numCols = frame.numCols();
        for (int i = 0; i < numCols; i++) {
            if (Double.isNaN(frame.vec(i).at(j))) {
                return true;
            }
        }
        return false;
    }

    public static String genThinPlateNameStart(GAMModel.GAMParameters gAMParameters, int i) {
        StringBuffer stringBuffer = new StringBuffer();
        for (int i2 = 0; i2 < gAMParameters._gam_columns_sorted[i].length; i2++) {
            stringBuffer.append(gAMParameters._gam_columns_sorted[i][i2]);
            stringBuffer.append("_");
        }
        stringBuffer.append(gAMParameters._bs_sorted[i]);
        stringBuffer.append("_");
        return stringBuffer.toString();
    }

    public static String[] extractColNames(String[] strArr, int i, int i2, int i3) {
        String[] strArr2 = new String[i3];
        System.arraycopy(strArr, i, strArr2, i2, i3);
        return strArr2;
    }

    public static int[][] convertList2Array(List<Integer[]> list, int i, int i2) {
        int[][] iArr = new int[i][i2];
        for (int i3 = 0; i3 < i; i3++) {
            iArr[i3] = Arrays.asList(list.get(i3)).stream().mapToInt((v0) -> {
                return v0.intValue();
            }).toArray();
        }
        return iArr;
    }

    public static double[][] genKnotsMultiplePreds(Frame frame, GAMModel.GAMParameters gAMParameters, int i) {
        Frame sort = frame.sort(new int[]{0});
        long floor = (long) Math.floor((1.0d / gAMParameters._num_knots[i]) * sort.numRows());
        int length = gAMParameters._gam_columns[i].length;
        double[][] dArr = new double[length][gAMParameters._num_knots[i]];
        long numRows = sort.numRows();
        for (int i2 = 0; i2 < gAMParameters._num_knots[i]; i2++) {
            long j = i2 * floor;
            long j2 = (i2 + 1) * floor;
            while (true) {
                if (j < numRows && j < j2) {
                    if (checkFrameRowNA(sort, j)) {
                        j++;
                    } else {
                        for (int i3 = 0; i3 < length; i3++) {
                            dArr[i3][i2] = sort.vec(i3).at(j);
                        }
                    }
                }
            }
        }
        sort.remove();
        gAMParameters._num_knots[i] = dArr[0].length;
        return dArr;
    }
}
