package org.linqs.psl.util;

import java.util.Collections;
import java.util.List;
import java.util.Random;
import org.linqs.psl.config.Options;

/* loaded from: input_file:org/linqs/psl/util/RandUtils.class */
public final class RandUtils {
    private static final Logger log = Logger.getLogger(RandUtils.class);
    private static Random rng = null;

    private RandUtils() {
    }

    private static synchronized void ensureRNG() {
        if (rng != null) {
            return;
        }
        long j = Options.RANDOM_SEED.getInt();
        log.info("Using random seed: " + j);
        rng = new Random(j);
    }

    public static synchronized void seed(int i) {
        ensureRNG();
        rng.setSeed(i);
    }

    public static synchronized boolean nextBoolean() {
        ensureRNG();
        return rng.nextBoolean();
    }

    public static synchronized double nextDouble() {
        ensureRNG();
        return rng.nextDouble();
    }

    public static synchronized float nextFloat() {
        ensureRNG();
        return rng.nextFloat();
    }

    public static synchronized float nextFloat(float f, float f2) {
        ensureRNG();
        if (f >= f2) {
            throw new IllegalArgumentException(String.format("Min (%f) must be strictly less than max (%f).", Float.valueOf(f), Float.valueOf(f2)));
        }
        return (rng.nextFloat() * (f2 - f)) + f;
    }

    public static synchronized double nextGaussian() {
        ensureRNG();
        return rng.nextGaussian();
    }

    public static synchronized int nextInt() {
        ensureRNG();
        return rng.nextInt();
    }

    public static synchronized int nextInt(int i) {
        ensureRNG();
        return rng.nextInt(i);
    }

    public static synchronized long nextLong() {
        ensureRNG();
        return rng.nextLong();
    }

    public static synchronized void shuffle(List<?> list) {
        ensureRNG();
        Collections.shuffle(list, rng);
    }

    @SafeVarargs
    public static synchronized void pairedShuffle(List... listArr) {
        ensureRNG();
        if (listArr.length == 0) {
            return;
        }
        for (List list : listArr) {
            if (list.size() != listArr[0].size()) {
                throw new IllegalArgumentException(String.format("Lists must all have a matching size, found %d and %d.", Integer.valueOf(list.size()), Integer.valueOf(listArr[0].size())));
            }
        }
        for (int size = listArr[0].size() - 1; size >= 0; size--) {
            int nextInt = nextInt(size + 1);
            for (List list2 : listArr) {
                Object obj = list2.get(size);
                list2.set(size, list2.get(nextInt));
                list2.set(nextInt, obj);
            }
        }
    }

    public static synchronized <T> void pairedShuffleIndexes(List<T> list, int[] iArr) {
        ensureRNG();
        if (list.size() > iArr.length) {
            throw new IllegalArgumentException(String.format("List size (%d) must be greater than or equal to array size (%d).", Integer.valueOf(list.size()), Integer.valueOf(iArr.length)));
        }
        for (int size = list.size() - 1; size >= 0; size--) {
            int nextInt = nextInt(size + 1);
            T t = list.get(size);
            list.set(size, list.get(nextInt));
            list.set(nextInt, t);
            int i = iArr[size];
            iArr[size] = iArr[nextInt];
            iArr[nextInt] = i;
        }
    }

    public static synchronized double[] sampleDirichlet(double[] dArr) {
        double[] dArr2 = new double[dArr.length];
        double d = 0.0d;
        double[] dArr3 = new double[dArr.length];
        for (int i = 0; i < dArr.length; i++) {
            dArr2[i] = nextGamma(dArr[i], 1.0d);
            d += dArr2[i];
        }
        for (int i2 = 0; i2 < dArr.length; i2++) {
            dArr3[i2] = dArr2[i2] / d;
        }
        return dArr3;
    }

    public static synchronized double nextGamma(double d, double d2) {
        boolean z;
        double d3;
        double pow;
        if (d < 1.0d) {
            z = true;
            d3 = d + 1.0d;
        } else {
            z = false;
            d3 = d;
        }
        double d4 = d3 - 0.3333333333333333d;
        double sqrt = 1.0d / (3.0d * Math.sqrt(d4));
        while (true) {
            double nextGaussian = nextGaussian();
            double d5 = 1.0d + (sqrt * nextGaussian);
            if (d5 > 0.0d) {
                pow = Math.pow(d5, 3.0d);
                double nextDouble = nextDouble();
                if (nextDouble < 1.0d - (0.0331d * Math.pow(nextGaussian, 4.0d)) || Math.log(nextDouble) < (0.5d * Math.pow(nextGaussian, 2.0d)) + (d4 * ((1.0d - pow) + Math.log(pow)))) {
                    break;
                }
            }
        }
        double d6 = d2 * d4 * pow;
        if (z) {
            d6 *= Math.pow(nextDouble(), 1.0d / d);
        }
        return d6;
    }
}
