package ai.sklearn4j.core.libraries;

import ai.sklearn4j.core.libraries.numpy.Numpy;
import ai.sklearn4j.core.libraries.numpy.NumpyArray;
import ai.sklearn4j.core.libraries.numpy.NumpyArrayFactory;

/* loaded from: input_file:ai/sklearn4j/core/libraries/Scipy.class */
public class Scipy {
    public static NumpyArray<Double> logSumExponent(NumpyArray<Double> numpyArray, int i) {
        NumpyArray<Double> arrayMax = Numpy.arrayMax(numpyArray, i);
        NumpyArray<Double> numpyArray2 = to2DArrayShape(arrayMax);
        numpyArray2.applyToEachElement(d -> {
            return !Double.isFinite(d.doubleValue()) ? Double.valueOf(0.0d) : d;
        });
        return Numpy.add(Numpy.log(Numpy.sum(Numpy.exp(Numpy.subtract(numpyArray, numpyArray2)), i)), arrayMax);
    }

    private static NumpyArray<Double> to2DArrayShape(NumpyArray<Double> numpyArray) {
        double[] dArr = (double[]) numpyArray.getWrapper().getRawArray();
        double[][] dArr2 = new double[dArr.length][1];
        for (int i = 0; i < dArr2.length; i++) {
            dArr2[i][0] = dArr[i];
        }
        return NumpyArrayFactory.from(dArr2);
    }
}
