package net.sf.tweety.math.opt.solver;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import net.sf.tweety.commons.util.VectorTools;
import net.sf.tweety.math.GeneralMathException;
import net.sf.tweety.math.equation.Equation;
import net.sf.tweety.math.opt.ConstraintSatisfactionProblem;
import net.sf.tweety.math.opt.OptimizationProblem;
import net.sf.tweety.math.opt.Solver;
import net.sf.tweety.math.term.FloatConstant;
import net.sf.tweety.math.term.FloatVariable;
import net.sf.tweety.math.term.IntegerConstant;
import net.sf.tweety.math.term.Product;
import net.sf.tweety.math.term.Term;
import net.sf.tweety.math.term.Variable;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* JADX WARN: Classes with same name are omitted:
  input_file:net.sf.tweety.math-1.15.jar:net/sf/tweety/math/opt/solver/HessianGradientDescent.class
 */
/* loaded from: input_file:net.sf.tweety.math-1.16.jar:net/sf/tweety/math/opt/solver/HessianGradientDescent.class */
public class HessianGradientDescent extends Solver {
    private Logger log = LoggerFactory.getLogger(HessianGradientDescent.class);
    private static final double PRECISION = 1.0E-5d;
    private Map<Variable, Term> startingPoint;

    public HessianGradientDescent(Map<Variable, Term> map) {
        this.startingPoint = map;
    }

    @Override // net.sf.tweety.math.opt.Solver
    public Map<Variable, Term> solve(ConstraintSatisfactionProblem constraintSatisfactionProblem) throws GeneralMathException {
        if (constraintSatisfactionProblem.size() > 0) {
            throw new IllegalArgumentException("The gradient descent method works only for optimization problems without constraints.");
        }
        this.log.trace("Solving the following optimization problem using hessian gradient descent:\n===BEGIN===\n" + constraintSatisfactionProblem + "\n===END===");
        Term targetFunction = ((OptimizationProblem) constraintSatisfactionProblem).getTargetFunction();
        if (((OptimizationProblem) constraintSatisfactionProblem).getType() == 1) {
            targetFunction = new IntegerConstant(-1).mult(targetFunction);
        }
        ArrayList arrayList = new ArrayList(targetFunction.getVariables());
        LinkedList linkedList = new LinkedList();
        Iterator<Variable> it = arrayList.iterator();
        while (it.hasNext()) {
            linkedList.add(targetFunction.derive(it.next()).simplify());
        }
        LinkedList linkedList2 = new LinkedList();
        for (Term term : linkedList) {
            LinkedList linkedList3 = new LinkedList();
            Iterator<Variable> it2 = arrayList.iterator();
            while (it2.hasNext()) {
                linkedList3.add(term.derive(it2.next()).simplify());
            }
            linkedList2.add(linkedList3);
        }
        int i = 0;
        double[] dArr = new double[arrayList.size()];
        Iterator<Variable> it3 = arrayList.iterator();
        while (it3.hasNext()) {
            dArr[i] = this.startingPoint.get(it3.next()).doubleValue();
            i++;
        }
        double[] dArr2 = new double[arrayList.size()];
        double[] dArr3 = new double[arrayList.size()];
        this.log.trace("Starting optimization.");
        while (true) {
            double[] evaluateVector = Term.evaluateVector(linkedList, dArr, arrayList);
            double manhattanDistanceToZero = VectorTools.manhattanDistanceToZero(evaluateVector);
            this.log.trace("Current manhattan distance of gradient to zero: " + manhattanDistanceToZero);
            if (manhattanDistanceToZero < 1.0E-5d) {
                break;
            }
            dArr = bestGuess(dArr, getDirection(Term.evaluateMatrix(linkedList2, dArr, arrayList), evaluateVector), linkedList, arrayList);
        }
        HashMap hashMap = new HashMap();
        int i2 = 0;
        Iterator<Variable> it4 = arrayList.iterator();
        while (it4.hasNext()) {
            int i3 = i2;
            i2++;
            hashMap.put(it4.next(), new FloatConstant(dArr[i3]));
        }
        this.log.trace("Optimum found: " + hashMap);
        return hashMap;
    }

    private double[] bestGuess(double[] dArr, double[] dArr2, List<Term> list, List<Variable> list2) {
        double manhattanDistanceToZero = VectorTools.manhattanDistanceToZero(Term.evaluateVector(list, dArr, list2));
        double[] dArr3 = new double[list2.size()];
        double d = 1.0d;
        int i = 0;
        do {
            for (int i2 = 0; i2 < list2.size(); i2++) {
                dArr3[i2] = dArr[i2] + (d * dArr2[i2]);
            }
            if (VectorTools.manhattanDistanceToZero(Term.evaluateVector(list, dArr3, list2)) < manhattanDistanceToZero) {
                return dArr3;
            }
            d /= 2.0d;
            i++;
        } while (i != 1000);
        return dArr3;
    }

    private double[] getDirection(double[][] dArr, double[] dArr2) {
        ConstraintSatisfactionProblem constraintSatisfactionProblem = new ConstraintSatisfactionProblem();
        LinkedList linkedList = new LinkedList();
        for (int i = 0; i < dArr2.length; i++) {
            linkedList.add(new FloatVariable("X" + i));
        }
        for (int i2 = 0; i2 < dArr2.length; i2++) {
            Term term = null;
            for (int i3 = 0; i3 < dArr2.length; i3++) {
                Product mult = ((Variable) linkedList.get(i3)).mult(new FloatConstant(dArr[i2][i3]));
                term = term == null ? mult : term.add(mult);
            }
            constraintSatisfactionProblem.add(new Equation(term, new FloatConstant(-dArr2[i2])));
        }
        try {
            Map<Variable, Term> solve = new ApacheCommonsSimplex().solve(constraintSatisfactionProblem);
            double[] dArr3 = new double[linkedList.size()];
            int i4 = 0;
            Iterator it = linkedList.iterator();
            while (it.hasNext()) {
                int i5 = i4;
                i4++;
                dArr3[i5] = solve.get((Variable) it.next()).doubleValue();
            }
            return dArr3;
        } catch (Exception e) {
            throw new RuntimeException();
        }
    }

    public static boolean isInstalled() throws UnsupportedOperationException {
        return true;
    }
}
