package ai.konduit.serving.metrics;

import ai.konduit.serving.config.metrics.ColumnDistribution;
import ai.konduit.serving.config.metrics.MetricsConfig;
import ai.konduit.serving.config.metrics.MetricsRenderer;
import ai.konduit.serving.config.metrics.impl.RegressionMetricsConfig;
import ai.konduit.serving.util.MetricRenderUtils;
import io.micrometer.core.instrument.Gauge;
import io.micrometer.core.instrument.ImmutableTag;
import io.micrometer.core.instrument.MeterRegistry;
import io.micrometer.core.instrument.Tag;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.function.Supplier;
import org.datavec.api.records.Record;
import org.datavec.api.transform.analysis.counter.StatCounter;
import org.datavec.api.writable.NDArrayWritable;
import org.nd4j.linalg.api.ndarray.INDArray;

/* loaded from: input_file:ai/konduit/serving/metrics/RegressionMetrics.class */
public class RegressionMetrics implements MetricsRenderer {
    private Iterable<Tag> tags;
    private List<Gauge> outputStatsGauges;
    private List<StatCounter> statCounters;
    private RegressionMetricsConfig regressionMetricsConfig;

    /* JADX INFO: Access modifiers changed from: package-private */
    /* renamed from: ai.konduit.serving.metrics.RegressionMetrics$1, reason: invalid class name */
    /* loaded from: input_file:ai/konduit/serving/metrics/RegressionMetrics$1.class */
    public static /* synthetic */ class AnonymousClass1 {
        static final /* synthetic */ int[] $SwitchMap$ai$konduit$serving$config$metrics$impl$RegressionMetricsConfig$SampleType = new int[RegressionMetricsConfig.SampleType.values().length];

        static {
            try {
                $SwitchMap$ai$konduit$serving$config$metrics$impl$RegressionMetricsConfig$SampleType[RegressionMetricsConfig.SampleType.SUM.ordinal()] = 1;
            } catch (NoSuchFieldError e) {
            }
            try {
                $SwitchMap$ai$konduit$serving$config$metrics$impl$RegressionMetricsConfig$SampleType[RegressionMetricsConfig.SampleType.MEAN.ordinal()] = 2;
            } catch (NoSuchFieldError e2) {
            }
            try {
                $SwitchMap$ai$konduit$serving$config$metrics$impl$RegressionMetricsConfig$SampleType[RegressionMetricsConfig.SampleType.MIN.ordinal()] = 3;
            } catch (NoSuchFieldError e3) {
            }
            try {
                $SwitchMap$ai$konduit$serving$config$metrics$impl$RegressionMetricsConfig$SampleType[RegressionMetricsConfig.SampleType.MAX.ordinal()] = 4;
            } catch (NoSuchFieldError e4) {
            }
            try {
                $SwitchMap$ai$konduit$serving$config$metrics$impl$RegressionMetricsConfig$SampleType[RegressionMetricsConfig.SampleType.STDDEV_POP.ordinal()] = 5;
            } catch (NoSuchFieldError e5) {
            }
            try {
                $SwitchMap$ai$konduit$serving$config$metrics$impl$RegressionMetricsConfig$SampleType[RegressionMetricsConfig.SampleType.STDDEV_NOPOP.ordinal()] = 6;
            } catch (NoSuchFieldError e6) {
            }
            try {
                $SwitchMap$ai$konduit$serving$config$metrics$impl$RegressionMetricsConfig$SampleType[RegressionMetricsConfig.SampleType.VARIANCE_POP.ordinal()] = 7;
            } catch (NoSuchFieldError e7) {
            }
            try {
                $SwitchMap$ai$konduit$serving$config$metrics$impl$RegressionMetricsConfig$SampleType[RegressionMetricsConfig.SampleType.VARIANCE_NOPOP.ordinal()] = 8;
            } catch (NoSuchFieldError e8) {
            }
        }
    }

    /* loaded from: input_file:ai/konduit/serving/metrics/RegressionMetrics$StatCounterSupplier.class */
    private static class StatCounterSupplier implements Serializable, Supplier<Number> {
        private StatCounter statCounter;
        private RegressionMetricsConfig.SampleType sampleType;
        private ColumnDistribution columnDistribution;

        StatCounterSupplier(StatCounter statCounter, RegressionMetricsConfig.SampleType sampleType, ColumnDistribution columnDistribution) {
            this.statCounter = statCounter;
            this.sampleType = sampleType;
            this.columnDistribution = columnDistribution;
        }

        @Override // java.util.function.Supplier
        /* renamed from: get, reason: merged with bridge method [inline-methods] */
        public Number get2() {
            Double valueOf;
            switch (AnonymousClass1.$SwitchMap$ai$konduit$serving$config$metrics$impl$RegressionMetricsConfig$SampleType[this.sampleType.ordinal()]) {
                case 1:
                    valueOf = Double.valueOf(this.statCounter.getSum());
                    break;
                case 2:
                    valueOf = Double.valueOf(this.statCounter.getMean());
                    break;
                case 3:
                    valueOf = Double.valueOf(this.statCounter.getMin());
                    break;
                case 4:
                    valueOf = Double.valueOf(this.statCounter.getMax());
                    break;
                case 5:
                    valueOf = Double.valueOf(this.statCounter.getStddev(true));
                    break;
                case 6:
                    valueOf = Double.valueOf(this.statCounter.getStddev(false));
                    break;
                case 7:
                    valueOf = Double.valueOf(this.statCounter.getVariance(true));
                    break;
                case 8:
                    valueOf = Double.valueOf(this.statCounter.getVariance(false));
                    break;
                default:
                    return Double.valueOf(0.0d);
            }
            if (this.columnDistribution != null) {
                valueOf = Double.valueOf(MetricRenderUtils.deNormalizeValue(valueOf.doubleValue(), this.columnDistribution));
            }
            return valueOf;
        }
    }

    public RegressionMetrics(RegressionMetricsConfig regressionMetricsConfig) {
        this(regressionMetricsConfig, Arrays.asList(new ImmutableTag("machinelearning", "regression")));
    }

    public RegressionMetrics(RegressionMetricsConfig regressionMetricsConfig, Iterable<Tag> iterable) {
        this.regressionMetricsConfig = regressionMetricsConfig;
        this.tags = iterable;
        this.outputStatsGauges = new ArrayList();
        this.statCounters = new ArrayList();
    }

    public void bindTo(MeterRegistry meterRegistry) {
        for (int i = 0; i < this.regressionMetricsConfig.getRegressionColumnLabels().size(); i++) {
            StatCounter statCounter = new StatCounter();
            this.statCounters.add(statCounter);
            this.outputStatsGauges.add(Gauge.builder((String) this.regressionMetricsConfig.getRegressionColumnLabels().get(i), new StatCounterSupplier(statCounter, (RegressionMetricsConfig.SampleType) this.regressionMetricsConfig.getSampleTypes().get(i), (this.regressionMetricsConfig.getColumnDistributions() == null || this.regressionMetricsConfig.getColumnDistributions().size() != this.regressionMetricsConfig.getRegressionColumnLabels().size()) ? null : (ColumnDistribution) this.regressionMetricsConfig.getColumnDistributions().get(i))).tags(this.tags).description("Regression values seen so far for label " + ((String) this.regressionMetricsConfig.getRegressionColumnLabels().get(i))).baseUnit("regression.outcome").register(meterRegistry));
        }
    }

    public MetricsConfig config() {
        return this.regressionMetricsConfig;
    }

    public void updateMetrics(Object... objArr) {
        if (objArr[0] instanceof Record) {
            incrementRegressionCounters(new Record[]{(Record) objArr[0]});
            return;
        }
        if (objArr[0] instanceof Record[]) {
            incrementRegressionCounters((Record[]) objArr[0]);
        } else if (objArr[0] instanceof INDArray) {
            incrementRegressionCounters(new INDArray[]{(INDArray) objArr[0]});
        } else if (objArr[0] instanceof INDArray[]) {
            incrementRegressionCounters((INDArray[]) objArr[0]);
        }
    }

    private void incrementRegressionCounters(INDArray[] iNDArrayArr) {
        synchronized (this.statCounters) {
            handleNdArray(iNDArrayArr[0]);
        }
    }

    private void incrementRegressionCounters(Record[] recordArr) {
        synchronized (this.statCounters) {
            handleNdArray(((NDArrayWritable) recordArr[0].getRecord().get(0)).get());
        }
    }

    private void handleNdArray(INDArray iNDArray) {
        if (iNDArray.isVector()) {
            for (int i = 0; i < iNDArray.length(); i++) {
                this.statCounters.get(i).add(iNDArray.getDouble(i));
            }
            return;
        }
        if (!iNDArray.isMatrix() || iNDArray.length() <= 1) {
            if (!iNDArray.isScalar()) {
                throw new IllegalArgumentException("Only vectors and matrices supported right now");
            }
            this.statCounters.get(0).add(iNDArray.sumNumber().doubleValue());
            return;
        }
        for (int i2 = 0; i2 < iNDArray.rows(); i2++) {
            for (int i3 = 0; i3 < iNDArray.columns(); i3++) {
                this.statCounters.get(i2).add(iNDArray.getDouble(i2, i3));
            }
        }
    }

    public List<Gauge> getOutputStatsGauges() {
        return this.outputStatsGauges;
    }
}
