package ai.djl.zero.tabular;

import ai.djl.Model;
import ai.djl.basicdataset.tabular.TabularDataset;
import ai.djl.basicmodelzoo.tabular.TabNet;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.Block;
import ai.djl.repository.zoo.ZooModel;
import ai.djl.training.DefaultTrainingConfig;
import ai.djl.training.EasyTrain;
import ai.djl.training.Trainer;
import ai.djl.training.dataset.Dataset;
import ai.djl.training.listener.TrainingListener;
import ai.djl.training.loss.TabNetRegressionLoss;
import ai.djl.translate.NoopTranslator;
import ai.djl.translate.TranslateException;
import ai.djl.zero.Performance;
import java.io.IOException;

/* loaded from: input_file:ai/djl/zero/tabular/TabularRegression.class */
public final class TabularRegression {
    private TabularRegression() {
    }

    public static ZooModel<NDList, NDList> train(TabularDataset tabularDataset, Performance performance) throws IOException, TranslateException {
        Dataset[] randomSplit = tabularDataset.randomSplit(new int[]{8, 2});
        Dataset dataset = randomSplit[0];
        Dataset dataset2 = randomSplit[1];
        int featureSize = tabularDataset.getFeatureSize();
        int labelSize = tabularDataset.getLabelSize();
        Block build = performance.equals(Performance.FAST) ? TabNet.builder().setInputDim(featureSize).setOutDim(labelSize).optNumIndependent(1).optNumShared(1).build() : performance.equals(Performance.BALANCED) ? TabNet.builder().setInputDim(featureSize).setOutDim(labelSize).build() : TabNet.builder().setInputDim(featureSize).setOutDim(labelSize).optNumIndependent(4).optNumShared(4).build();
        Model newInstance = Model.newInstance("tabular");
        newInstance.setBlock(build);
        Trainer newTrainer = newInstance.newTrainer(new DefaultTrainingConfig(new TabNetRegressionLoss()).addTrainingListeners(TrainingListener.Defaults.basic()));
        try {
            newTrainer.initialize(new Shape[]{new Shape(new long[]{1, featureSize})});
            EasyTrain.fit(newTrainer, 20, dataset, dataset2);
            if (newTrainer != null) {
                newTrainer.close();
            }
            return new ZooModel<>(newInstance, new NoopTranslator());
        } catch (Throwable th) {
            if (newTrainer != null) {
                try {
                    newTrainer.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
            }
            throw th;
        }
    }
}
