package math.dl;

import math.rng.Stc64;

/* loaded from: input_file:math/dl/Softmax.class */
public final class Softmax {
    public static double[] softmax(int i, int i2, double[] dArr, int i3, double[] dArr2) {
        double max = max(i, i2, dArr);
        double d = 0.0d;
        for (int i4 = i2; i4 < i2 + i; i4++) {
            double exp = Math.exp(dArr[i4] - max);
            d += exp;
            dArr2[(i3 + i4) - i2] = exp;
        }
        double d2 = 1.0d / d;
        for (int i5 = i3; i5 < i3 + i; i5++) {
            int i6 = i5;
            dArr2[i6] = dArr2[i6] * d2;
        }
        return dArr2;
    }

    public static float[] softmaxF(int i, int i2, float[] fArr, int i3, float[] fArr2) {
        float maxF = maxF(i, i2, fArr);
        float f = 0.0f;
        for (int i4 = i2; i4 < i2 + i; i4++) {
            float exp = (float) Math.exp(fArr[i4] - maxF);
            f += exp;
            fArr2[(i3 + i4) - i2] = exp;
        }
        float f2 = 1.0f / f;
        for (int i5 = i3; i5 < i3 + i; i5++) {
            int i6 = i5;
            fArr2[i6] = fArr2[i6] * f2;
        }
        return fArr2;
    }

    public static double[] reweigh(double[] dArr, double d) {
        if (d == 1.0d) {
            return dArr;
        }
        if (d <= 0.0d) {
            throw new IllegalArgumentException("temperature must be strictly positive: " + d);
        }
        double[] dArr2 = new double[dArr.length];
        double d2 = 0.0d;
        double d3 = -1.7976931348623157E308d;
        int i = 0;
        for (int i2 = 0; i2 < dArr.length; i2++) {
            double log = Math.log(clamp(dArr[i2])) / d;
            if (log > d3) {
                d3 = log;
                i = i2;
            }
            double exp = Math.exp(log);
            d2 += exp;
            dArr2[i2] = exp;
        }
        if (d2 == 0.0d) {
            dArr2[i] = 1.0d;
        } else {
            for (int i3 = 0; i3 < dArr2.length; i3++) {
                int i4 = i3;
                dArr2[i4] = dArr2[i4] / d2;
            }
        }
        return dArr2;
    }

    public static float[] reweighF(float[] fArr, float f) {
        if (f == 1.0f) {
            return fArr;
        }
        if (f <= 0.0f) {
            throw new IllegalArgumentException("temperature must be strictly positive: " + f);
        }
        float[] fArr2 = new float[fArr.length];
        float f2 = 0.0f;
        float f3 = -3.4028235E38f;
        int i = 0;
        for (int i2 = 0; i2 < fArr.length; i2++) {
            float log = ((float) Math.log(clampF(fArr[i2]))) / f;
            if (log > f3) {
                f3 = log;
                i = i2;
            }
            float exp = (float) Math.exp(log);
            f2 += exp;
            fArr2[i2] = exp;
        }
        if (f2 == 0.0d) {
            fArr2[i] = 1.0f;
        } else {
            for (int i3 = 0; i3 < fArr2.length; i3++) {
                int i4 = i3;
                fArr2[i4] = fArr2[i4] / f2;
            }
        }
        return fArr2;
    }

    public static int sampleClass(double[] dArr) {
        int i = 0;
        double nextDouble = Stc64.getDefault().nextDouble();
        int i2 = 0;
        while (true) {
            if (i2 >= dArr.length) {
                break;
            }
            nextDouble -= dArr[i2];
            if (nextDouble <= 0.0d) {
                i = i2;
                break;
            }
            i2++;
        }
        return i;
    }

    public static int sampleClassF(float[] fArr) {
        int i = 0;
        float nextFloat = Stc64.getDefault().nextFloat();
        int i2 = 0;
        while (true) {
            if (i2 >= fArr.length) {
                break;
            }
            nextFloat -= fArr[i2];
            if (nextFloat <= 0.0f) {
                i = i2;
                break;
            }
            i2++;
        }
        return i;
    }

    static double max(int i, int i2, double[] dArr) {
        double d = dArr[i2];
        for (int i3 = i2 + 1; i3 < i2 + i; i3++) {
            d = Math.max(d, dArr[i3]);
        }
        return d;
    }

    static float maxF(int i, int i2, float[] fArr) {
        float f = fArr[i2];
        for (int i3 = i2 + 1; i3 < i2 + i; i3++) {
            f = Math.max(f, fArr[i3]);
        }
        return f;
    }

    static double clamp(double d) {
        if (d > 0.0d) {
            return d;
        }
        if (d == 0.0d) {
            return Double.MIN_NORMAL;
        }
        throw new IllegalArgumentException("negative probability: " + d);
    }

    static float clampF(float f) {
        if (f > 0.0f) {
            return f;
        }
        if (f == 0.0f) {
            return Float.MIN_NORMAL;
        }
        throw new IllegalArgumentException("negative probability: " + f);
    }

    private Softmax() {
        throw new AssertionError();
    }
}
