package greycat.ml.regression;

import greycat.Callback;
import greycat.Graph;
import greycat.Node;
import greycat.ml.BaseMLNode;
import greycat.ml.RegressionNode;
import greycat.plugin.NodeState;
import greycat.utility.Enforcer;
import java.util.Random;

/* loaded from: input_file:greycat/ml/regression/LiveLinearRegressionNode.class */
public class LiveLinearRegressionNode extends BaseMLNode implements RegressionNode {
    public static final double ALPHA_DEF = 1.0E-4d;
    public static final double LAMBDA_DEF = 1.0E-5d;
    public static final int ITERATION_DEF = 5;
    public static final String THRESHOLD_KEY = "THRESHOLD";
    public static final double THRESHOLD_DEF = 0.01d;
    public static final String LAST_ERR_KEY = "_ERR";
    public static final String WEIGHT_KEY = "_WEIGHT";
    private static final String INTERNAL_TOTAL_KEY = "_TOTAL_KEY";
    private static final String INTERNAL_WEIGHT_BACKUP_KEY = "_WEIGHTBACKUP";
    private static final String MISMATCH_MSG = "Different Imput lengths are not supported";
    public static final String NAME = "LiveLinearRegression";
    public static final String ALPHA_KEY = "ALPHA";
    public static final String LAMBDA_KEY = "LAMBDA";
    public static final String ITERATION_KEY = "ITERATION";
    private static final Enforcer enforcer = new Enforcer().asDouble(ALPHA_KEY).asDouble(LAMBDA_KEY).asInt(ITERATION_KEY);

    public LiveLinearRegressionNode(long j, long j2, long j3, Graph graph) {
        super(j, j2, j3, graph);
    }

    @Override // greycat.ml.RegressionNode
    public void learn(final double d, final Callback<Boolean> callback) {
        extractFeatures(new Callback<double[]>() { // from class: greycat.ml.regression.LiveLinearRegressionNode.1
            public void on(double[] dArr) {
                LiveLinearRegressionNode.this.internalLearn(dArr, d, callback);
            }
        });
    }

    public void internalLearn(double[] dArr, double d, Callback<Boolean> callback) {
        NodeState alignState = this._resolver.alignState(this);
        int intValue = ((Integer) alignState.getFromKeyWithDefault(ITERATION_KEY, 5)).intValue();
        double doubleValue = ((Double) alignState.getFromKeyWithDefault(ALPHA_KEY, Double.valueOf(1.0E-4d))).doubleValue();
        double doubleValue2 = ((Double) alignState.getFromKeyWithDefault(LAMBDA_KEY, Double.valueOf(1.0E-5d))).doubleValue();
        double[] dArr2 = (double[]) alignState.getFromKey(WEIGHT_KEY);
        if (dArr2 == null) {
            dArr2 = new double[dArr.length + 1];
            Random random = new Random();
            for (int i = 0; i < dArr2.length; i++) {
                dArr2[i] = random.nextDouble() * 0.001d;
            }
        }
        ((Double) alignState.getFromKeyWithDefault(LAST_ERR_KEY, Double.valueOf(0.0d))).doubleValue();
        alignState.setFromKey(LAST_ERR_KEY, (byte) 5, Double.valueOf(calculate(dArr2, dArr) - d));
        if (dArr == null || dArr2.length != dArr.length + 1) {
            throw new RuntimeException(MISMATCH_MSG);
        }
        int length = dArr.length;
        for (int i2 = 0; i2 < intValue; i2++) {
            double calculate = calculate(dArr2, dArr) - d;
            for (int i3 = 0; i3 < length; i3++) {
                dArr2[i3] = dArr2[i3] - (doubleValue * ((calculate * dArr[i3]) + (doubleValue2 * dArr2[i3])));
            }
            dArr2[length] = dArr2[length] - (doubleValue * calculate);
        }
        double[] dArr3 = (double[]) alignState.getFromKey(INTERNAL_WEIGHT_BACKUP_KEY);
        if (dArr3 == null) {
            alignState.setFromKey(WEIGHT_KEY, (byte) 6, dArr2);
            alignState.setFromKey(INTERNAL_WEIGHT_BACKUP_KEY, (byte) 6, dArr2);
            alignState.setFromKey(INTERNAL_TOTAL_KEY, (byte) 4, 1);
        } else {
            double d2 = 0.0d;
            for (int i4 = 0; i4 < dArr2.length; i4++) {
                d2 = Math.max(d2, Math.abs(dArr2[i4] - dArr3[i4]));
            }
            if (d2 > ((Double) alignState.getFromKeyWithDefault(THRESHOLD_KEY, Double.valueOf(0.01d))).doubleValue()) {
                NodeState phasedState = phasedState();
                phasedState.setFromKey(WEIGHT_KEY, (byte) 6, dArr2);
                phasedState.setFromKey(INTERNAL_WEIGHT_BACKUP_KEY, (byte) 6, dArr2);
                phasedState.setFromKey(INTERNAL_TOTAL_KEY, (byte) 4, 1);
            } else {
                alignState.setFromKey(WEIGHT_KEY, (byte) 6, dArr2);
                alignState.setFromKey(INTERNAL_TOTAL_KEY, (byte) 4, Integer.valueOf(((Integer) alignState.getFromKey(INTERNAL_TOTAL_KEY)).intValue() + 1));
            }
        }
        if (callback != null) {
            callback.on(true);
        }
    }

    public Node set(String str, byte b, Object obj) {
        enforcer.check(str, b, obj);
        return super.set(str, b, obj);
    }

    /* JADX INFO: Access modifiers changed from: private */
    public double calculate(double[] dArr, double[] dArr2) {
        double d = 0.0d;
        for (int i = 0; i < dArr2.length; i++) {
            d += dArr[i] * dArr2[i];
        }
        return d + dArr[dArr2.length];
    }

    @Override // greycat.ml.RegressionNode
    public void extrapolate(final Callback<Double> callback) {
        final double[] dArr = (double[]) this._resolver.resolveState(this).getFromKey(WEIGHT_KEY);
        if (dArr != null) {
            extractFeatures(new Callback<double[]>() { // from class: greycat.ml.regression.LiveLinearRegressionNode.2
                public void on(double[] dArr2) {
                    if (dArr2.length != dArr.length - 1) {
                        throw new RuntimeException(LiveLinearRegressionNode.MISMATCH_MSG);
                    }
                    if (callback != null) {
                        callback.on(Double.valueOf(LiveLinearRegressionNode.this.calculate(dArr, dArr2)));
                    }
                }
            });
        } else if (callback != null) {
            callback.on(Double.valueOf(0.0d));
        }
    }
}
