package ai.libs.jaicore.math.gradientdescent;

import java.util.Map;
import org.aeonbits.owner.ConfigFactory;
import org.api4.java.common.math.IVector;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:ai/libs/jaicore/math/gradientdescent/GradientDescentOptimizer.class */
public class GradientDescentOptimizer implements IGradientBasedOptimizer {
    private double learningRate;
    private final double gradientThreshold;
    private final int maxIterations;
    private static final Logger log = LoggerFactory.getLogger(GradientDescentOptimizer.class);

    public GradientDescentOptimizer(IGradientDescentOptimizerConfig iGradientDescentOptimizerConfig) {
        this.learningRate = iGradientDescentOptimizerConfig.learningRate();
        this.gradientThreshold = iGradientDescentOptimizerConfig.gradientThreshold();
        this.maxIterations = iGradientDescentOptimizerConfig.maxIterations();
    }

    public GradientDescentOptimizer() {
        this(ConfigFactory.create(IGradientDescentOptimizerConfig.class, new Map[0]));
    }

    @Override // ai.libs.jaicore.math.gradientdescent.IGradientBasedOptimizer
    public IVector optimize(IGradientDescendableFunction iGradientDescendableFunction, IGradientFunction iGradientFunction, IVector iVector) {
        int i = 0;
        do {
            IVector apply = iGradientFunction.apply(iVector);
            i++;
            updatePredictions(iVector, apply);
            log.warn("iteration {}:\n weights \t{} \n gradients \t{}", new Object[]{Integer.valueOf(i), iVector, apply});
            if (allGradientsAreBelowThreshold(apply)) {
                break;
            }
        } while (i < this.maxIterations);
        log.warn("Gradient descent based optimization took {} iterations.", Integer.valueOf(i));
        return iVector;
    }

    private boolean allGradientsAreBelowThreshold(IVector iVector) {
        return iVector.stream().allMatch(d -> {
            return Math.abs(d) < this.gradientThreshold || !Double.isFinite(d);
        });
    }

    private void updatePredictions(IVector iVector, IVector iVector2) {
        for (int i = 0; i < iVector.length(); i++) {
            double value = iVector.getValue(i);
            double value2 = iVector2.getValue(i);
            if (Math.abs(value2) >= this.gradientThreshold) {
                iVector.setValue(i, value + (value2 * (-1.0d) * this.learningRate));
            }
        }
    }
}
