package com.hankcs.hanlp.mining.word2vec;

import com.hankcs.hanlp.utility.Predefine;
import com.hankcs.hanlp.utility.TextUtility;
import java.io.IOException;

/* loaded from: input_file:com/hankcs/hanlp/mining/word2vec/Word2VecTrainer.class */
public class Word2VecTrainer {
    private boolean useHierarchicalSoftmax;
    private Float initialLearningRate;
    private TrainingCallback callback;
    private Integer layerSize = 200;
    private Integer windowSize = 5;
    private Integer numThreads = Integer.valueOf(Runtime.getRuntime().availableProcessors());
    private int negativeSamples = 25;
    private Integer minFrequency = 5;
    private float downSampleRate = 1.0E-4f;
    private Integer iterations = 15;
    private NeuralNetworkType type = NeuralNetworkType.CBOW;

    public void setCallback(TrainingCallback trainingCallback) {
        this.callback = trainingCallback;
    }

    public Word2VecTrainer setLayerSize(int i) {
        Preconditions.checkArgument(i > 0, "Value must be positive");
        this.layerSize = Integer.valueOf(i);
        return this;
    }

    public Word2VecTrainer setWindowSize(int i) {
        Preconditions.checkArgument(i > 0, "Value must be positive");
        this.windowSize = Integer.valueOf(i);
        return this;
    }

    public Word2VecTrainer useNumThreads(int i) {
        Preconditions.checkArgument(i > 0, "Value must be positive");
        this.numThreads = Integer.valueOf(i);
        return this;
    }

    public Word2VecTrainer type(NeuralNetworkType neuralNetworkType) {
        this.type = (NeuralNetworkType) Preconditions.checkNotNull(neuralNetworkType);
        return this;
    }

    public Word2VecTrainer useHierarchicalSoftmax() {
        this.useHierarchicalSoftmax = true;
        return this;
    }

    public Word2VecTrainer useNegativeSamples(int i) {
        Preconditions.checkArgument(i >= 0, "Value must be non-negative");
        this.negativeSamples = i;
        return this;
    }

    public Word2VecTrainer setMinVocabFrequency(int i) {
        Preconditions.checkArgument(i >= 0, "Value must be non-negative");
        this.minFrequency = Integer.valueOf(i);
        return this;
    }

    public Word2VecTrainer setInitialLearningRate(float f) {
        Preconditions.checkArgument(f >= 0.0f, "Value must be non-negative");
        this.initialLearningRate = Float.valueOf(f);
        return this;
    }

    public Word2VecTrainer setDownSamplingRate(float f) {
        Preconditions.checkArgument(f >= 0.0f, "Value must be non-negative");
        this.downSampleRate = f;
        return this;
    }

    public Word2VecTrainer setNumIterations(int i) {
        Preconditions.checkArgument(i > 0, "Value must be positive");
        this.iterations = Integer.valueOf(i);
        return this;
    }

    public WordVectorModel train(String str, String str2) {
        Config config = new Config();
        config.setInputFile(str);
        config.setLayer1Size(this.layerSize.intValue());
        config.setUseContinuousBagOfWords(this.type == NeuralNetworkType.CBOW);
        config.setUseHierarchicalSoftmax(this.useHierarchicalSoftmax);
        config.setNegative(this.negativeSamples);
        config.setNumThreads(this.numThreads.intValue());
        config.setAlpha(this.initialLearningRate == null ? this.type.getDefaultInitialLearningRate() : this.initialLearningRate.floatValue());
        config.setSample(this.downSampleRate);
        config.setWindow(this.windowSize.intValue());
        config.setIter(this.iterations.intValue());
        config.setMinCount(this.minFrequency.intValue());
        config.setOutputFile(str2);
        Word2VecTraining word2VecTraining = new Word2VecTraining(config);
        long currentTimeMillis = System.currentTimeMillis();
        config.setCallback(this.callback);
        try {
            word2VecTraining.trainModel();
            System.out.println();
            System.out.printf("训练结束，一共耗时：%s\n", Utility.humanTime(System.currentTimeMillis() - currentTimeMillis));
            return new WordVectorModel(str2);
        } catch (IOException e) {
            Predefine.logger.warning("训练过程中发生IO异常\n" + TextUtility.exceptionToString(e));
            return null;
        }
    }
}
