package ai.konduit.serving.metrics;

import ai.konduit.serving.config.metrics.MetricsConfig;
import ai.konduit.serving.config.metrics.MetricsRenderer;
import ai.konduit.serving.config.metrics.impl.ClassificationMetricsConfig;
import io.micrometer.core.instrument.Gauge;
import io.micrometer.core.instrument.ImmutableTag;
import io.micrometer.core.instrument.MeterRegistry;
import io.micrometer.core.instrument.Tag;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.function.Supplier;
import org.datavec.api.records.Record;
import org.datavec.api.writable.NDArrayWritable;
import org.nd4j.common.primitives.AtomicDouble;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;

/* loaded from: input_file:ai/konduit/serving/metrics/ClassificationMetrics.class */
public class ClassificationMetrics implements MetricsRenderer {
    private Iterable<Tag> tags;
    private List<Gauge> classCounterIncrement;
    private List<CurrentClassTrackerCount> classTrackerCounts;
    private ClassificationMetricsConfig classificationMetricsConfig;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:ai/konduit/serving/metrics/ClassificationMetrics$CurrentClassTrackerCount.class */
    public static class CurrentClassTrackerCount implements Supplier<Number> {
        private AtomicDouble currCounter;

        private CurrentClassTrackerCount() {
            this.currCounter = new AtomicDouble(0.0f);
        }

        public void increment(double d) {
            this.currCounter.getAndAdd(d);
        }

        public void reset() {
            this.currCounter.set(0.0d);
        }

        /* JADX WARN: Can't rename method to resolve collision */
        @Override // java.util.function.Supplier
        public Number get() {
            double d = this.currCounter.get();
            reset();
            return Double.valueOf(d);
        }
    }

    public ClassificationMetrics(ClassificationMetricsConfig classificationMetricsConfig) {
        this(classificationMetricsConfig, Arrays.asList(new ImmutableTag("machinelearning", "classification")));
    }

    public ClassificationMetrics(ClassificationMetricsConfig classificationMetricsConfig, Iterable<Tag> iterable) {
        this.classificationMetricsConfig = classificationMetricsConfig;
        this.tags = iterable;
        this.classCounterIncrement = new ArrayList();
        this.classTrackerCounts = new ArrayList();
    }

    public void bindTo(MeterRegistry meterRegistry) {
        for (int i = 0; i < this.classificationMetricsConfig.getClassificationLabels().size(); i++) {
            CurrentClassTrackerCount currentClassTrackerCount = new CurrentClassTrackerCount();
            this.classTrackerCounts.add(currentClassTrackerCount);
            this.classCounterIncrement.add(Gauge.builder((String) this.classificationMetricsConfig.getClassificationLabels().get(i), currentClassTrackerCount).tags(this.tags).description("Classification counts seen so far for label " + ((String) this.classificationMetricsConfig.getClassificationLabels().get(i))).baseUnit("classification.outcome").register(meterRegistry));
        }
    }

    public MetricsConfig config() {
        return this.classificationMetricsConfig;
    }

    public void updateMetrics(Object... objArr) {
        if (objArr[0] instanceof Record) {
            incrementClassificationCounters(new Record[]{(Record) objArr[0]});
            return;
        }
        if (objArr[0] instanceof Record[]) {
            incrementClassificationCounters((Record[]) objArr[0]);
        } else if (objArr[0] instanceof INDArray) {
            incrementClassificationCounters(new INDArray[]{(INDArray) objArr[0]});
        } else if (objArr[0] instanceof INDArray[]) {
            incrementClassificationCounters((INDArray[]) objArr[0]);
        }
    }

    private void incrementClassificationCounters(INDArray[] iNDArrayArr) {
        handleNdArray(iNDArrayArr[0]);
    }

    private void incrementClassificationCounters(Record[] recordArr) {
        if (this.classCounterIncrement != null) {
            handleNdArray(((NDArrayWritable) recordArr[0].getRecord().get(0)).get());
        }
    }

    private void handleNdArray(INDArray iNDArray) {
        INDArray argMax = Nd4j.argMax(iNDArray, new int[]{-1});
        for (int i = 0; i < argMax.length(); i++) {
            this.classTrackerCounts.get(argMax.getInt(new int[]{i})).increment(1.0d);
        }
    }

    public List<Gauge> getClassCounterIncrement() {
        return this.classCounterIncrement;
    }

    public List<CurrentClassTrackerCount> getClassTrackerCounts() {
        return this.classTrackerCounts;
    }
}
