package ai.djl.zero.cv;

import ai.djl.Model;
import ai.djl.basicdataset.cv.ObjectDetectionDataset;
import ai.djl.basicmodelzoo.cv.object_detection.ssd.SingleShotDetection;
import ai.djl.modality.cv.Image;
import ai.djl.modality.cv.output.DetectedObjects;
import ai.djl.modality.cv.transform.ToTensor;
import ai.djl.modality.cv.translator.SingleShotDetectionTranslator;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.Block;
import ai.djl.nn.SequentialBlock;
import ai.djl.repository.zoo.ZooModel;
import ai.djl.training.DefaultTrainingConfig;
import ai.djl.training.EasyTrain;
import ai.djl.training.Trainer;
import ai.djl.training.dataset.Dataset;
import ai.djl.training.evaluator.BoundingBoxError;
import ai.djl.training.evaluator.SingleShotDetectionAccuracy;
import ai.djl.training.listener.TrainingListener;
import ai.djl.training.loss.SingleShotDetectionLoss;
import ai.djl.translate.TranslateException;
import ai.djl.zero.Performance;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;

/* loaded from: input_file:ai/djl/zero/cv/ObjectDetection.class */
public final class ObjectDetection {
    private ObjectDetection() {
    }

    public static ZooModel<Image, DetectedObjects> train(ObjectDetectionDataset objectDetectionDataset, Performance performance) throws IOException, TranslateException {
        List classes = objectDetectionDataset.getClasses();
        Shape shape = new Shape(new long[]{objectDetectionDataset.getImageChannels(), ((Integer) objectDetectionDataset.getImageHeight().orElseThrow(() -> {
            return new IllegalArgumentException("The dataset must have a fixed image height");
        })).intValue(), ((Integer) objectDetectionDataset.getImageWidth().orElseThrow(() -> {
            return new IllegalArgumentException("The dataset must have a fixed image width");
        })).intValue()});
        Dataset[] randomSplit = objectDetectionDataset.randomSplit(new int[]{8, 2});
        Dataset dataset = randomSplit[0];
        Dataset dataset2 = randomSplit[1];
        Block ssdTrainBlock = getSsdTrainBlock(classes.size());
        Model newInstance = Model.newInstance("ObjectDetection");
        newInstance.setBlock(ssdTrainBlock);
        Trainer newTrainer = newInstance.newTrainer(new DefaultTrainingConfig(new SingleShotDetectionLoss()).addEvaluator(new SingleShotDetectionAccuracy("classAccuracy")).addEvaluator(new BoundingBoxError("boundingBoxError")).addTrainingListeners(TrainingListener.Defaults.basic()));
        try {
            newTrainer.initialize(new Shape[]{new Shape(new long[]{1}).addAll(shape)});
            EasyTrain.fit(newTrainer, 50, dataset, dataset2);
            if (newTrainer != null) {
                newTrainer.close();
            }
            return new ZooModel<>(newInstance, SingleShotDetectionTranslator.builder().addTransform(new ToTensor()).optSynset(classes).optThreshold(0.6f).build());
        } catch (Throwable th) {
            if (newTrainer != null) {
                try {
                    newTrainer.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
            }
            throw th;
        }
    }

    private static Block getSsdTrainBlock(int i) {
        SequentialBlock sequentialBlock = new SequentialBlock();
        for (int i2 : new int[]{16, 32, 64}) {
            sequentialBlock.add(SingleShotDetection.getDownSamplingBlock(i2));
        }
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        for (int i3 = 0; i3 < 5; i3++) {
            arrayList2.add(Arrays.asList(Float.valueOf(1.0f), Float.valueOf(2.0f), Float.valueOf(0.5f)));
        }
        arrayList.add(Arrays.asList(Float.valueOf(0.2f), Float.valueOf(0.272f)));
        arrayList.add(Arrays.asList(Float.valueOf(0.37f), Float.valueOf(0.447f)));
        arrayList.add(Arrays.asList(Float.valueOf(0.54f), Float.valueOf(0.619f)));
        arrayList.add(Arrays.asList(Float.valueOf(0.71f), Float.valueOf(0.79f)));
        arrayList.add(Arrays.asList(Float.valueOf(0.88f), Float.valueOf(0.961f)));
        return SingleShotDetection.builder().setNumClasses(i).setNumFeatures(3).optGlobalPool(true).setRatios(arrayList2).setSizes(arrayList).setBaseNetwork(sequentialBlock).build();
    }
}
