package org.clulab.fatdynet.apps;

import edu.cmu.dynet.ComputationGraph$;
import edu.cmu.dynet.Dim$;
import edu.cmu.dynet.Expression;
import edu.cmu.dynet.Expression$;
import edu.cmu.dynet.FloatPointer;
import edu.cmu.dynet.Initialize$;
import edu.cmu.dynet.ParameterCollection;
import edu.cmu.dynet.SimpleSGDTrainer;
import edu.cmu.dynet.SimpleSGDTrainer$;
import org.clulab.fatdynet.Model;
import org.clulab.fatdynet.Repo;
import org.clulab.fatdynet.Repo$;
import org.clulab.fatdynet.apps.ExternalLookupParameterExampleApp;
import org.clulab.fatdynet.utils.CloseableModelSaver;
import org.clulab.fatdynet.utils.Closer$;
import scala.MatchError;
import scala.Predef$;
import scala.Predef$ArrowAssoc$;
import scala.StringContext;
import scala.Tuple2;
import scala.collection.Seq;
import scala.collection.Seq$;
import scala.collection.immutable.Map;
import scala.runtime.BoxesRunTime;
import scala.runtime.IntRef;
import scala.runtime.RichInt$;
import scala.util.Random;

/* compiled from: ExternalLookupParameterExampleApp.scala */
/* loaded from: input_file:org/clulab/fatdynet/apps/ExternalLookupParameterExampleApp$.class */
public final class ExternalLookupParameterExampleApp$ {
    public static final ExternalLookupParameterExampleApp$ MODULE$ = null;
    private final Random random;
    private final int INPUT_SIZE;
    private final int HIDDEN_SIZE;
    private final int OUTPUT_SIZE;
    private final int ITERATIONS;
    private final Seq<ExternalLookupParameterExampleApp.XorTransformation> transformations;

    static {
        new ExternalLookupParameterExampleApp$();
    }

    public Random random() {
        return this.random;
    }

    public int INPUT_SIZE() {
        return this.INPUT_SIZE;
    }

    public int HIDDEN_SIZE() {
        return this.HIDDEN_SIZE;
    }

    public int OUTPUT_SIZE() {
        return this.OUTPUT_SIZE;
    }

    public int ITERATIONS() {
        return this.ITERATIONS;
    }

    public Seq<ExternalLookupParameterExampleApp.XorTransformation> transformations() {
        return this.transformations;
    }

    public Expression mkPredictionGraph(ExternalLookupParameterExampleApp.XorModel xorModel, ExternalLookupParameterExampleApp.XorTransformation xorTransformation) {
        ComputationGraph$.MODULE$.renew();
        Expression parameter = Expression$.MODULE$.parameter(xorModel.w());
        Expression parameter2 = Expression$.MODULE$.parameter(xorModel.b());
        Expression parameter3 = Expression$.MODULE$.parameter(xorModel.v());
        return parameter3.$times(Expression$.MODULE$.tanh(parameter.$times(xorTransformation.transform()).$plus(parameter2))).$plus(Expression$.MODULE$.parameter(xorModel.a()));
    }

    public Tuple2<ExternalLookupParameterExampleApp.XorModel, Seq<Object>> train() {
        ParameterCollection parameterCollection = new ParameterCollection();
        SimpleSGDTrainer simpleSGDTrainer = new SimpleSGDTrainer(parameterCollection, SimpleSGDTrainer$.MODULE$.$lessinit$greater$default$2());
        ExternalLookupParameterExampleApp.XorModel xorModel = new ExternalLookupParameterExampleApp.XorModel(parameterCollection.addParameters(Dim$.MODULE$.apply(Predef$.MODULE$.wrapIntArray(new int[]{HIDDEN_SIZE(), INPUT_SIZE()})), parameterCollection.addParameters$default$2()), parameterCollection.addParameters(Dim$.MODULE$.apply(Predef$.MODULE$.wrapIntArray(new int[]{HIDDEN_SIZE()})), parameterCollection.addParameters$default$2()), parameterCollection.addParameters(Dim$.MODULE$.apply(Predef$.MODULE$.wrapIntArray(new int[]{OUTPUT_SIZE(), HIDDEN_SIZE()})), parameterCollection.addParameters$default$2()), parameterCollection.addParameters(Dim$.MODULE$.apply(Predef$.MODULE$.wrapIntArray(new int[]{OUTPUT_SIZE()})), parameterCollection.addParameters$default$2()), parameterCollection);
        RichInt$.MODULE$.until$extension0(Predef$.MODULE$.intWrapper(0), ITERATIONS()).foreach$mVc$sp(new ExternalLookupParameterExampleApp$$anonfun$train$1(simpleSGDTrainer, xorModel, new FloatPointer()));
        return new Tuple2<>(xorModel, predict(xorModel));
    }

    public Seq<Object> predict(ExternalLookupParameterExampleApp.XorModel xorModel) {
        IntRef create = IntRef.create(0);
        Predef$.MODULE$.println();
        Seq<Object> seq = (Seq) transformations().map(new ExternalLookupParameterExampleApp$$anonfun$2(xorModel, create), Seq$.MODULE$.canBuildFrom());
        Predef$.MODULE$.println(new StringContext(Predef$.MODULE$.wrapRefArray(new String[]{"Accuracy: ", " / ", " = ", ""})).s(Predef$.MODULE$.genericWrapArray(new Object[]{BoxesRunTime.boxToInteger(create.elem), BoxesRunTime.boxToInteger(transformations().size()), BoxesRunTime.boxToFloat(create.elem / transformations().size())})));
        return seq;
    }

    public void save(String str, ExternalLookupParameterExampleApp.XorModel xorModel) {
        Closer$.MODULE$.AutoCloser(new CloseableModelSaver(str)).autoClose(new ExternalLookupParameterExampleApp$$anonfun$save$1(xorModel));
    }

    public ExternalLookupParameterExampleApp.XorModel load(String str) {
        Repo apply = Repo$.MODULE$.apply(str);
        Model model = apply.getModel(apply.getDesigns(apply.getDesigns$default$1()), "/model");
        return new ExternalLookupParameterExampleApp.XorModel(model.getParameter(0), model.getParameter(1), model.getParameter(2), model.getParameter(3), model.getParameterCollection());
    }

    public void main(String[] strArr) {
        Initialize$.MODULE$.initialize((Map<String, Object>) Predef$.MODULE$.Map().apply(Predef$.MODULE$.wrapRefArray(new Tuple2[]{Predef$ArrowAssoc$.MODULE$.$minus$greater$extension(Predef$.MODULE$.ArrowAssoc("random-seed"), BoxesRunTime.boxToLong(2522620396L))})));
        Tuple2<ExternalLookupParameterExampleApp.XorModel, Seq<Object>> train = train();
        if (train == null) {
            throw new MatchError(train);
        }
        Tuple2 tuple2 = new Tuple2((ExternalLookupParameterExampleApp.XorModel) train._1(), (Seq) train._2());
        ExternalLookupParameterExampleApp.XorModel xorModel = (ExternalLookupParameterExampleApp.XorModel) tuple2._1();
        Seq seq = (Seq) tuple2._2();
        Seq<Object> predict = predict(xorModel);
        save("XorModel.dat", xorModel);
        Seq<Object> predict2 = predict(load("XorModel.dat"));
        Predef$.MODULE$.assert(seq != null ? seq.equals(predict) : predict == null);
        Predef$.MODULE$.assert(predict != null ? predict.equals(predict2) : predict2 == null);
    }

    private ExternalLookupParameterExampleApp$() {
        MODULE$ = this;
        this.random = new Random(1234L);
        this.INPUT_SIZE = 2;
        this.HIDDEN_SIZE = 2;
        this.OUTPUT_SIZE = 1;
        this.ITERATIONS = 400;
        this.transformations = Seq$.MODULE$.apply(Predef$.MODULE$.wrapRefArray(new ExternalLookupParameterExampleApp.XorTransformation[]{new ExternalLookupParameterExampleApp.XorTransformation(0, 0, 0, 0), new ExternalLookupParameterExampleApp.XorTransformation(1, 0, 1, 1), new ExternalLookupParameterExampleApp.XorTransformation(2, 1, 0, 1), new ExternalLookupParameterExampleApp.XorTransformation(3, 1, 1, 0)}));
    }
}
