package ai.libs.jaicore.ml.classification.singlelabel.timeseries.util;

import ai.libs.jaicore.basic.sets.Pair;
import ai.libs.jaicore.ml.classification.singlelabel.timeseries.dataset.TimeSeriesDataset2;
import ai.libs.jaicore.ml.classification.singlelabel.timeseries.exception.TimeSeriesLengthException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;

/* loaded from: input_file:ai/libs/jaicore/ml/classification/singlelabel/timeseries/util/TimeSeriesUtil.class */
public class TimeSeriesUtil {
    private TimeSeriesUtil() {
    }

    public static boolean isTimeSeries(INDArray... iNDArrayArr) {
        for (INDArray iNDArray : iNDArrayArr) {
            if (iNDArray.rank() != 1) {
                return false;
            }
        }
        return true;
    }

    public static boolean isTimeSeries(int i, INDArray... iNDArrayArr) {
        for (INDArray iNDArray : iNDArrayArr) {
            if (iNDArray.rank() != 1 && iNDArray.length() == i) {
                return false;
            }
        }
        return true;
    }

    public static boolean isTimeSeries(int i, double[]... dArr) {
        for (double[] dArr2 : dArr) {
            if (dArr2.length != i) {
                return false;
            }
        }
        return true;
    }

    public static void isTimeSeriesOrException(INDArray... iNDArrayArr) {
        for (INDArray iNDArray : iNDArrayArr) {
            if (!isTimeSeries(iNDArrayArr)) {
                throw new IllegalArgumentException(String.format("The given INDArray is no time series. It should have rank 1, but has a rank of %d.", Integer.valueOf(iNDArray.rank())));
            }
        }
    }

    public static void isTimeSeriesOrException(int i, INDArray... iNDArrayArr) {
        for (INDArray iNDArray : iNDArrayArr) {
            if (!isTimeSeries(iNDArrayArr)) {
                throw new IllegalArgumentException(String.format("The given INDArray is no time series. It should have rank 1, but has a rank of %d.", Integer.valueOf(iNDArray.rank())));
            }
            if (!isTimeSeries(i, iNDArray)) {
                throw new IllegalArgumentException(String.format("The given time series should length 7, but has a length of %d.", Long.valueOf(iNDArray.length())));
            }
        }
    }

    /* JADX WARN: Type inference failed for: r1v3, types: [double[], double[][]] */
    public static void isTimeSeriesOrException(int i, double[]... dArr) {
        for (double[] dArr2 : dArr) {
            if (!isTimeSeries(i, (double[][]) new double[]{dArr2})) {
                throw new IllegalArgumentException(String.format("The given time series should length 7, but has a length of %d.", Integer.valueOf(dArr2.length)));
            }
        }
    }

    public static boolean isSameLength(INDArray iNDArray, INDArray... iNDArrayArr) {
        for (INDArray iNDArray2 : iNDArrayArr) {
            if (iNDArray.length() != iNDArray2.length()) {
                return false;
            }
        }
        return true;
    }

    public static boolean isSameLength(double[] dArr, double[]... dArr2) {
        for (double[] dArr3 : dArr2) {
            if (dArr.length != dArr3.length) {
                return false;
            }
        }
        return true;
    }

    public static void isSameLengthOrException(INDArray iNDArray, INDArray... iNDArrayArr) {
        for (INDArray iNDArray2 : iNDArrayArr) {
            if (!isSameLength(iNDArray, iNDArray2)) {
                throw new TimeSeriesLengthException(String.format("Length of the given time series are not equal: Length first time series: (%d). Length of seconds time series: (%d)", Long.valueOf(iNDArray.length()), Long.valueOf(iNDArray2.length())));
            }
        }
    }

    /* JADX WARN: Type inference failed for: r1v3, types: [double[], double[][]] */
    public static void isSameLengthOrException(double[] dArr, double[]... dArr2) {
        for (double[] dArr3 : dArr2) {
            if (!isSameLength(dArr, (double[][]) new double[]{dArr3})) {
                throw new TimeSeriesLengthException(String.format("Length of the given time series are not equal: Length first time series: (%d). Length of seconds time series: (%d)", Integer.valueOf(dArr.length), Integer.valueOf(dArr3.length)));
            }
        }
    }

    public static INDArray createEquidistantTimestamps(INDArray iNDArray) {
        int length = (int) iNDArray.length();
        return Nd4j.create(IntStream.range(0, length).mapToDouble(i -> {
            return i;
        }).toArray(), new int[]{length});
    }

    public static double[] createEquidistantTimestamps(double[] dArr) {
        return IntStream.range(0, dArr.length).mapToDouble(i -> {
            return i;
        }).toArray();
    }

    public static double[] getInterval(double[] dArr, int i, int i2) {
        if (i2 <= i) {
            throw new IllegalArgumentException("The end index must be greater than the start index.");
        }
        double[] dArr2 = new double[i2 - i];
        for (int i3 = 0; i3 < i2 - i; i3++) {
            dArr2[i3] = dArr[i3 + i];
        }
        return dArr2;
    }

    public static INDArray normalizeINDArray(INDArray iNDArray, boolean z) {
        if (iNDArray.shape().length > 2 && iNDArray.shape()[0] != 1) {
            throw new IllegalArgumentException(String.format("Input INDArray object must be a vector with shape size 1. Actual shape: (%s)", Arrays.toString(iNDArray.shape())));
        }
        double d = iNDArray.mean(new int[]{1}).getDouble(0L);
        return (z ? iNDArray.subi(Double.valueOf(d)) : iNDArray.sub(Double.valueOf(d))).addi(Double.valueOf(Nd4j.EPS_THRESHOLD)).divi(Double.valueOf(iNDArray.std(new int[]{1}).getDouble(0L)));
    }

    public static int getMode(int[] iArr) {
        HashMap hashMap = new HashMap();
        for (int i = 0; i < iArr.length; i++) {
            if (hashMap.containsKey(Integer.valueOf(iArr[i]))) {
                hashMap.replace(Integer.valueOf(iArr[i]), Integer.valueOf(((Integer) hashMap.get(Integer.valueOf(iArr[i]))).intValue() + 1));
            } else {
                hashMap.put(Integer.valueOf(iArr[i]), 1);
            }
        }
        if (getMaximumKeyByValue(hashMap) != null) {
            return ((Integer) getMaximumKeyByValue(hashMap)).intValue();
        }
        return -1;
    }

    public static <T> T getMaximumKeyByValue(Map<T, Integer> map) {
        T t = null;
        int i = 0;
        for (Map.Entry<T, Integer> entry : map.entrySet()) {
            T key = entry.getKey();
            int intValue = entry.getValue().intValue();
            if (intValue > i) {
                i = intValue;
                t = key;
            }
        }
        return t;
    }

    public static double[] zNormalize(double[] dArr, boolean z) {
        int length = dArr.length - (z ? 1 : 0);
        double d = 0.0d;
        for (double d2 : dArr) {
            d += d2;
        }
        double length2 = d / dArr.length;
        double d3 = 0.0d;
        for (double d4 : dArr) {
            d3 += Math.pow(d4 - length2, 2.0d);
        }
        double sqrt = Math.sqrt(d3 / length);
        double[] dArr2 = new double[dArr.length];
        if (sqrt == 0.0d) {
            return dArr2;
        }
        for (int i = 0; i < dArr2.length; i++) {
            dArr2[i] = (dArr[i] - length2) / sqrt;
        }
        return dArr2;
    }

    public static List<Integer> sortIndexes(double[] dArr, boolean z) {
        Integer[] numArr = new Integer[dArr.length];
        for (int i = 0; i < numArr.length; i++) {
            numArr[i] = Integer.valueOf(i);
        }
        Arrays.sort(numArr, (num, num2) -> {
            return (z ? 1 : -1) * Double.compare(Math.abs(dArr[num.intValue()]), Math.abs(dArr[num2.intValue()]));
        });
        return Arrays.asList(numArr);
    }

    public static int getNumberOfClasses(TimeSeriesDataset2 timeSeriesDataset2) {
        if (timeSeriesDataset2 == null || timeSeriesDataset2.getTargets() == null) {
            throw new IllegalArgumentException("Given parameter 'dataset' must not be null and must contain a target matrix!");
        }
        return getClassesInDataset(timeSeriesDataset2).size();
    }

    public static List<Integer> getClassesInDataset(TimeSeriesDataset2 timeSeriesDataset2) {
        if (timeSeriesDataset2 == null || timeSeriesDataset2.getTargets() == null) {
            throw new IllegalArgumentException("Given parameter 'dataset' must not be null and must contain a target matrix!");
        }
        return (List) ((Set) IntStream.of(timeSeriesDataset2.getTargets()).boxed().collect(Collectors.toSet())).stream().collect(Collectors.toList());
    }

    public static void shuffleTimeSeriesDataset(TimeSeriesDataset2 timeSeriesDataset2, int i) {
        List list = (List) IntStream.range(0, timeSeriesDataset2.getNumberOfInstances()).boxed().collect(Collectors.toList());
        Collections.shuffle(list, new Random(i));
        List<double[][]> valueMatrices = timeSeriesDataset2.getValueMatrices();
        List<double[][]> timestampMatrices = timeSeriesDataset2.getTimestampMatrices();
        int[] targets = timeSeriesDataset2.getTargets();
        if (valueMatrices != null) {
            ArrayList arrayList = new ArrayList();
            for (int i2 = 0; i2 < valueMatrices.size(); i2++) {
                arrayList.add(shuffleMatrix(valueMatrices.get(i2), (List<Integer>) list));
            }
            timeSeriesDataset2.setValueMatrices(arrayList);
        }
        if (timestampMatrices != null) {
            ArrayList arrayList2 = new ArrayList();
            for (int i3 = 0; i3 < timestampMatrices.size(); i3++) {
                arrayList2.add(shuffleMatrix(timestampMatrices.get(i3), (List<Integer>) list));
            }
            timeSeriesDataset2.setTimestampMatrices(arrayList2);
        }
        if (targets != null) {
            timeSeriesDataset2.setTargets(shuffleMatrix(targets, (List<Integer>) list));
        }
    }

    private static double[][] shuffleMatrix(double[][] dArr, List<Integer> list) {
        if (dArr == null || dArr.length < 1) {
            throw new IllegalArgumentException("Parameter 'srcMatrix' must not be null or empty!");
        }
        if (list == null || list.size() != dArr.length) {
            throw new IllegalArgumentException("Parameter 'indices' must not be null and must have the same length as the number of instances in the source matrix!");
        }
        double[][] dArr2 = new double[dArr.length][dArr[0].length];
        for (int i = 0; i < list.size(); i++) {
            dArr2[i] = dArr[list.get(i).intValue()];
        }
        return dArr2;
    }

    private static int[] shuffleMatrix(int[] iArr, List<Integer> list) {
        if (iArr == null || iArr.length < 1) {
            throw new IllegalArgumentException("Parameter 'srcMatrix' must not be null or empty!");
        }
        if (list == null || list.size() != iArr.length) {
            throw new IllegalArgumentException("Parameter 'indices' must not be null and must have the same length as the number of instances in the source matrix!");
        }
        int[] iArr2 = new int[iArr.length];
        for (int i = 0; i < list.size(); i++) {
            iArr2[i] = iArr[list.get(i).intValue()];
        }
        return iArr2;
    }

    public static Pair<TimeSeriesDataset2, TimeSeriesDataset2> getTrainingAndTestDataForFold(int i, int i2, double[][] dArr, int[] iArr) {
        return new Pair<>(selectTrainingDataForFold(i, i2, dArr, iArr), selectTestDataForFold(i, i2, dArr, iArr));
    }

    private static TimeSeriesDataset2 selectTrainingDataForFold(int i, int i2, double[][] dArr, int[] iArr) {
        int length = (int) (dArr.length / i2);
        double[][] dArr2 = new double[(i2 - 1) * length][dArr[0].length];
        int[] iArr2 = new int[(i2 - 1) * length];
        if (i == 0) {
            System.arraycopy(dArr, length, dArr2, 0, (i2 - 1) * length);
            System.arraycopy(iArr, length, iArr2, 0, (i2 - 1) * length);
        } else if (i == i2 - 1) {
            System.arraycopy(dArr, 0, dArr2, 0, (i2 - 1) * length);
            System.arraycopy(iArr, 0, iArr2, 0, (i2 - 1) * length);
        } else {
            System.arraycopy(dArr, 0, dArr2, 0, i * length);
            System.arraycopy(dArr, (i + 1) * length, dArr2, i * length, ((i2 - i) - 1) * length);
            System.arraycopy(iArr, 0, iArr2, 0, i * length);
            System.arraycopy(iArr, (i + 1) * length, iArr2, i * length, ((i2 - i) - 1) * length);
        }
        ArrayList arrayList = new ArrayList();
        arrayList.add(dArr2);
        return new TimeSeriesDataset2(arrayList, iArr2);
    }

    private static TimeSeriesDataset2 selectTestDataForFold(int i, int i2, double[][] dArr, int[] iArr) {
        double[][] dArr2;
        int[] iArr2;
        int length = (int) (dArr.length / i2);
        if (i == i2 - 1) {
            int length2 = dArr.length - ((i2 - 1) * length);
            dArr2 = new double[length2][dArr[0].length];
            iArr2 = new int[length2];
        } else {
            dArr2 = new double[length][dArr[0].length];
            iArr2 = new int[length];
        }
        System.arraycopy(dArr, i * length, dArr2, 0, dArr2.length);
        System.arraycopy(iArr, i * length, iArr2, 0, iArr2.length);
        ArrayList arrayList = new ArrayList();
        arrayList.add(dArr2);
        return new TimeSeriesDataset2(arrayList, iArr2);
    }

    public static TimeSeriesDataset2 createDatasetForMatrix(int[] iArr, double[][]... dArr) {
        if (dArr.length == 0) {
            throw new IllegalArgumentException("There must be at least one value matrix to generate a TimeSeriesDataset object!");
        }
        List asList = Arrays.asList(dArr);
        return iArr == null ? new TimeSeriesDataset2(asList) : new TimeSeriesDataset2((List<double[][]>) asList, iArr);
    }

    public static TimeSeriesDataset2 createDatasetForMatrix(double[][]... dArr) {
        return createDatasetForMatrix(null, dArr);
    }

    public static String toString(double[] dArr) {
        if (dArr.length == 0) {
            return "{}";
        }
        StringBuilder sb = new StringBuilder((2 + (dArr.length * 3)) - 1);
        sb.append("{" + dArr[0]);
        for (int i = 1; i < dArr.length; i++) {
            sb.append(", " + dArr[i]);
        }
        sb.append("}");
        return sb.toString();
    }

    public static double[] keoghDerivate(double[] dArr) {
        double[] dArr2 = new double[dArr.length - 2];
        for (int i = 1; i < dArr.length - 1; i++) {
            dArr2[i - 1] = ((dArr[i] - dArr[i - 1]) + ((dArr[i + 1] - dArr[i - 1]) / 2.0d)) / 2.0d;
        }
        return dArr2;
    }

    public static double[] keoghDerivateWithBoundaries(double[] dArr) {
        double[] dArr2 = new double[dArr.length];
        for (int i = 1; i < dArr.length - 1; i++) {
            dArr2[i] = ((dArr[i] - dArr[i - 1]) + ((dArr[i + 1] - dArr[i - 1]) / 2.0d)) / 2.0d;
        }
        dArr2[0] = dArr2[1];
        dArr2[dArr.length - 1] = dArr2[dArr.length - 2];
        return dArr2;
    }

    public static double[] backwardDifferenceDerivate(double[] dArr) {
        double[] dArr2 = new double[dArr.length - 1];
        for (int i = 1; i < dArr.length; i++) {
            dArr2[i - 1] = dArr[i] - dArr[i - 1];
        }
        return dArr2;
    }

    public static double[] backwardDifferenceDerivateWithBoundaries(double[] dArr) {
        double[] dArr2 = new double[dArr.length];
        for (int i = 1; i < dArr.length; i++) {
            dArr2[i] = dArr[i] - dArr[i - 1];
        }
        dArr2[0] = dArr2[1];
        return dArr2;
    }

    public static double[] forwardDifferenceDerivate(double[] dArr) {
        double[] dArr2 = new double[dArr.length - 1];
        for (int i = 0; i < dArr.length - 1; i++) {
            dArr2[i] = dArr[i + 1] - dArr[i];
        }
        return dArr2;
    }

    public static double[] forwardDifferenceDerivateWithBoundaries(double[] dArr) {
        double[] dArr2 = new double[dArr.length];
        for (int i = 0; i < dArr.length - 1; i++) {
            dArr2[i] = dArr[i + 1] - dArr[i];
        }
        dArr2[dArr.length - 1] = dArr2[dArr.length - 2];
        return dArr2;
    }

    public static double[] gulloDerivate(double[] dArr) {
        double[] dArr2 = new double[dArr.length - 1];
        for (int i = 1; i < dArr.length; i++) {
            dArr2[i - 1] = dArr[i + 1] - (dArr[i - 1] / 2.0d);
        }
        return dArr2;
    }

    public static double[] gulloDerivateWithBoundaries(double[] dArr) {
        double[] dArr2 = new double[dArr.length];
        for (int i = 1; i < dArr.length; i++) {
            dArr2[i] = dArr[i + 1] - (dArr[i - 1] / 2.0d);
        }
        dArr2[0] = dArr2[1];
        return dArr2;
    }

    public static double sum(double[] dArr) {
        double d = 0.0d;
        for (double d2 : dArr) {
            d += d2;
        }
        return d;
    }

    public static double mean(double[] dArr) {
        return sum(dArr) / dArr.length;
    }

    public static double variance(double[] dArr) {
        double mean = mean(dArr);
        double d = 0.0d;
        for (int i = 0; i < dArr.length; i++) {
            d += (dArr[i] - mean) * (dArr[i] - mean);
        }
        return d / dArr.length;
    }

    public static double standardDeviation(double[] dArr) {
        return Math.sqrt(variance(dArr));
    }

    public static double[] normalizeByStandardDeviation(double[] dArr) {
        double standardDeviation = standardDeviation(dArr);
        if (standardDeviation == 0.0d) {
            return new double[dArr.length];
        }
        double[] dArr2 = new double[dArr.length];
        for (int i = 0; i < dArr.length; i++) {
            dArr2[i] = dArr[i] / standardDeviation;
        }
        return dArr2;
    }
}
