package org.tribuo.data;

import com.oracle.labs.mlrg.olcut.config.Option;
import com.oracle.labs.mlrg.olcut.config.Options;
import com.oracle.labs.mlrg.olcut.util.Pair;
import java.io.BufferedInputStream;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.nio.file.Path;
import java.util.Locale;
import java.util.logging.Logger;
import org.tribuo.Dataset;
import org.tribuo.ImmutableDataset;
import org.tribuo.Model;
import org.tribuo.MutableDataset;
import org.tribuo.Output;
import org.tribuo.OutputFactory;
import org.tribuo.data.columnar.RowProcessor;
import org.tribuo.data.csv.CSVDataSource;
import org.tribuo.data.csv.CSVLoader;
import org.tribuo.data.text.impl.SimpleTextDataSource;
import org.tribuo.data.text.impl.TextFeatureExtractorImpl;
import org.tribuo.data.text.impl.TokenPipeline;
import org.tribuo.dataset.MinimumCardinalityDataset;
import org.tribuo.datasource.LibSVMDataSource;
import org.tribuo.util.tokens.impl.BreakIteratorTokenizer;

/* loaded from: input_file:org/tribuo/data/DataOptions.class */
public final class DataOptions implements Options {
    private static final Logger logger = Logger.getLogger(DataOptions.class.getName());

    @Option(longName = "term-counting", usage = "Use term counts instead of boolean when using the standard text format.")
    public boolean termCounting;

    @Option(charName = 'f', longName = "model-output-path", usage = "Path to serialize model to.")
    public Path outputPath;

    @Option(longName = "csv-response-name", usage = "Response name in the csv file.")
    public String csvResponseName;

    @Option(longName = "columnar-row-processor", usage = "The name of the row processor from the config file.")
    public RowProcessor<?> rowProcessor;

    @Option(charName = 'u', longName = "training-file", usage = "Path to the training file.")
    public Path trainingPath;

    @Option(charName = 'v', longName = "testing-file", usage = "Path to the testing file.")
    public Path testingPath;

    @Option(longName = "hashing-dimension", usage = "Hashing dimension used for standard text format.")
    public int hashDim = 0;

    @Option(longName = "ngram", usage = "Ngram size to generate when using standard text format.")
    public int ngram = 2;

    @Option(charName = 'r', longName = "seed", usage = "RNG seed.")
    public long seed = 12345;

    @Option(charName = 's', longName = "input-format", usage = "Loads the data using the specified format.")
    public InputFormat inputFormat = InputFormat.LIBSVM;

    @Option(longName = "csv-delimiter", usage = "Delimiter")
    public Delimiter delimiter = Delimiter.COMMA;

    @Option(longName = "csv-quote-char", usage = "Quote character in the CSV file.")
    public char csvQuoteChar = '\"';

    @Option(longName = "min-count", usage = "Minimum cardinality of the features.")
    public int minCount = 0;

    /* loaded from: input_file:org/tribuo/data/DataOptions$Delimiter.class */
    public enum Delimiter {
        COMMA(','),
        TAB('\t'),
        SEMICOLON(';');

        public final char value;

        Delimiter(char c) {
            this.value = c;
        }
    }

    /* loaded from: input_file:org/tribuo/data/DataOptions$InputFormat.class */
    public enum InputFormat {
        SERIALIZED,
        LIBSVM,
        TEXT,
        CSV,
        COLUMNAR
    }

    public String getOptionsDescription() {
        return "Options for loading and processing train and test data.";
    }

    public <T extends Output<T>> Pair<Dataset<T>, Dataset<T>> load(OutputFactory<T> outputFactory) throws IOException {
        Dataset mutableDataset;
        ImmutableDataset mutableDataset2;
        logger.info(String.format("Loading data from %s", this.trainingPath));
        switch (this.inputFormat) {
            case SERIALIZED:
                logger.info("Deserialising dataset from " + this.trainingPath);
                try {
                    ObjectInputStream objectInputStream = new ObjectInputStream(new BufferedInputStream(new FileInputStream(this.trainingPath.toFile())));
                    Throwable th = null;
                    try {
                        ObjectInputStream objectInputStream2 = new ObjectInputStream(new BufferedInputStream(new FileInputStream(this.testingPath.toFile())));
                        Throwable th2 = null;
                        try {
                            try {
                                mutableDataset = (Dataset) objectInputStream.readObject();
                                if (this.minCount > 0) {
                                    logger.info("Found " + mutableDataset.getFeatureIDMap().size() + " features");
                                    logger.info("Removing features that occur fewer than " + this.minCount + " times.");
                                    mutableDataset = new MinimumCardinalityDataset(mutableDataset, this.minCount);
                                }
                                logger.info(String.format("Loaded %d training examples for %s", Integer.valueOf(mutableDataset.size()), mutableDataset.getOutputs().toString()));
                                logger.info("Found " + mutableDataset.getFeatureIDMap().size() + " features, and " + mutableDataset.getOutputInfo().size() + " response dimensions");
                                Dataset dataset = (Dataset) objectInputStream2.readObject();
                                mutableDataset2 = new ImmutableDataset(dataset, dataset.getSourceProvenance(), dataset.getOutputFactory(), mutableDataset.getFeatureIDMap(), mutableDataset.getOutputIDInfo(), true);
                                if (objectInputStream2 != null) {
                                    if (0 != 0) {
                                        try {
                                            objectInputStream2.close();
                                        } catch (Throwable th3) {
                                            th2.addSuppressed(th3);
                                        }
                                    } else {
                                        objectInputStream2.close();
                                    }
                                }
                                if (objectInputStream != null) {
                                    if (0 != 0) {
                                        try {
                                            objectInputStream.close();
                                        } catch (Throwable th4) {
                                            th.addSuppressed(th4);
                                        }
                                    } else {
                                        objectInputStream.close();
                                    }
                                }
                                break;
                            } finally {
                            }
                        } catch (Throwable th5) {
                            if (objectInputStream2 != null) {
                                if (th2 != null) {
                                    try {
                                        objectInputStream2.close();
                                    } catch (Throwable th6) {
                                        th2.addSuppressed(th6);
                                    }
                                } else {
                                    objectInputStream2.close();
                                }
                            }
                            throw th5;
                        }
                    } finally {
                    }
                } catch (ClassNotFoundException e) {
                    throw new IllegalArgumentException("Unknown class in serialised files", e);
                }
            case LIBSVM:
                LibSVMDataSource libSVMDataSource = new LibSVMDataSource(this.trainingPath, outputFactory);
                mutableDataset = new MutableDataset(libSVMDataSource);
                boolean isZeroIndexed = libSVMDataSource.isZeroIndexed();
                int maxFeatureID = libSVMDataSource.getMaxFeatureID();
                if (this.minCount > 0) {
                    logger.info("Removing features that occur fewer than " + this.minCount + " times.");
                    mutableDataset = new MinimumCardinalityDataset(mutableDataset, this.minCount);
                }
                logger.info(String.format("Loaded %d training examples for %s", Integer.valueOf(mutableDataset.size()), mutableDataset.getOutputs().toString()));
                logger.info("Found " + mutableDataset.getFeatureIDMap().size() + " features, and " + mutableDataset.getOutputInfo().size() + " response dimensions");
                mutableDataset2 = new ImmutableDataset(new LibSVMDataSource(this.testingPath, outputFactory, isZeroIndexed, maxFeatureID), mutableDataset.getFeatureIDMap(), mutableDataset.getOutputIDInfo(), false);
                break;
            case TEXT:
                TextFeatureExtractorImpl textFeatureExtractorImpl = this.hashDim > 0 ? new TextFeatureExtractorImpl(new TokenPipeline(new BreakIteratorTokenizer(Locale.US), this.ngram, this.termCounting, this.hashDim)) : new TextFeatureExtractorImpl(new TokenPipeline(new BreakIteratorTokenizer(Locale.US), this.ngram, this.termCounting));
                mutableDataset = new MutableDataset(new SimpleTextDataSource(this.trainingPath, outputFactory, textFeatureExtractorImpl));
                if (this.minCount > 0) {
                    logger.info("Removing features that occur fewer than " + this.minCount + " times.");
                    mutableDataset = new MinimumCardinalityDataset(mutableDataset, this.minCount);
                }
                logger.info(String.format("Loaded %d training examples for %s", Integer.valueOf(mutableDataset.size()), mutableDataset.getOutputs().toString()));
                logger.info("Found " + mutableDataset.getFeatureIDMap().size() + " features, and " + mutableDataset.getOutputInfo().size() + " response dimensions");
                mutableDataset2 = new ImmutableDataset(new SimpleTextDataSource(this.testingPath, outputFactory, textFeatureExtractorImpl), mutableDataset.getFeatureIDMap(), mutableDataset.getOutputIDInfo(), false);
                break;
            case CSV:
                if (this.csvResponseName != null) {
                    CSVLoader cSVLoader = new CSVLoader(this.delimiter.value, outputFactory);
                    mutableDataset = new MutableDataset(cSVLoader.loadDataSource(this.trainingPath, this.csvResponseName));
                    logger.info(String.format("Loaded %d training examples for %s", Integer.valueOf(mutableDataset.size()), mutableDataset.getOutputs().toString()));
                    logger.info("Found " + mutableDataset.getFeatureIDMap().size() + " features, and " + mutableDataset.getOutputInfo().size() + " response dimensions");
                    mutableDataset2 = new MutableDataset(cSVLoader.loadDataSource(this.testingPath, this.csvResponseName));
                    break;
                } else {
                    throw new IllegalArgumentException("Please supply a response column name");
                }
            case COLUMNAR:
                if (this.rowProcessor != null) {
                    OutputFactory<?> outputFactory2 = this.rowProcessor.getResponseProcessor().getOutputFactory();
                    if (!outputFactory2.equals(outputFactory)) {
                        throw new IllegalArgumentException("The RowProcessor doesn't use the same kind of OutputFactory as the one supplied. RowProcessor has " + outputFactory2.getClass().getSimpleName() + ", supplied " + outputFactory.getClass().getName());
                    }
                    RowProcessor<?> rowProcessor = this.rowProcessor;
                    char c = this.delimiter.value;
                    mutableDataset = new MutableDataset(new CSVDataSource(this.trainingPath, (RowProcessor) rowProcessor, true, c, this.csvQuoteChar));
                    logger.info(String.format("Loaded %d training examples for %s", Integer.valueOf(mutableDataset.size()), mutableDataset.getOutputs().toString()));
                    logger.info("Found " + mutableDataset.getFeatureIDMap().size() + " features, and " + mutableDataset.getOutputInfo().size() + " response dimensions");
                    mutableDataset2 = new MutableDataset(new CSVDataSource(this.testingPath, (RowProcessor) rowProcessor, true, c, this.csvQuoteChar));
                    break;
                } else {
                    throw new IllegalArgumentException("Please supply a RowProcessor");
                }
            default:
                throw new IllegalArgumentException("Unsupported input format " + this.inputFormat);
        }
        logger.info(String.format("Loaded %d testing examples", Integer.valueOf(mutableDataset2.size())));
        return new Pair<>(mutableDataset, mutableDataset2);
    }

    public <T extends Output<T>> void saveModel(Model<T> model) throws IOException {
        FileOutputStream fileOutputStream = new FileOutputStream(this.outputPath.toFile());
        ObjectOutputStream objectOutputStream = new ObjectOutputStream(fileOutputStream);
        objectOutputStream.writeObject(model);
        objectOutputStream.close();
        fileOutputStream.close();
        logger.info("Serialized model to file: " + this.outputPath);
    }
}
