package ai.konduit.serving.metrics;

import ai.konduit.serving.config.metrics.MetricsConfig;
import ai.konduit.serving.config.metrics.MetricsRenderer;
import ai.konduit.serving.config.metrics.impl.MultiLabelMetricsConfig;
import io.micrometer.core.instrument.Gauge;
import io.micrometer.core.instrument.ImmutableTag;
import io.micrometer.core.instrument.MeterRegistry;
import io.micrometer.core.instrument.Tag;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.function.Supplier;
import org.datavec.api.records.Record;
import org.datavec.api.writable.NDArrayWritable;
import org.nd4j.common.primitives.AtomicDouble;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;

/* loaded from: input_file:ai/konduit/serving/metrics/MultiLabelMetrics.class */
public class MultiLabelMetrics implements MetricsRenderer {
    private MultiLabelMetricsConfig multiLabelMetricsConfig;
    private List<CurrentClassTrackerCount> classTrackerCounts;
    private Iterable<Tag> tags;
    private List<Gauge> classCounterIncrement;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:ai/konduit/serving/metrics/MultiLabelMetrics$CurrentClassTrackerCount.class */
    public static class CurrentClassTrackerCount implements Supplier<Number> {
        private AtomicDouble currCounter;

        private CurrentClassTrackerCount() {
            this.currCounter = new AtomicDouble(0.0f);
        }

        public void increment(double d) {
            this.currCounter.getAndAdd(d);
        }

        public void reset() {
            this.currCounter.set(0.0d);
        }

        /* JADX WARN: Can't rename method to resolve collision */
        @Override // java.util.function.Supplier
        public Number get() {
            double d = this.currCounter.get();
            reset();
            return Double.valueOf(d);
        }
    }

    public MultiLabelMetrics(MultiLabelMetricsConfig multiLabelMetricsConfig, Iterable<Tag> iterable) {
        this.multiLabelMetricsConfig = multiLabelMetricsConfig;
        this.tags = iterable;
        this.classTrackerCounts = new ArrayList();
        this.classCounterIncrement = new ArrayList();
    }

    public MultiLabelMetrics(MultiLabelMetricsConfig multiLabelMetricsConfig) {
        this(multiLabelMetricsConfig, Arrays.asList(new ImmutableTag("machinelearning", "multilabel")));
    }

    public MetricsConfig config() {
        return this.multiLabelMetricsConfig;
    }

    public void updateMetrics(Object... objArr) {
        if (objArr[0] instanceof Record) {
            incrementClassificationCounters(new Record[]{(Record) objArr[0]});
            return;
        }
        if (objArr[0] instanceof Record[]) {
            incrementClassificationCounters((Record[]) objArr[0]);
        } else if (objArr[0] instanceof INDArray) {
            incrementClassificationCounters(new INDArray[]{(INDArray) objArr[0]});
        } else if (objArr[0] instanceof INDArray[]) {
            incrementClassificationCounters((INDArray[]) objArr[0]);
        }
    }

    private void incrementClassificationCounters(INDArray[] iNDArrayArr) {
        INDArray argMax = Nd4j.argMax(iNDArrayArr[0], new int[]{-1});
        for (int i = 0; i < argMax.length(); i++) {
            this.classTrackerCounts.get(argMax.getInt(new int[]{i})).increment(1.0d);
        }
    }

    private void incrementClassificationCounters(Record[] recordArr) {
        if (this.classCounterIncrement != null) {
            INDArray argMax = Nd4j.argMax(((NDArrayWritable) recordArr[0].getRecord().get(0)).get(), new int[]{-1});
            for (int i = 0; i < argMax.length(); i++) {
                this.classTrackerCounts.get(argMax.getInt(new int[]{i})).increment(1.0d);
            }
        }
    }

    private void handleNdArray(INDArray iNDArray) {
        if (iNDArray.isScalar()) {
            this.classTrackerCounts.get(0).increment(iNDArray.getDouble(0L));
            return;
        }
        if (!iNDArray.isMatrix()) {
            if (iNDArray.isVector()) {
                for (int i = 0; i < iNDArray.length(); i++) {
                    this.classTrackerCounts.get(iNDArray.getInt(new int[]{i})).increment(iNDArray.getDouble(i));
                }
                return;
            }
            return;
        }
        for (int i2 = 0; i2 < iNDArray.rows(); i2++) {
            for (int i3 = 0; i3 < iNDArray.columns(); i3++) {
                this.classTrackerCounts.get(iNDArray.getInt(new int[]{i2})).increment(iNDArray.getDouble(i2, i3));
            }
        }
    }

    public void bindTo(MeterRegistry meterRegistry) {
        for (int i = 0; i < this.multiLabelMetricsConfig.getLabels().size(); i++) {
            CurrentClassTrackerCount currentClassTrackerCount = new CurrentClassTrackerCount();
            this.classTrackerCounts.add(currentClassTrackerCount);
            this.classCounterIncrement.add(Gauge.builder((String) this.multiLabelMetricsConfig.getLabels().get(i), currentClassTrackerCount).tags(this.tags).description("Multi-label Classification counts seen so far for label " + ((String) this.multiLabelMetricsConfig.getLabels().get(i))).baseUnit("multilabelclassification.outcome").register(meterRegistry));
        }
    }

    public MultiLabelMetricsConfig getMultiLabelMetricsConfig() {
        return this.multiLabelMetricsConfig;
    }

    public List<CurrentClassTrackerCount> getClassTrackerCounts() {
        return this.classTrackerCounts;
    }

    public List<Gauge> getClassCounterIncrement() {
        return this.classCounterIncrement;
    }
}
