package ai.konduit.serving.pipeline.impl.step.ml.classifier;

import ai.konduit.serving.annotation.runner.CanRun;
import ai.konduit.serving.pipeline.api.context.Context;
import ai.konduit.serving.pipeline.api.data.Data;
import ai.konduit.serving.pipeline.api.data.NDArray;
import ai.konduit.serving.pipeline.api.data.ValueType;
import ai.konduit.serving.pipeline.api.step.PipelineStep;
import ai.konduit.serving.pipeline.api.step.PipelineStepRunner;
import ai.konduit.serving.pipeline.registry.MicrometerRegistry;
import ai.konduit.serving.pipeline.settings.KonduitSettings;
import ai.konduit.serving.pipeline.util.DataUtils;
import ai.konduit.serving.pipeline.util.NDArrayUtils;
import io.micrometer.core.instrument.Counter;
import io.micrometer.core.instrument.MeterRegistry;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import lombok.NonNull;

@CanRun({ClassifierOutputStep.class})
/* loaded from: input_file:ai/konduit/serving/pipeline/impl/step/ml/classifier/ClassifierOutputRunner.class */
public class ClassifierOutputRunner implements PipelineStepRunner {

    @NonNull
    protected final ClassifierOutputStep step;
    private boolean metricsSetup = false;
    private MeterRegistry registry = null;
    private List<Counter> classificationMetricsCounters = new ArrayList();

    @Override // ai.konduit.serving.pipeline.api.step.PipelineStepRunner, java.io.Closeable, java.lang.AutoCloseable
    public void close() {
    }

    @Override // ai.konduit.serving.pipeline.api.step.PipelineStepRunner
    public PipelineStep getPipelineStep() {
        return this.step;
    }

    @Override // ai.konduit.serving.pipeline.api.step.PipelineStepRunner
    public Data exec(Context context, Data data) {
        String inputName = this.step.inputName();
        if (inputName == null) {
            inputName = DataUtils.inferField(data, ValueType.NDARRAY, false, "NDArray field name was not provided and could not be inferred: multiple NDArray fields exist: %s and %s", "NDArray field name was not provided and could not be inferred: no image NDArray exist");
        }
        String probName = this.step.probName() == null ? ClassifierOutputStep.DEFAULT_PROB_NAME : this.step.probName();
        String indexName = this.step.indexName() == null ? ClassifierOutputStep.DEFAULT_PROB_NAME : this.step.indexName();
        String labelName = this.step.labelName() == null ? ClassifierOutputStep.DEFAULT_PROB_NAME : this.step.labelName();
        NDArray nDArray = data.getNDArray(inputName);
        if (nDArray.shape().length > 2) {
            throw new UnsupportedOperationException("Invalid input to ClassifierOutputStep: only rank 1 or 2 inputs are available, got array with shape" + Arrays.toString(nDArray.shape()));
        }
        NDArray FloatNDArrayToDouble = NDArrayUtils.FloatNDArrayToDouble(nDArray);
        boolean z = false;
        if (FloatNDArrayToDouble.shape().length == 2 && FloatNDArrayToDouble.shape()[0] > 1) {
            z = true;
        }
        List<String> labels = this.step.labels();
        if (labels == null) {
            labels = new ArrayList();
        }
        if (labels.isEmpty()) {
            for (int i = 0; i < FloatNDArrayToDouble.shape()[1]; i++) {
                labels.add(Integer.toString(i));
            }
        }
        if (!this.metricsSetup) {
            this.registry = MicrometerRegistry.getRegistry();
            if (this.registry != null) {
                for (String str : labels) {
                    this.classificationMetricsCounters.add(Counter.builder(str).description("Classification counts seen so far for class label: " + str).tag("servingId", KonduitSettings.getServingId()).baseUnit("classification.outcome").register(this.registry));
                }
            }
            this.metricsSetup = true;
        }
        if (!z) {
            double[] squeeze = NDArrayUtils.squeeze(FloatNDArrayToDouble);
            double[] maxValueAndIndex = NDArrayUtils.getMaxValueAndIndex(squeeze);
            double d = maxValueAndIndex[0];
            long j = (long) maxValueAndIndex[1];
            String str2 = labels.get((int) j);
            if (this.registry != null && j < this.classificationMetricsCounters.size()) {
                this.classificationMetricsCounters.get((int) j).increment();
            }
            if (this.step.topN() == null || this.step.topN().intValue() <= 1) {
                if (this.step.returnProb()) {
                    data.put(probName, d);
                }
                if (this.step.returnIndex()) {
                    data.put(indexName, j);
                }
                if (this.step.returnLabel()) {
                    data.put(labelName, str2);
                }
            } else {
                if (this.step.returnProb()) {
                    data.putListDouble(probName, Collections.singletonList(Double.valueOf(d)));
                }
                if (this.step.returnIndex()) {
                    data.putListInt64(indexName, Collections.singletonList(Long.valueOf(j)));
                }
                if (this.step.returnLabel()) {
                    data.putListString(labelName, Collections.singletonList(str2));
                }
            }
            if (this.step.allProbabilities()) {
                data.put("allProbabilities", NDArray.create(squeeze));
            }
        }
        if (z) {
            int i2 = (int) FloatNDArrayToDouble.shape()[1];
            double[][] dArr = (double[][]) FloatNDArrayToDouble.getAs(double[][].class);
            ArrayList arrayList = new ArrayList();
            ArrayList arrayList2 = new ArrayList();
            ArrayList arrayList3 = new ArrayList();
            ArrayList arrayList4 = new ArrayList();
            for (int i3 = 0; i3 < i2; i3++) {
                double[] dArr2 = dArr[i3];
                double[] maxValueAndIndex2 = NDArrayUtils.getMaxValueAndIndex(dArr2);
                double d2 = maxValueAndIndex2[0];
                long j2 = (long) maxValueAndIndex2[1];
                String str3 = labels.get((int) j2);
                if (this.registry != null && j2 < this.classificationMetricsCounters.size()) {
                    this.classificationMetricsCounters.get((int) j2).increment();
                }
                arrayList.add(Double.valueOf(d2));
                arrayList2.add(Long.valueOf(j2));
                arrayList3.add(str3);
                arrayList4.add(NDArray.create(dArr2));
            }
            if (this.step.returnProb()) {
                data.putListDouble(probName, arrayList);
            }
            if (this.step.returnIndex()) {
                data.putListInt64(indexName, arrayList2);
            }
            if (this.step.returnLabel()) {
                data.putListString(labelName, arrayList3);
            }
            if (this.step.allProbabilities()) {
                data.putListNDArray("allProbabilities", arrayList4);
            }
        }
        return data;
    }

    public ClassifierOutputRunner(@NonNull ClassifierOutputStep classifierOutputStep) {
        if (classifierOutputStep == null) {
            throw new NullPointerException("step is marked non-null but is null");
        }
        this.step = classifierOutputStep;
    }
}
