package org.xyou.xcommon.math;

import java.util.function.BiFunction;
import java.util.function.Function;

import lombok.NonNull;

public final class XMath {

    public static double computeLengthManhattan(@NonNull double[] vec) {
        double sum = 0f;
        for (int i = 0; i < vec.length; ++i) {
            sum += vec[i];
        }
        return sum;
    }

    public static double computeLengthEuclid(@NonNull double[] vec) {
        double sum = 0f;
        for (int i = 0; i < vec.length; ++i) {
            double value = vec[i];
            sum += value * value;
        }
        return Math.sqrt(sum);
    }

    public static double[] normalize(
        @NonNull double[] arrInput,
        @NonNull double[] arrOutput,
        @NonNull Function<double[], Double> computeLength
    ) {
        double normValue = computeLength.apply(arrInput);
        for (int idxEle = 0; idxEle < arrInput.length; ++idxEle) {
            arrOutput[idxEle] = (arrInput[idxEle] / (double) normValue);
        }
        return arrOutput;
    }

    public static double[] normalizeL1(@NonNull double[] arr) {
        return normalize(arr, new double[arr.length], ele -> computeLengthManhattan(ele));
    }

    public static double[] normalizeL1InPlace(@NonNull double[] arr) {
        return normalize(arr, arr, ele -> computeLengthManhattan(ele));
    }

    public static double[] normalizeL2(@NonNull double[] arr) {
        return normalize(arr, new double[arr.length], ele -> computeLengthEuclid(ele));
    }

    public static double[] normalizeL2InPlace(@NonNull double[] arr) {
        return normalize(arr, arr, ele -> computeLengthEuclid(ele));
    }

    public static double sigmoid(@NonNull Double inpt) {
        return (1.0 / (1.0 + Math.exp(-inpt)));
    }

    public static double tanh(@NonNull Double inpt) {
        return (2.0 / (1.0 + Math.exp(-2.0 * inpt))) - 1.0;
    }

    public static int argMax(@NonNull double[] arr) {
        double valMax = -Double.MIN_VALUE;
        int idxMax = 0;
        for (int idx = 0; idx < arr.length; idx++) {
            double val = arr[idx];
            if (val > valMax) {
                idxMax = idx;
                valMax = val;
            }
        }
        return idxMax;
    }

    public static double dot(@NonNull double[] arr1, @NonNull double[] arr2) {
        double product = 0f;
        int length = arr1.length;
        for (int i = 0; i < length; ++i) {
            double eleA = arr1[i];
            double eleB = arr2[i];
            product = product + eleA * eleB;
        }
        return product;
    }

    private static double[] computeVector(
        @NonNull double[] arrRes,
        @NonNull double[] arr1,
        @NonNull double[] arr2,
        @NonNull BiFunction<Double, Double, Double> op
    ) {
        for (int idxEle = 0; idxEle < arr1.length; idxEle++) {
            arrRes[idxEle] = op.apply(arr1[idxEle], arr2[idxEle]);
        }
        return arrRes;
    }

    public static double[] add(@NonNull double[] arr1, @NonNull double[] arr2) {
        double[] arrRes = new double[arr1.length];
        return computeVector(arrRes, arr1, arr2, (ele1, ele2) -> ele1 + ele2);
    }

    public static double[] addInPlace(@NonNull double[] arr1, @NonNull double[] arr2) {
        return computeVector(arr1, arr1, arr2, (ele1, ele2) -> ele1 + ele2);
    }

    public static double[] mul(@NonNull double[] arr1, @NonNull double[] arr2) {
        double[] arrRes = new double[arr1.length];
        return computeVector(arrRes, arr1, arr2, (ele1, ele2) -> ele1 * ele2);
    }

    public static double[] mulInPlace(@NonNull double[] arr1, @NonNull double[] arr2) {
        return computeVector(arr1, arr1, arr2, (ele1, ele2) -> ele1 * ele2);
    }

    public static double[] sub(@NonNull double[] arr1, @NonNull double[] arr2) {
        double[] arrRes = new double[arr1.length];
        return computeVector(arrRes, arr1, arr2, (ele1, ele2) -> ele1 - ele2);
    }

    public static double[] subInPlace(@NonNull double[] arr1, @NonNull double[] arr2) {
        return computeVector(arr1, arr1, arr2, (ele1, ele2) -> ele1 - ele2);
    }

    public static double[] div(@NonNull double[] arr1, @NonNull double[] arr2) {
        double[] arrRes = new double[arr1.length];
        return computeVector(arrRes, arr1, arr2, (ele1, ele2) -> ele1 / ele2);
    }

    public static double[] divInPlace(@NonNull double[] arr1, @NonNull double[] arr2) {
        return computeVector(arr1, arr1, arr2, (ele1, ele2) -> ele1 / ele2);
    }

}
