package fr.janalyse.sotohp.processor;

import ai.djl.ModelException;
import ai.djl.modality.cv.Image;
import ai.djl.modality.cv.ImageFactory;
import ai.djl.modality.cv.transform.Normalize;
import ai.djl.modality.cv.transform.ToTensor;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.repository.zoo.Criteria;
import ai.djl.repository.zoo.ZooModel;
import ai.djl.training.util.ProgressBar;
import ai.djl.translate.Pipeline;
import ai.djl.translate.TranslateException;
import ai.djl.translate.Translator;
import ai.djl.translate.TranslatorContext;
import java.io.IOException;
import java.nio.file.Paths;
import java.util.Arrays;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:fr/janalyse/sotohp/processor/FeatureExtraction.class */
public final class FeatureExtraction {
    private static final Logger logger = LoggerFactory.getLogger(FeatureExtraction.class);

    /* loaded from: input_file:fr/janalyse/sotohp/processor/FeatureExtraction$FaceFeatureTranslator.class */
    public static final class FaceFeatureTranslator implements Translator<Image, float[]> {
        public NDList processInput(TranslatorContext translatorContext, Image image) {
            NDArray nDArray = image.toNDArray(translatorContext.getNDManager(), Image.Flag.COLOR);
            Pipeline pipeline = new Pipeline();
            pipeline.add(new ToTensor()).add(new Normalize(new float[]{0.5f, 0.5f, 0.5f}, new float[]{0.5019608f, 0.5019608f, 0.5019608f}));
            return pipeline.transform(new NDList(new NDArray[]{nDArray}));
        }

        /* renamed from: processOutput, reason: merged with bridge method [inline-methods] */
        public float[] m16processOutput(TranslatorContext translatorContext, NDList nDList) {
            NDList nDList2 = new NDList();
            long j = nDList.singletonOrThrow().getShape().get(0);
            for (int i = 0; i < j; i++) {
                nDList2.add(nDList.singletonOrThrow().get(new long[]{i}));
            }
            float[][] fArr = (float[][]) nDList2.stream().map((v0) -> {
                return v0.toFloatArray();
            }).toArray(i2 -> {
                return new float[i2];
            });
            float[] fArr2 = new float[fArr.length];
            for (int i3 = 0; i3 < fArr.length; i3++) {
                fArr2[i3] = fArr[i3][0];
            }
            return fArr2;
        }
    }

    private FeatureExtraction() {
    }

    public static void main(String[] strArr) throws IOException, ModelException, TranslateException {
        float[] predict = predict(ImageFactory.getInstance().fromFile(Paths.get("src/test/resources/kana1.jpg", new String[0])));
        if (predict != null) {
            logger.info(Arrays.toString(predict));
        }
    }

    public static float[] predict(Image image) throws IOException, ModelException, TranslateException {
        image.getWrappedImage();
        ZooModel loadModel = Criteria.builder().setTypes(Image.class, float[].class).optModelUrls("https://resources.djl.ai/test-models/pytorch/face_feature.zip").optModelName("face_feature").optTranslator(new FaceFeatureTranslator()).optProgress(new ProgressBar()).optEngine("PyTorch").build().loadModel();
        try {
            float[] fArr = (float[]) loadModel.newPredictor().predict(image);
            if (loadModel != null) {
                loadModel.close();
            }
            return fArr;
        } catch (Throwable th) {
            if (loadModel != null) {
                try {
                    loadModel.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
            }
            throw th;
        }
    }
}
