package org.clulab.lm;

import com.typesafe.config.ConfigFactory;
import edu.cmu.dynet.Dim$;
import edu.cmu.dynet.LookupParameter;
import edu.cmu.dynet.LstmBuilder;
import edu.cmu.dynet.LstmBuilder$;
import edu.cmu.dynet.ParameterCollection;
import org.clulab.embeddings.word2vec.Word2Vec;
import org.clulab.sequences.LstmUtils;
import org.clulab.sequences.LstmUtils$;
import org.clulab.utils.Serializer$;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import scala.None$;
import scala.Predef$;
import scala.Some;
import scala.StringContext;
import scala.collection.immutable.Map;
import scala.collection.immutable.Nil$;
import scala.runtime.BoxesRunTime;

/* compiled from: RnnLMTrain.scala */
/* loaded from: input_file:org/clulab/lm/RnnLMTrain$.class */
public final class RnnLMTrain$ {
    public static RnnLMTrain$ MODULE$;
    private final Logger logger;

    static {
        new RnnLMTrain$();
    }

    public Logger logger() {
        return this.logger;
    }

    public void main(String[] strArr) {
        LstmUtils$.MODULE$.initializeDyNet(LstmUtils$.MODULE$.initializeDyNet$default$1(), LstmUtils$.MODULE$.initializeDyNet$default$2());
        FlairConfig flairConfig = new FlairConfig(ConfigFactory.load("rnnlm-en"));
        int argInt = flairConfig.getArgInt("rnnlm.train.charEmbeddingSize", new Some(BoxesRunTime.boxToInteger(32)));
        int argInt2 = flairConfig.getArgInt("rnnlm.train.charRnnStateSize", new Some(BoxesRunTime.boxToInteger(16)));
        int argInt3 = flairConfig.getArgInt("rnnlm.train.wordRnnStateSize", new Some(BoxesRunTime.boxToInteger(256)));
        logger().debug(new StringContext(Predef$.MODULE$.wrapRefArray(new String[]{"Loading the character map..."})).s(Nil$.MODULE$));
        Map map = (Map) Serializer$.MODULE$.using(LstmUtils$.MODULE$.newSource(flairConfig.getArgString("rnnlm.train.c2i", None$.MODULE$)), source -> {
            return new LstmUtils.ByLineCharIntMapBuilder().build(source.getLines());
        });
        logger().debug(new StringContext(Predef$.MODULE$.wrapRefArray(new String[]{"Loaded a character map with ", " entries."})).s(Predef$.MODULE$.genericWrapArray(new Object[]{BoxesRunTime.boxToInteger(map.keySet().size())})));
        logger().debug("Loading word embeddings...");
        Word2Vec loadEmbeddings = LstmUtils$.MODULE$.loadEmbeddings(new Some(flairConfig.getArgString("rnnlm.train.docFreq", None$.MODULE$)), flairConfig.getArgInt("rnnlm.train.minWordFreq", new Some(BoxesRunTime.boxToInteger(100))), flairConfig.getArgString("rnnlm.train.embed", None$.MODULE$), new Some(flairConfig.getArgString("rnnlm.train.mandatoryWords", None$.MODULE$)), flairConfig.getArgInt("rnnlm.train.minMandatoryWordFreq", new Some(BoxesRunTime.boxToInteger(1))));
        Map<String, Object> mkWordVocab = LstmUtils$.MODULE$.mkWordVocab(loadEmbeddings);
        ParameterCollection parameterCollection = new ParameterCollection();
        LookupParameter addLookupParameters = parameterCollection.addLookupParameters(mkWordVocab.size(), Dim$.MODULE$.apply(Predef$.MODULE$.wrapIntArray(new int[]{loadEmbeddings.dimensions()})));
        LstmUtils$.MODULE$.initializeEmbeddings(loadEmbeddings, mkWordVocab, addLookupParameters);
        logger().debug("Completed loading word embeddings.");
        LookupParameter addLookupParameters2 = parameterCollection.addLookupParameters(map.size(), Dim$.MODULE$.apply(Predef$.MODULE$.wrapIntArray(new int[]{argInt})));
        LstmBuilder lstmBuilder = new LstmBuilder(1L, argInt, argInt2, parameterCollection, LstmBuilder$.MODULE$.$lessinit$greater$default$5());
        LstmBuilder lstmBuilder2 = new LstmBuilder(1L, argInt, argInt2, parameterCollection, LstmBuilder$.MODULE$.$lessinit$greater$default$5());
        int dimensions = (2 * argInt2) + loadEmbeddings.dimensions();
        LstmBuilder lstmBuilder3 = new LstmBuilder(1L, dimensions, argInt3, parameterCollection, LstmBuilder$.MODULE$.$lessinit$greater$default$5());
        LstmBuilder lstmBuilder4 = new LstmBuilder(1L, dimensions, argInt3, parameterCollection, LstmBuilder$.MODULE$.$lessinit$greater$default$5());
        int argInt4 = flairConfig.getArgInt("rnnlm.train.lmLabelCount", new Some(BoxesRunTime.boxToInteger(40000))) + 2;
        RnnLM rnnLM = new RnnLM(mkWordVocab, map, argInt3, argInt2, argInt4, parameterCollection, addLookupParameters, addLookupParameters2, lstmBuilder, lstmBuilder2, lstmBuilder3, lstmBuilder4, parameterCollection.addParameters(Dim$.MODULE$.apply(Predef$.MODULE$.wrapIntArray(new int[]{argInt4, argInt3})), parameterCollection.addParameters$default$2()), parameterCollection.addParameters(Dim$.MODULE$.apply(Predef$.MODULE$.wrapIntArray(new int[]{argInt4, argInt3})), parameterCollection.addParameters$default$2()));
        if (flairConfig.getArgBoolean("rnnlm.train.doTrain", new Some(BoxesRunTime.boxToBoolean(false))) && flairConfig.contains("rnnlm.train.train")) {
            rnnLM.trainLM(flairConfig.getArgString("rnnlm.train.train", None$.MODULE$), new Some(flairConfig.getArgString("rnnlm.train.dev", None$.MODULE$)), argInt4, flairConfig.getArgInt("rnnlm.train.logCheckpoint", new Some(BoxesRunTime.boxToInteger(1000))), flairConfig.getArgInt("rnnlm.train.saveCheckpoint", new Some(BoxesRunTime.boxToInteger(50000))), flairConfig.getArgInt("rnnlm.train.batchSize", new Some(BoxesRunTime.boxToInteger(1))));
        }
        rnnLM.save(flairConfig.getArgString("rnnlm.train.model", None$.MODULE$));
        logger().info("Done.");
    }

    private RnnLMTrain$() {
        MODULE$ = this;
        this.logger = LoggerFactory.getLogger(RnnLMTrain.class);
    }
}
