package ai.konduit.serving.pipeline.impl.step.ml.regression;

import ai.konduit.serving.annotation.runner.CanRun;
import ai.konduit.serving.pipeline.api.context.Context;
import ai.konduit.serving.pipeline.api.data.Data;
import ai.konduit.serving.pipeline.api.data.NDArray;
import ai.konduit.serving.pipeline.api.data.ValueType;
import ai.konduit.serving.pipeline.api.step.PipelineStep;
import ai.konduit.serving.pipeline.api.step.PipelineStepRunner;
import ai.konduit.serving.pipeline.util.DataUtils;
import ai.konduit.serving.pipeline.util.NDArrayUtils;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Map;

@CanRun({RegressionOutputStep.class})
/* loaded from: input_file:ai/konduit/serving/pipeline/impl/step/ml/regression/RegressionOutputRunner.class */
public class RegressionOutputRunner implements PipelineStepRunner {
    protected final RegressionOutputStep step;

    @Override // ai.konduit.serving.pipeline.api.step.PipelineStepRunner, java.io.Closeable, java.lang.AutoCloseable
    public void close() {
    }

    @Override // ai.konduit.serving.pipeline.api.step.PipelineStepRunner
    public PipelineStep getPipelineStep() {
        return this.step;
    }

    @Override // ai.konduit.serving.pipeline.api.step.PipelineStepRunner
    public Data exec(Context context, Data data) {
        String inputName = this.step.inputName();
        if (inputName == null) {
            inputName = DataUtils.inferField(data, ValueType.NDARRAY, false, "NDArray field name was not provided and could not be inferred: multiple NDArray fields exist: %s and %s", "NDArray field name was not provided and could not be inferred: no image NDArray exist");
        }
        NDArray nDArray = data.getNDArray(inputName);
        if (nDArray.shape().length > 2) {
            throw new UnsupportedOperationException("Invalid input to RegressionOutputStep: only rank 1 or 2 inputs are available, got array with shape" + Arrays.toString(nDArray.shape()));
        }
        NDArray FloatNDArrayToDouble = NDArrayUtils.FloatNDArrayToDouble(nDArray);
        boolean z = false;
        if (FloatNDArrayToDouble.shape().length == 2 && FloatNDArrayToDouble.shape()[0] > 1) {
            z = true;
        }
        Map<String, Integer> names = this.step.names();
        if (names == null || names.isEmpty()) {
            throw new UnsupportedOperationException("RegressionOutputStep names field was not provided or is null");
        }
        if (!z) {
            double[] squeeze = NDArrayUtils.squeeze(FloatNDArrayToDouble);
            for (Map.Entry<String, Integer> entry : names.entrySet()) {
                data.put(entry.getKey(), squeeze[entry.getValue().intValue()]);
            }
        }
        if (z) {
            int i = (int) FloatNDArrayToDouble.shape()[0];
            double[][] dArr = (double[][]) FloatNDArrayToDouble.getAs(double[][].class);
            for (Map.Entry<String, Integer> entry2 : names.entrySet()) {
                ArrayList arrayList = new ArrayList();
                for (int i2 = 0; i2 < i; i2++) {
                    arrayList.add(Double.valueOf(dArr[i2][entry2.getValue().intValue()]));
                }
                data.putListDouble(entry2.getKey(), arrayList);
            }
        }
        return data;
    }

    public RegressionOutputRunner(RegressionOutputStep regressionOutputStep) {
        this.step = regressionOutputStep;
    }
}
