package net.finmath.optimizer;

import java.io.Serializable;
import java.util.Arrays;
import java.util.List;
import java.util.Vector;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Future;
import java.util.concurrent.FutureTask;
import java.util.logging.Level;
import java.util.logging.Logger;
import net.finmath.functions.LinearAlgebra;
import net.finmath.montecarlo.RandomVariable;
import net.finmath.stochastic.RandomVariableInterface;
import org.apache.commons.math3.optimization.direct.CMAESOptimizer;

/* loaded from: input_file:net/finmath/optimizer/StochasticLevenbergMarquardt.class */
public abstract class StochasticLevenbergMarquardt implements Serializable, Cloneable, StochasticOptimizerInterface {
    private static final long serialVersionUID = 4560864869394838155L;
    private final RegularizationMethod regularizationMethod;
    private RandomVariableInterface[] initialParameters;
    private RandomVariableInterface[] parameterSteps;
    private RandomVariableInterface[] targetValues;
    private int maxIteration;
    private double lambda;
    private double lambdaInitialValue;
    private double lambdaDivisor;
    private double lambdaMultiplicator;
    private int numberOfPaths;
    private double errorTolerance;
    private int iteration;
    private RandomVariableInterface[] parameterTest;
    private RandomVariableInterface[] valueTest;
    private RandomVariableInterface[] parameterCurrent;
    private RandomVariableInterface[] valueCurrent;
    private RandomVariableInterface[][] derivativeCurrent;
    private double errorMeanSquaredCurrent;
    private double errorRootMeanSquaredChange;
    private boolean isParameterCurrentDerivativeValid;
    private ExecutorService executor;
    private boolean executorShutdownWhenDone;
    private final Logger logger;

    /* loaded from: input_file:net/finmath/optimizer/StochasticLevenbergMarquardt$RegularizationMethod.class */
    public enum RegularizationMethod {
        LEVENBERG,
        LEVENBERG_MARQUARDT
    }

    public static void main(String[] strArr) throws SolverException {
        StochasticLevenbergMarquardt stochasticLevenbergMarquardt = new StochasticLevenbergMarquardt(new RandomVariableInterface[]{new RandomVariable(2.0d), new RandomVariable(2.0d)}, new RandomVariableInterface[]{new RandomVariable(25.0d), new RandomVariable(100.0d)}, new RandomVariableInterface[]{new RandomVariable(1.0d), new RandomVariable(1.0d)}, 100, 1.0E-12d, null) { // from class: net.finmath.optimizer.StochasticLevenbergMarquardt.1
            private static final long serialVersionUID = -282626938650139518L;

            @Override // net.finmath.optimizer.StochasticLevenbergMarquardt
            public void setValues(RandomVariableInterface[] randomVariableInterfaceArr, RandomVariableInterface[] randomVariableInterfaceArr2) {
                randomVariableInterfaceArr2[0] = randomVariableInterfaceArr[0].mult(CMAESOptimizer.DEFAULT_STOPFITNESS).add(randomVariableInterfaceArr[1]).squared();
                randomVariableInterfaceArr2[1] = randomVariableInterfaceArr[0].mult(2.0d).add(randomVariableInterfaceArr[1]).squared();
            }

            @Override // net.finmath.optimizer.StochasticLevenbergMarquardt
            /* renamed from: clone */
            public /* bridge */ /* synthetic */ Object mo175clone() throws CloneNotSupportedException {
                return super.mo175clone();
            }
        };
        stochasticLevenbergMarquardt.run();
        RandomVariableInterface[] bestFitParameters = stochasticLevenbergMarquardt.getBestFitParameters();
        System.out.println("The solver for problem 1 required " + stochasticLevenbergMarquardt.getIterations() + " iterations. The best fit parameters are:");
        for (int i = 0; i < bestFitParameters.length; i++) {
            System.out.println("\tparameter[" + i + "]: " + bestFitParameters[i]);
        }
        System.out.println("The solver accuracy is " + stochasticLevenbergMarquardt.getRootMeanSquaredError());
    }

    public StochasticLevenbergMarquardt(RegularizationMethod regularizationMethod, RandomVariableInterface[] randomVariableInterfaceArr, RandomVariableInterface[] randomVariableInterfaceArr2, RandomVariableInterface[] randomVariableInterfaceArr3, int i, double d, ExecutorService executorService) {
        this.initialParameters = null;
        this.parameterSteps = null;
        this.targetValues = null;
        this.lambdaInitialValue = 0.001d;
        this.lambdaDivisor = 1.3d;
        this.lambdaMultiplicator = 2.0d;
        this.iteration = 0;
        this.parameterTest = null;
        this.valueTest = null;
        this.parameterCurrent = null;
        this.valueCurrent = null;
        this.derivativeCurrent = (RandomVariableInterface[][]) null;
        this.errorMeanSquaredCurrent = Double.POSITIVE_INFINITY;
        this.errorRootMeanSquaredChange = Double.POSITIVE_INFINITY;
        this.executor = null;
        this.executorShutdownWhenDone = true;
        this.logger = Logger.getLogger("net.finmath");
        this.regularizationMethod = regularizationMethod;
        this.initialParameters = randomVariableInterfaceArr;
        this.targetValues = randomVariableInterfaceArr2;
        this.parameterSteps = randomVariableInterfaceArr3;
        this.maxIteration = i;
        this.errorTolerance = d;
        this.executor = executorService;
        this.executorShutdownWhenDone = executorService == null;
    }

    public StochasticLevenbergMarquardt(RandomVariableInterface[] randomVariableInterfaceArr, RandomVariableInterface[] randomVariableInterfaceArr2, RandomVariableInterface[] randomVariableInterfaceArr3, int i, double d, ExecutorService executorService) {
        this(RegularizationMethod.LEVENBERG_MARQUARDT, randomVariableInterfaceArr, randomVariableInterfaceArr2, randomVariableInterfaceArr3, i, d, executorService);
    }

    public double getLambda() {
        return this.lambda;
    }

    public void setLambda(double d) {
        this.lambda = d;
    }

    public double getLambdaMultiplicator() {
        return this.lambdaMultiplicator;
    }

    public void setLambdaMultiplicator(double d) {
        if (d <= 1.0d) {
            throw new IllegalArgumentException("Parameter lambdaMultiplicator is required to be > 1.");
        }
        this.lambdaMultiplicator = d;
    }

    public double getLambdaDivisor() {
        return this.lambdaDivisor;
    }

    public void setLambdaDivisor(double d) {
        if (d <= 1.0d) {
            throw new IllegalArgumentException("Parameter lambdaDivisor is required to be > 1.");
        }
        this.lambdaDivisor = d;
    }

    @Override // net.finmath.optimizer.StochasticOptimizerInterface
    public RandomVariableInterface[] getBestFitParameters() {
        return this.parameterCurrent;
    }

    @Override // net.finmath.optimizer.StochasticOptimizerInterface
    public double getRootMeanSquaredError() {
        return Math.sqrt(this.errorMeanSquaredCurrent);
    }

    public void setErrorMeanSquaredCurrent(double d) {
        this.errorMeanSquaredCurrent = d;
    }

    @Override // net.finmath.optimizer.StochasticOptimizerInterface
    public int getIterations() {
        return this.iteration;
    }

    protected void prepareAndSetValues(RandomVariableInterface[] randomVariableInterfaceArr, RandomVariableInterface[] randomVariableInterfaceArr2) throws SolverException {
        setValues(randomVariableInterfaceArr, randomVariableInterfaceArr2);
    }

    protected void prepareAndSetDerivatives(RandomVariableInterface[] randomVariableInterfaceArr, RandomVariableInterface[] randomVariableInterfaceArr2, RandomVariableInterface[][] randomVariableInterfaceArr3) throws SolverException {
        setDerivatives(randomVariableInterfaceArr, randomVariableInterfaceArr3);
    }

    public abstract void setValues(RandomVariableInterface[] randomVariableInterfaceArr, RandomVariableInterface[] randomVariableInterfaceArr2) throws SolverException;

    public void setDerivatives(RandomVariableInterface[] randomVariableInterfaceArr, RandomVariableInterface[][] randomVariableInterfaceArr2) throws SolverException {
        RandomVariableInterface[] randomVariableInterfaceArr3 = this.parameterCurrent;
        Vector vector = new Vector(this.parameterCurrent.length);
        for (int i = 0; i < this.parameterCurrent.length; i++) {
            final RandomVariableInterface[] randomVariableInterfaceArr4 = (RandomVariableInterface[]) randomVariableInterfaceArr3.clone();
            final RandomVariableInterface[] randomVariableInterfaceArr5 = randomVariableInterfaceArr2[i];
            final int i2 = i;
            Callable<RandomVariableInterface[]> callable = new Callable<RandomVariableInterface[]>() { // from class: net.finmath.optimizer.StochasticLevenbergMarquardt.2
                /* JADX WARN: Can't rename method to resolve collision */
                @Override // java.util.concurrent.Callable
                public RandomVariableInterface[] call() {
                    RandomVariableInterface mult = StochasticLevenbergMarquardt.this.parameterSteps != null ? StochasticLevenbergMarquardt.this.parameterSteps[i2] : randomVariableInterfaceArr4[i2].abs().add(1.0d).mult(1.0E-8d);
                    randomVariableInterfaceArr4[i2] = randomVariableInterfaceArr4[i2].add(mult);
                    try {
                        StochasticLevenbergMarquardt.this.prepareAndSetValues(randomVariableInterfaceArr4, randomVariableInterfaceArr5);
                    } catch (Exception e) {
                        Arrays.fill(randomVariableInterfaceArr5, new RandomVariable(Double.NaN));
                    }
                    for (int i3 = 0; i3 < StochasticLevenbergMarquardt.this.valueCurrent.length; i3++) {
                        randomVariableInterfaceArr5[i3] = randomVariableInterfaceArr5[i3].sub(StochasticLevenbergMarquardt.this.valueCurrent[i3]).div(mult);
                        randomVariableInterfaceArr5[i3] = randomVariableInterfaceArr5[i3].barrier(randomVariableInterfaceArr5[i3].isNaN().sub(0.5d).mult(-1.0d), randomVariableInterfaceArr5[i3], CMAESOptimizer.DEFAULT_STOPFITNESS);
                    }
                    return randomVariableInterfaceArr5;
                }
            };
            if (this.executor != null) {
                vector.add(i, this.executor.submit(callable));
            } else {
                FutureTask futureTask = new FutureTask(callable);
                futureTask.run();
                vector.add(i, futureTask);
            }
        }
        for (int i3 = 0; i3 < this.parameterCurrent.length; i3++) {
            try {
                randomVariableInterfaceArr2[i3] = (RandomVariableInterface[]) ((Future) vector.get(i3)).get();
            } catch (InterruptedException e) {
                throw new SolverException(e);
            } catch (ExecutionException e2) {
                throw new SolverException(e2);
            }
        }
    }

    boolean done() {
        return this.iteration > this.maxIteration || this.errorRootMeanSquaredChange <= this.errorTolerance;
    }

    @Override // net.finmath.optimizer.StochasticOptimizerInterface
    public void run() throws SolverException {
        try {
            int length = this.initialParameters.length;
            int length2 = this.targetValues.length;
            this.parameterTest = (RandomVariableInterface[]) this.initialParameters.clone();
            this.parameterCurrent = (RandomVariableInterface[]) this.initialParameters.clone();
            this.valueTest = new RandomVariableInterface[length2];
            this.valueCurrent = new RandomVariableInterface[length2];
            Arrays.fill(this.valueCurrent, new RandomVariable(Double.NaN));
            this.derivativeCurrent = new RandomVariableInterface[length][length2];
            this.iteration = 0;
            this.lambda = this.lambdaInitialValue;
            this.isParameterCurrentDerivativeValid = false;
            while (true) {
                this.iteration++;
                prepareAndSetValues(this.parameterTest, this.valueTest);
                double meanSquaredError = getMeanSquaredError(this.valueTest);
                boolean z = this.errorMeanSquaredCurrent > meanSquaredError;
                if (z) {
                    this.parameterCurrent = (RandomVariableInterface[]) this.parameterTest.clone();
                    this.valueCurrent = (RandomVariableInterface[]) this.valueTest.clone();
                    this.errorRootMeanSquaredChange = Math.sqrt(this.errorMeanSquaredCurrent) - Math.sqrt(meanSquaredError);
                    this.errorMeanSquaredCurrent = meanSquaredError;
                }
                if (done()) {
                    break;
                }
                this.isParameterCurrentDerivativeValid = !z;
                if (z) {
                    this.lambda /= this.lambdaDivisor;
                } else {
                    this.lambda *= this.lambdaMultiplicator;
                }
                System.out.println(this.iteration + " \t" + this.lambda + " \t" + Math.sqrt(this.errorMeanSquaredCurrent));
                prepareAndSetDerivatives(this.parameterTest, this.valueTest, this.derivativeCurrent);
                double[] dArr = new double[this.parameterCurrent.length];
                double[][] dArr2 = new double[this.parameterCurrent.length][this.parameterCurrent.length];
                double[] dArr3 = new double[this.parameterCurrent.length];
                boolean z2 = true;
                while (z2) {
                    for (int i = 0; i < this.parameterCurrent.length; i++) {
                        for (int i2 = i; i2 < this.parameterCurrent.length; i2++) {
                            double d = 0.0d;
                            for (int i3 = 0; i3 < this.valueCurrent.length; i3++) {
                                if (this.derivativeCurrent[i][i3] != null && this.derivativeCurrent[i2][i3] != null) {
                                    d += this.derivativeCurrent[i][i3].mult(this.derivativeCurrent[i2][i3]).getAverage();
                                }
                            }
                            if (i == i2) {
                                d = this.regularizationMethod == RegularizationMethod.LEVENBERG ? d + this.lambda : d == CMAESOptimizer.DEFAULT_STOPFITNESS ? this.lambda : d * (1.0d + this.lambda);
                            }
                            dArr2[i][i2] = d;
                            dArr2[i2][i] = d;
                        }
                    }
                    for (int i4 = 0; i4 < this.parameterCurrent.length; i4++) {
                        double d2 = 0.0d;
                        for (int i5 = 0; i5 < this.valueCurrent.length; i5++) {
                            if (this.derivativeCurrent[i4][i5] != null) {
                                d2 += this.targetValues[i5].sub(this.valueCurrent[i5]).mult(this.derivativeCurrent[i4][i5]).getAverage();
                            }
                        }
                        dArr3[i4] = d2;
                    }
                    try {
                        dArr = LinearAlgebra.solveLinearEquationSymmetric(dArr2, dArr3);
                        z2 = false;
                    } catch (Exception e) {
                        z2 = true;
                        this.lambda *= 16.0d;
                    }
                }
                for (int i6 = 0; i6 < this.parameterCurrent.length; i6++) {
                    this.parameterTest[i6] = this.parameterCurrent[i6].add(dArr[i6]);
                }
                if (this.logger.isLoggable(Level.FINE)) {
                    String str = "Iteration: " + this.iteration + "\tLambda=" + this.lambda + "\tError Current:" + this.errorMeanSquaredCurrent + "\tError Change:" + this.errorRootMeanSquaredChange + "\t";
                    for (int i7 = 0; i7 < this.parameterCurrent.length; i7++) {
                        str = str + "[" + i7 + "] = " + this.parameterCurrent[i7] + "\t";
                    }
                    this.logger.fine(str);
                }
            }
        } finally {
            if (this.executor != null && this.executorShutdownWhenDone) {
                this.executor.shutdown();
                this.executor = null;
            }
        }
    }

    public double getMeanSquaredError(RandomVariableInterface[] randomVariableInterfaceArr) {
        double d = 0.0d;
        for (int i = 0; i < randomVariableInterfaceArr.length; i++) {
            d += randomVariableInterfaceArr[i].sub(this.targetValues[i]).squared().getAverage();
        }
        return d / randomVariableInterfaceArr.length;
    }

    @Override // 
    /* renamed from: clone, reason: merged with bridge method [inline-methods] */
    public StochasticLevenbergMarquardt mo175clone() throws CloneNotSupportedException {
        throw new CloneNotSupportedException();
    }

    public StochasticLevenbergMarquardt getCloneWithModifiedTargetValues(RandomVariableInterface[] randomVariableInterfaceArr, RandomVariableInterface[] randomVariableInterfaceArr2, boolean z) throws CloneNotSupportedException {
        StochasticLevenbergMarquardt mo175clone = mo175clone();
        mo175clone.targetValues = (RandomVariableInterface[]) randomVariableInterfaceArr.clone();
        if (z && done()) {
            mo175clone.initialParameters = getBestFitParameters();
        }
        return mo175clone;
    }

    public StochasticLevenbergMarquardt getCloneWithModifiedTargetValues(List<RandomVariableInterface> list, List<RandomVariableInterface> list2, boolean z) throws CloneNotSupportedException {
        StochasticLevenbergMarquardt mo175clone = mo175clone();
        mo175clone.targetValues = (RandomVariableInterface[]) list.toArray(new RandomVariableInterface[0]);
        if (z && done()) {
            mo175clone.initialParameters = getBestFitParameters();
        }
        return mo175clone;
    }
}
