package ai.djl.fasttext.engine;

import ai.djl.Device;
import ai.djl.training.DataManager;
import ai.djl.training.TrainingConfig;
import ai.djl.training.evaluator.Evaluator;
import ai.djl.training.initializer.Initializer;
import ai.djl.training.listener.TrainingListener;
import ai.djl.training.loss.Loss;
import ai.djl.training.optimizer.Optimizer;
import java.nio.file.Path;
import java.util.ArrayList;
import java.util.List;

/* loaded from: input_file:ai/djl/fasttext/engine/FtTrainingConfig.class */
public class FtTrainingConfig implements TrainingConfig {
    private FtTrainingMode trainingMode;
    private Path outputDir;
    private String modelName;
    private int epoch;
    private int minWordCount;
    private int minLabelCount;
    private int maxNgramLength;
    private int minCharLength;
    private int maxCharLength;
    private int bucket;
    private float samplingThreshold;
    private String labelPrefix;
    private float learningRate;
    private int learningRateUpdateRate;
    private int wordVecSize;
    private int contextWindow;
    private int numNegativesSampled;
    private int threads;
    private String loss;

    /* loaded from: input_file:ai/djl/fasttext/engine/FtTrainingConfig$Builder.class */
    public static final class Builder {
        Path outputDir;
        String modelName;
        String labelPrefix;
        String loss;
        FtTrainingMode trainingMode = FtTrainingMode.SUPERVISED;
        int epoch = -1;
        int minWordCount = -1;
        int minLabelCount = -1;
        int maxNgramLength = -1;
        int minCharLength = -1;
        int maxCharLength = -1;
        int bucket = -1;
        float samplingThreshold = -1.0f;
        float learningRate = -1.0f;
        int learningRateUpdateRate = -1;
        int wordVecSize = -1;
        int contextWindow = -1;
        int numNegativesSampled = -1;
        int threads = -1;

        Builder() {
        }

        public Builder setOutputDir(Path path) {
            this.outputDir = path;
            return this;
        }

        public Builder setModelName(String str) {
            this.modelName = str;
            return this;
        }

        public Builder optTrainingMode(FtTrainingMode ftTrainingMode) {
            this.trainingMode = ftTrainingMode;
            return this;
        }

        public Builder optEpoch(int i) {
            this.epoch = i;
            return this;
        }

        public Builder optMinWordCount(int i) {
            this.minWordCount = i;
            return this;
        }

        public Builder optMinLabelCount(int i) {
            this.minLabelCount = i;
            return this;
        }

        public Builder optMaxNGramLength(int i) {
            this.maxNgramLength = i;
            return this;
        }

        public Builder optMinCharLength(int i) {
            this.minCharLength = i;
            return this;
        }

        public Builder optMaxCharLength(int i) {
            this.maxCharLength = i;
            return this;
        }

        public Builder optBucket(int i) {
            this.bucket = i;
            return this;
        }

        public Builder optSamplingThreshold(float f) {
            this.samplingThreshold = f;
            return this;
        }

        public Builder optLabelPrefix(String str) {
            this.labelPrefix = str;
            return this;
        }

        public Builder optLearningRate(float f) {
            this.learningRate = f;
            return this;
        }

        public Builder optLearningRateUpdateRate(int i) {
            this.learningRateUpdateRate = i;
            return this;
        }

        public Builder optWordVecSize(int i) {
            this.wordVecSize = i;
            return this;
        }

        public Builder optContextWindow(int i) {
            this.contextWindow = i;
            return this;
        }

        public Builder optNumNegativesSampled(int i) {
            this.numNegativesSampled = i;
            return this;
        }

        public Builder optThreads(int i) {
            this.threads = i;
            return this;
        }

        public Builder optLoss(FtLoss ftLoss) {
            this.loss = ftLoss.name().toLowerCase();
            return this;
        }

        public FtTrainingConfig build() {
            return new FtTrainingConfig(this);
        }
    }

    /* loaded from: input_file:ai/djl/fasttext/engine/FtTrainingConfig$FtLoss.class */
    public enum FtLoss {
        NS,
        HS,
        SOFTMAX,
        OVA
    }

    FtTrainingConfig(Builder builder) {
        this.trainingMode = builder.trainingMode;
        this.outputDir = builder.outputDir;
        this.modelName = builder.modelName;
        this.epoch = builder.epoch;
        this.minWordCount = builder.minWordCount;
        this.minLabelCount = builder.minLabelCount;
        this.maxNgramLength = builder.maxNgramLength;
        this.minCharLength = builder.minCharLength;
        this.maxCharLength = builder.maxCharLength;
        this.bucket = builder.bucket;
        this.samplingThreshold = builder.samplingThreshold;
        this.labelPrefix = builder.labelPrefix;
        this.learningRate = builder.learningRate;
        this.learningRateUpdateRate = builder.learningRateUpdateRate;
        this.wordVecSize = builder.wordVecSize;
        this.contextWindow = builder.contextWindow;
        this.numNegativesSampled = builder.numNegativesSampled;
        this.threads = builder.threads;
        this.loss = builder.loss;
    }

    public FtTrainingMode getTrainingMode() {
        return this.trainingMode;
    }

    public Path getOutputDir() {
        return this.outputDir;
    }

    public String getModelName() {
        return this.modelName;
    }

    public int getEpoch() {
        return this.epoch;
    }

    public int getMinWordCount() {
        return this.minWordCount;
    }

    public int getMinLabelCount() {
        return this.minLabelCount;
    }

    public int getMaxNgramLength() {
        return this.maxNgramLength;
    }

    public int getMinCharLength() {
        return this.minCharLength;
    }

    public int getMaxCharLength() {
        return this.maxCharLength;
    }

    public int getBucket() {
        return this.bucket;
    }

    public float getSamplingThreshold() {
        return this.samplingThreshold;
    }

    public String getLabelPrefix() {
        return this.labelPrefix;
    }

    public float getLearningRate() {
        return this.learningRate;
    }

    public int getLearningRateUpdateRate() {
        return this.learningRateUpdateRate;
    }

    public int getWordVecSize() {
        return this.wordVecSize;
    }

    public int getContextWindow() {
        return this.contextWindow;
    }

    public int getNumNegativesSampled() {
        return this.numNegativesSampled;
    }

    public int getThreads() {
        return this.threads;
    }

    public String getLoss() {
        return this.loss;
    }

    public Device[] getDevices() {
        return new Device[0];
    }

    public Initializer getInitializer() {
        return null;
    }

    public Optimizer getOptimizer() {
        return null;
    }

    public Loss getLossFunction() {
        return null;
    }

    public DataManager getDataManager() {
        return null;
    }

    public List<Evaluator> getEvaluators() {
        return null;
    }

    public List<TrainingListener> getTrainingListeners() {
        return null;
    }

    public String[] toCommand(String str) {
        Path absolutePath = this.outputDir.resolve(this.modelName).toAbsolutePath();
        ArrayList arrayList = new ArrayList();
        arrayList.add("fasttext");
        arrayList.add(this.trainingMode.name().toLowerCase());
        arrayList.add("-input");
        arrayList.add(str);
        arrayList.add("-output");
        arrayList.add(absolutePath.toString());
        if (this.epoch >= 0) {
            arrayList.add("-epoch");
            arrayList.add(String.valueOf(this.epoch));
        }
        if (this.minWordCount >= 0) {
            arrayList.add("-minCount");
            arrayList.add(String.valueOf(this.minWordCount));
        }
        if (this.minLabelCount >= 0) {
            arrayList.add("-minCountLabel");
            arrayList.add(String.valueOf(this.minLabelCount));
        }
        if (this.maxNgramLength >= 0) {
            arrayList.add("-wordNgrams");
            arrayList.add(String.valueOf(this.maxNgramLength));
        }
        if (this.minCharLength >= 0) {
            arrayList.add("-minn");
            arrayList.add(String.valueOf(this.minCharLength));
        }
        if (this.maxCharLength >= 0) {
            arrayList.add("-maxn");
            arrayList.add(String.valueOf(this.maxCharLength));
        }
        if (this.bucket >= 0) {
            arrayList.add("-bucket");
            arrayList.add(String.valueOf(this.bucket));
        }
        if (this.samplingThreshold >= 0.0f) {
            arrayList.add("-t");
            arrayList.add(String.valueOf(this.samplingThreshold));
        }
        if (this.labelPrefix != null) {
            arrayList.add("-label");
            arrayList.add(this.labelPrefix);
        }
        if (this.learningRate >= 0.0f) {
            arrayList.add("-lr");
            arrayList.add(String.valueOf(this.learningRate));
        }
        if (this.learningRateUpdateRate >= 0) {
            arrayList.add("-lrUpdateRate");
            arrayList.add(String.valueOf(this.learningRateUpdateRate));
        }
        if (this.wordVecSize >= 0) {
            arrayList.add("-dim");
            arrayList.add(String.valueOf(this.wordVecSize));
        }
        if (this.contextWindow >= 0) {
            arrayList.add("-ws");
            arrayList.add(String.valueOf(this.contextWindow));
        }
        if (this.numNegativesSampled >= 0) {
            arrayList.add("-neg");
            arrayList.add(String.valueOf(this.numNegativesSampled));
        }
        if (this.threads >= 0) {
            arrayList.add("-thread");
            arrayList.add(String.valueOf(this.threads));
        }
        if (this.loss != null) {
            arrayList.add("-loss");
            arrayList.add(this.loss);
        }
        return (String[]) arrayList.toArray(new String[0]);
    }

    public static Builder builder() {
        return new Builder();
    }
}
