package com.gengoai.apollo.ml.model;

import com.gengoai.LogUtils;
import com.gengoai.Stopwatch;
import com.gengoai.apollo.ml.model.Params;
import java.io.Serializable;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.function.IntToDoubleFunction;
import java.util.logging.Logger;

/* loaded from: input_file:com/gengoai/apollo/ml/model/StoppingCriteria.class */
public final class StoppingCriteria implements Serializable {
    private final String criteriaName;
    private final LinkedList<Double> history = new LinkedList<>();
    private Logger logger = LogUtils.getGlobalLogger();
    private int historySize = 5;
    private int maxIterations = 100;
    private int reportInterval = -1;
    private double tolerance = 1.0E-6d;

    private StoppingCriteria(String str) {
        this.criteriaName = str;
    }

    public static StoppingCriteria create() {
        return new StoppingCriteria("loss");
    }

    public static StoppingCriteria create(String str) {
        return new StoppingCriteria(str);
    }

    public static StoppingCriteria create(String str, FitParameters<?> fitParameters) {
        StoppingCriteria stoppingCriteria = new StoppingCriteria(str);
        stoppingCriteria.reportInterval = ((Boolean) fitParameters.verbose.value()).booleanValue() ? ((Integer) fitParameters.getOrDefault(Params.Optimizable.reportInterval, -1)).intValue() : -1;
        stoppingCriteria.historySize = ((Integer) fitParameters.getOrDefault(Params.Optimizable.historySize, 5)).intValue();
        stoppingCriteria.tolerance = ((Double) fitParameters.getOrDefault(Params.Optimizable.tolerance, Double.valueOf(1.0E-6d))).doubleValue();
        stoppingCriteria.maxIterations = ((Integer) fitParameters.getOrDefault(Params.Optimizable.maxIterations, 100)).intValue();
        return stoppingCriteria;
    }

    public boolean check(double d) {
        boolean z = false;
        if (!Double.isFinite(d)) {
            System.err.println("Non Finite loss, aborting");
            return true;
        }
        if (this.history.size() >= this.historySize) {
            z = Math.abs(d - this.history.removeLast().doubleValue()) <= this.tolerance;
            Iterator<Double> it = this.history.iterator();
            while (z && it.hasNext()) {
                double doubleValue = it.next().doubleValue();
                z = Math.abs(d - doubleValue) <= this.tolerance || d > doubleValue;
            }
        }
        this.history.addFirst(Double.valueOf(d));
        return z;
    }

    public String criteriaName() {
        return this.criteriaName;
    }

    public int historySize() {
        return this.historySize;
    }

    public StoppingCriteria historySize(int i) {
        this.historySize = i;
        return this;
    }

    public double lastLoss() {
        return this.history.getFirst().doubleValue();
    }

    public Logger logger() {
        return this.logger;
    }

    public StoppingCriteria logger(Logger logger) {
        this.logger = logger == null ? LogUtils.getGlobalLogger() : logger;
        return this;
    }

    public int maxIterations() {
        return this.maxIterations;
    }

    public StoppingCriteria maxIterations(int i) {
        this.maxIterations = i;
        return this;
    }

    public int reportInterval() {
        return this.reportInterval;
    }

    public StoppingCriteria reportInterval(int i) {
        this.reportInterval = i;
        return this;
    }

    public double tolerance() {
        return this.tolerance;
    }

    public StoppingCriteria tolerance(double d) {
        this.tolerance = d;
        return this;
    }

    public int untilTermination(IntToDoubleFunction intToDoubleFunction) {
        Stopwatch createStopped = Stopwatch.createStopped();
        double d = 0.0d;
        for (int i = 0; i < this.maxIterations; i++) {
            createStopped.reset();
            createStopped.start();
            d = intToDoubleFunction.applyAsDouble(i);
            createStopped.stop();
            if (check(d) && this.reportInterval > 0) {
                LogUtils.logInfo(this.logger, "iteration {0}: {1}={2}, time={3}, Converged", new Object[]{Integer.valueOf(i + 1), this.criteriaName, Double.valueOf(d), createStopped});
                return i;
            }
            if (this.reportInterval > 0 && (i + 1) % this.reportInterval == 0) {
                LogUtils.logInfo(this.logger, "iteration {0}: {1}={2}, time={3}", new Object[]{Integer.valueOf(i + 1), this.criteriaName, Double.valueOf(d), createStopped});
            }
        }
        if (this.reportInterval > 0 && (this.maxIterations + 1) % this.reportInterval != 0) {
            LogUtils.logInfo(this.logger, "iteration {0}: {1}={2}, time={3}, Max. Iterations Reached", new Object[]{Integer.valueOf(this.maxIterations), this.criteriaName, Double.valueOf(d), createStopped});
        }
        return this.maxIterations;
    }
}
