package org.hipparchus.analysis.function;

import org.hipparchus.analysis.FunctionUtils;
import org.hipparchus.analysis.differentiation.DSFactory;
import org.hipparchus.analysis.differentiation.DerivativeStructure;
import org.hipparchus.analysis.differentiation.UnivariateDifferentiableFunction;
import org.hipparchus.analysis.function.Logit;
import org.hipparchus.exception.MathIllegalArgumentException;
import org.hipparchus.exception.NullArgumentException;
import org.hipparchus.random.Well1024a;
import org.hipparchus.util.FastMath;
import org.junit.Assert;
import org.junit.Test;

/* loaded from: input_file:org/hipparchus/analysis/function/LogitTest.class */
public class LogitTest {
    private final double EPS = Math.ulp(1.0d);

    @Test(expected = MathIllegalArgumentException.class)
    public void testPreconditions1() {
        new Logit(-1.0d, 2.0d).value(-2.0d);
    }

    @Test(expected = MathIllegalArgumentException.class)
    public void testPreconditions2() {
        new Logit(-1.0d, 2.0d).value(3.0d);
    }

    @Test
    public void testSomeValues() {
        Logit logit = new Logit(1.0d, 2.0d);
        Assert.assertEquals(Double.NEGATIVE_INFINITY, logit.value(1.0d), this.EPS);
        Assert.assertEquals(Double.POSITIVE_INFINITY, logit.value(2.0d), this.EPS);
        Assert.assertEquals(0.0d, logit.value(1.5d), this.EPS);
    }

    @Test
    public void testDerivative() {
        Assert.assertEquals(4.0d, new Logit(1.0d, 2.0d).value(new DSFactory(1, 1).variable(0, 1.5d)).getPartialDerivative(new int[]{1}), this.EPS);
    }

    @Test
    public void testDerivativeLargeArguments() {
        Logit logit = new Logit(1.0d, 2.0d);
        DSFactory dSFactory = new DSFactory(1, 1);
        for (double d : new double[]{Double.NEGATIVE_INFINITY, -1.7976931348623157E308d, -1.0E155d, 1.0E155d, Double.MAX_VALUE, Double.POSITIVE_INFINITY}) {
            try {
                logit.value(dSFactory.variable(0, d));
                Assert.fail("an exception should have been thrown");
            } catch (Exception e) {
                Assert.fail("wrong exception caught: " + e.getMessage());
            } catch (MathIllegalArgumentException e2) {
            }
        }
    }

    @Test
    public void testDerivativesHighOrder() {
        DerivativeStructure value = new Logit(1.0d, 3.0d).value(new DSFactory(1, 5).variable(0, 1.2d));
        Assert.assertEquals(-2.1972245773362196d, value.getPartialDerivative(new int[]{0}), 1.0E-16d);
        Assert.assertEquals(5.555555555555555d, value.getPartialDerivative(new int[]{1}), 9.0E-16d);
        Assert.assertEquals(-24.691358024691358d, value.getPartialDerivative(new int[]{2}), 2.0E-14d);
        Assert.assertEquals(250.3429355281207d, value.getPartialDerivative(new int[]{3}), 2.0E-13d);
        Assert.assertEquals(-3749.4284407864657d, value.getPartialDerivative(new int[]{4}), 4.0E-12d);
        Assert.assertEquals(75001.27013158564d, value.getPartialDerivative(new int[]{5}), 8.0E-11d);
    }

    @Test(expected = NullArgumentException.class)
    public void testParametricUsage1() {
        new Logit.Parametric().value(0.0d, (double[]) null);
    }

    @Test(expected = MathIllegalArgumentException.class)
    public void testParametricUsage2() {
        new Logit.Parametric().value(0.0d, new double[]{0.0d});
    }

    @Test(expected = NullArgumentException.class)
    public void testParametricUsage3() {
        new Logit.Parametric().gradient(0.0d, (double[]) null);
    }

    @Test(expected = MathIllegalArgumentException.class)
    public void testParametricUsage4() {
        new Logit.Parametric().gradient(0.0d, new double[]{0.0d});
    }

    @Test(expected = MathIllegalArgumentException.class)
    public void testParametricUsage5() {
        new Logit.Parametric().value(-1.0d, new double[]{0.0d, 1.0d});
    }

    @Test(expected = MathIllegalArgumentException.class)
    public void testParametricUsage6() {
        new Logit.Parametric().value(2.0d, new double[]{0.0d, 1.0d});
    }

    @Test
    public void testParametricValue() {
        Logit logit = new Logit(2.0d, 3.0d);
        Logit.Parametric parametric = new Logit.Parametric();
        Assert.assertEquals(logit.value(2.0d), parametric.value(2.0d, new double[]{2.0d, 3.0d}), 0.0d);
        Assert.assertEquals(logit.value(2.34567d), parametric.value(2.34567d, new double[]{2.0d, 3.0d}), 0.0d);
        Assert.assertEquals(logit.value(3.0d), parametric.value(3.0d, new double[]{2.0d, 3.0d}), 0.0d);
    }

    @Test
    public void testValueWithInverseFunction() {
        UnivariateDifferentiableFunction logit = new Logit(2.0d, 3.0d);
        UnivariateDifferentiableFunction sigmoid = new Sigmoid(2.0d, 3.0d);
        Well1024a well1024a = new Well1024a(5301102751131602357L);
        UnivariateDifferentiableFunction compose = FunctionUtils.compose(new UnivariateDifferentiableFunction[]{sigmoid, logit});
        DSFactory dSFactory = new DSFactory(1, 1);
        for (int i = 0; i < 10; i++) {
            double nextDouble = 2.0d + (well1024a.nextDouble() * 1.0d);
            Assert.assertEquals(nextDouble, compose.value(dSFactory.variable(0, nextDouble)).getValue(), this.EPS);
        }
        Assert.assertEquals(2.0d, compose.value(dSFactory.variable(0, 2.0d)).getValue(), this.EPS);
        Assert.assertEquals(3.0d, compose.value(dSFactory.variable(0, 3.0d)).getValue(), this.EPS);
    }

    @Test
    public void testDerivativesWithInverseFunction() {
        double[] dArr = {1.0E-20d, 4.0E-16d, 3.0E-15d, 2.0E-11d, 3.0E-9d, 1.0E-6d};
        UnivariateDifferentiableFunction logit = new Logit(2.0d, 3.0d);
        UnivariateDifferentiableFunction sigmoid = new Sigmoid(2.0d, 3.0d);
        Well1024a well1024a = new Well1024a(-7599720346551202139L);
        UnivariateDifferentiableFunction compose = FunctionUtils.compose(new UnivariateDifferentiableFunction[]{sigmoid, logit});
        for (int i = 0; i < 6; i++) {
            DSFactory dSFactory = new DSFactory(1, i);
            double d = 0.0d;
            for (int i2 = 0; i2 < 10; i2++) {
                DerivativeStructure variable = dSFactory.variable(0, 2.0d + (well1024a.nextDouble() * 1.0d));
                d = FastMath.max(d, FastMath.abs(variable.getPartialDerivative(new int[]{i}) - compose.value(variable).getPartialDerivative(new int[]{i})));
                Assert.assertEquals(variable.getPartialDerivative(new int[]{i}), compose.value(variable).getPartialDerivative(new int[]{i}), dArr[i]);
            }
            DerivativeStructure variable2 = dSFactory.variable(0, 2.0d);
            if (i == 0) {
                Assert.assertTrue(Double.isInfinite(logit.value(variable2).getPartialDerivative(new int[]{i})));
                Assert.assertEquals(2.0d, compose.value(variable2).getPartialDerivative(new int[]{i}), dArr[i]);
            } else if (i == 1) {
                Assert.assertTrue(Double.isInfinite(logit.value(variable2).getPartialDerivative(new int[]{i})));
                Assert.assertTrue(Double.isNaN(compose.value(variable2).getPartialDerivative(new int[]{i})));
            } else {
                Assert.assertTrue(Double.isNaN(logit.value(variable2).getPartialDerivative(new int[]{i})));
                Assert.assertTrue(Double.isNaN(compose.value(variable2).getPartialDerivative(new int[]{i})));
            }
            DerivativeStructure variable3 = dSFactory.variable(0, 3.0d);
            if (i == 0) {
                Assert.assertTrue(Double.isInfinite(logit.value(variable3).getPartialDerivative(new int[]{i})));
                Assert.assertEquals(3.0d, compose.value(variable3).getPartialDerivative(new int[]{i}), dArr[i]);
            } else if (i == 1) {
                Assert.assertTrue(Double.isInfinite(logit.value(variable3).getPartialDerivative(new int[]{i})));
                Assert.assertTrue(Double.isNaN(compose.value(variable3).getPartialDerivative(new int[]{i})));
            } else {
                Assert.assertTrue(Double.isNaN(logit.value(variable3).getPartialDerivative(new int[]{i})));
                Assert.assertTrue(Double.isNaN(compose.value(variable3).getPartialDerivative(new int[]{i})));
            }
        }
    }
}
