package ai.djl.modality.cv.translator;

import ai.djl.modality.cv.output.Joints;
import ai.djl.modality.cv.translator.BaseImageTranslator;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.index.NDIndex;
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;
import ai.djl.translate.TranslatorContext;
import java.util.ArrayList;
import java.util.Map;

/* loaded from: input_file:ai/djl/modality/cv/translator/SimplePoseTranslator.class */
public class SimplePoseTranslator extends BaseImageTranslator<Joints> {
    private float threshold;

    /* loaded from: input_file:ai/djl/modality/cv/translator/SimplePoseTranslator$Builder.class */
    public static class Builder extends BaseImageTranslator.BaseBuilder<Builder> {
        float threshold = 0.2f;

        Builder() {
        }

        public Builder optThreshold(float f) {
            this.threshold = f;
            return self();
        }

        /* JADX INFO: Access modifiers changed from: protected */
        /* JADX WARN: Can't rename method to resolve collision */
        @Override // ai.djl.modality.cv.translator.BaseImageTranslator.BaseBuilder
        public Builder self() {
            return this;
        }

        @Override // ai.djl.modality.cv.translator.BaseImageTranslator.BaseBuilder
        protected void configPostProcess(Map<String, Object> map) {
            this.threshold = BaseImageTranslator.getFloatValue(map, "threshold", 0.2f);
        }

        public SimplePoseTranslator build() {
            validate();
            return new SimplePoseTranslator(this);
        }
    }

    public SimplePoseTranslator(Builder builder) {
        super(builder);
        this.threshold = builder.threshold;
    }

    @Override // ai.djl.translate.PostProcessor
    public Joints processOutput(TranslatorContext translatorContext, NDList nDList) {
        NDArray singletonOrThrow = nDList.singletonOrThrow();
        int i = (int) singletonOrThrow.getShape().get(0);
        int i2 = (int) singletonOrThrow.getShape().get(1);
        int i3 = (int) singletonOrThrow.getShape().get(2);
        NDArray reshape = singletonOrThrow.reshape(new Shape(1, i, -1));
        NDArray type = reshape.argMax(2).reshape(new Shape(1, i, -1)).toType(DataType.FLOAT32, false);
        NDArray max = reshape.max(new int[]{2}, true);
        NDArray tile = type.tile(2, 2L);
        tile.set(new NDIndex(":, :, 0", new Object[0]), tile.get(":, :, 0", new Object[0]).mod(Integer.valueOf(i3)));
        tile.set(new NDIndex(":, :, 1", new Object[0]), tile.get(":, :, 1", new Object[0]).div(Integer.valueOf(i3)).floor());
        float[] floatArray = tile.get(max.gt(Double.valueOf(0.0d)).toType(DataType.UINT8, false).tile(2, 2L).toType(DataType.BOOLEAN, false)).toFloatArray();
        float[] floatArray2 = max.toFloatArray();
        ArrayList arrayList = new ArrayList(i);
        for (int i4 = 0; i4 < i; i4++) {
            if (floatArray2[i4] > this.threshold) {
                arrayList.add(new Joints.Joint(floatArray[i4 * 2] / i3, floatArray[(i4 * 2) + 1] / i2, floatArray2[i4]));
            }
        }
        return new Joints(arrayList);
    }

    public static Builder builder() {
        return new Builder();
    }

    public static Builder builder(Map<String, Object> map) {
        Builder builder = new Builder();
        builder.configPreProcess(map);
        builder.configPostProcess(map);
        return builder;
    }
}
