/*
 * Decompiled with CFR 0.152.
 */
package com.whylogs.core.metrics;

import com.google.common.base.Preconditions;
import com.whylogs.core.message.RegressionMetricsMessage;
import java.util.Map;
import java.util.Objects;
import lombok.NonNull;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class RegressionMetrics {
    private static final Logger log = LoggerFactory.getLogger(RegressionMetrics.class);
    @NonNull
    private final String predictionField;
    @NonNull
    private final String targetField;
    private double sumAbsDiff;
    private double sumDiff;
    private double sum2Diff;
    private long count;

    public void track(Map<String, ?> columns) {
        Number prediction = (Number)columns.get(this.predictionField);
        Number target = (Number)columns.get(this.targetField);
        double diff = prediction.doubleValue() - target.doubleValue();
        this.sumAbsDiff += Math.abs(diff);
        this.sumDiff += diff;
        this.sum2Diff += diff * diff;
        ++this.count;
    }

    public RegressionMetrics copy() {
        RegressionMetrics res = new RegressionMetrics(this.predictionField, this.targetField);
        res.sumAbsDiff = this.sumAbsDiff;
        res.sumDiff = this.sumDiff;
        res.sum2Diff = this.sum2Diff;
        res.count = this.count;
        return res;
    }

    public RegressionMetrics merge(RegressionMetrics other) {
        if (other == null) {
            return this.copy();
        }
        Preconditions.checkState((boolean)Objects.equals(this.predictionField, other.predictionField), (String)"Mismatched prediction fields: %s vs %s", (Object[])new Object[]{this.predictionField, other.predictionField});
        Preconditions.checkState((boolean)Objects.equals(this.targetField, other.targetField), (String)"Mismatched target fields: %s vs %s", (Object[])new Object[]{this.targetField, other.targetField});
        RegressionMetrics result = new RegressionMetrics(this.predictionField, this.targetField);
        result.sumAbsDiff = this.sumAbsDiff + other.sumAbsDiff;
        result.sumDiff = this.sumDiff + other.sumDiff;
        result.sum2Diff = this.sum2Diff + other.sum2Diff;
        result.count = this.count + other.count;
        return result;
    }

    public RegressionMetricsMessage.Builder toProtobuf() {
        return RegressionMetricsMessage.newBuilder().setPredictionField(this.predictionField).setTargetField(this.targetField).setSumAbsDiff(this.sumAbsDiff).setSumDiff(this.sumDiff).setSum2Diff(this.sum2Diff).setCount(this.count);
    }

    public static RegressionMetrics fromProtobuf(RegressionMetricsMessage msg) {
        if (msg == null || msg.getSerializedSize() == 0) {
            return null;
        }
        if ("".equals(msg.getPredictionField()) || "".equals(msg.getTargetField())) {
            log.warn("Skipping Regression metrics: prediction or target field not set");
            return null;
        }
        RegressionMetrics res = new RegressionMetrics(msg.getPredictionField(), msg.getTargetField());
        res.sumAbsDiff = msg.getSumAbsDiff();
        res.sumDiff = msg.getSumDiff();
        res.sum2Diff = msg.getSum2Diff();
        res.count = msg.getCount();
        return res;
    }

    public RegressionMetrics(@NonNull String predictionField, @NonNull String targetField) {
        if (predictionField == null) {
            throw new NullPointerException("predictionField is marked non-null but is null");
        }
        if (targetField == null) {
            throw new NullPointerException("targetField is marked non-null but is null");
        }
        this.predictionField = predictionField;
        this.targetField = targetField;
    }

    @NonNull
    public String getPredictionField() {
        return this.predictionField;
    }

    @NonNull
    public String getTargetField() {
        return this.targetField;
    }

    public double getSumAbsDiff() {
        return this.sumAbsDiff;
    }

    public double getSumDiff() {
        return this.sumDiff;
    }

    public double getSum2Diff() {
        return this.sum2Diff;
    }

    public long getCount() {
        return this.count;
    }
}

