package org.clulab.fatdynet.apps;

import edu.cmu.dynet.ComputationGraph$;
import edu.cmu.dynet.Dim$;
import edu.cmu.dynet.Expression;
import edu.cmu.dynet.Expression$;
import edu.cmu.dynet.FloatPointer;
import edu.cmu.dynet.Initialize$;
import edu.cmu.dynet.LstmBuilder;
import edu.cmu.dynet.LstmBuilder$;
import edu.cmu.dynet.Parameter;
import edu.cmu.dynet.ParameterCollection;
import edu.cmu.dynet.RnnBuilder;
import edu.cmu.dynet.SimpleSGDTrainer;
import edu.cmu.dynet.SimpleSGDTrainer$;
import org.clulab.fatdynet.Model;
import org.clulab.fatdynet.Repo;
import org.clulab.fatdynet.utils.CloseableModelSaver;
import org.clulab.fatdynet.utils.Closer$;
import org.clulab.fatdynet.utils.Transducer$;
import scala.MatchError;
import scala.Predef$;
import scala.Predef$ArrowAssoc$;
import scala.StringContext;
import scala.Tuple2;
import scala.collection.Seq;
import scala.collection.Seq$;
import scala.collection.TraversableOnce;
import scala.collection.immutable.Map;
import scala.math.Numeric$FloatIsFractional$;
import scala.runtime.BoxedUnit;
import scala.runtime.BoxesRunTime;
import scala.runtime.IntRef;
import scala.runtime.RichFloat$;
import scala.runtime.RichInt$;
import scala.util.Random;

/* compiled from: PairExampleApp.scala */
/* loaded from: input_file:org/clulab/fatdynet/apps/PairExampleApp$.class */
public final class PairExampleApp$ {
    public static PairExampleApp$ MODULE$;
    private final Random random;
    private final int LAYERS_SIZE;
    private final int INPUT_SIZE;
    private final int HIDDEN_SIZE;
    private final int OUTPUT_SIZE;
    private final int ITERATIONS;
    private final Seq<PairTransformation> transformations;

    static {
        new PairExampleApp$();
    }

    public Random random() {
        return this.random;
    }

    public int LAYERS_SIZE() {
        return this.LAYERS_SIZE;
    }

    public int INPUT_SIZE() {
        return this.INPUT_SIZE;
    }

    public int HIDDEN_SIZE() {
        return this.HIDDEN_SIZE;
    }

    public int OUTPUT_SIZE() {
        return this.OUTPUT_SIZE;
    }

    public int ITERATIONS() {
        return this.ITERATIONS;
    }

    public Seq<PairTransformation> transformations() {
        return this.transformations;
    }

    public Expression mkPredictionGraph(PairModel pairModel, Seq<Object> seq, RnnBuilder rnnBuilder) {
        ComputationGraph$.MODULE$.renew();
        rnnBuilder.newGraph(rnnBuilder.newGraph$default$1());
        Expression expression = (Expression) Transducer$.MODULE$.transduce(rnnBuilder, (Seq) seq.map(obj -> {
            return $anonfun$mkPredictionGraph$1(BoxesRunTime.unboxToFloat(obj));
        }, Seq$.MODULE$.canBuildFrom())).last();
        Expression parameter = Expression$.MODULE$.parameter(pairModel.w());
        Expression parameter2 = Expression$.MODULE$.parameter(pairModel.b());
        Expression parameter3 = Expression$.MODULE$.parameter(pairModel.v());
        return parameter3.$times(Expression$.MODULE$.tanh(parameter.$times(expression).$plus(parameter2))).$plus(Expression$.MODULE$.parameter(pairModel.a()));
    }

    public Tuple2<PairModel, Seq<Object>> train() {
        ParameterCollection parameterCollection = new ParameterCollection();
        SimpleSGDTrainer simpleSGDTrainer = new SimpleSGDTrainer(parameterCollection, SimpleSGDTrainer$.MODULE$.$lessinit$greater$default$2());
        Parameter addParameters = parameterCollection.addParameters(Dim$.MODULE$.apply(Predef$.MODULE$.wrapIntArray(new int[]{HIDDEN_SIZE(), HIDDEN_SIZE()})), parameterCollection.addParameters$default$2());
        Parameter addParameters2 = parameterCollection.addParameters(Dim$.MODULE$.apply(Predef$.MODULE$.wrapIntArray(new int[]{HIDDEN_SIZE()})), parameterCollection.addParameters$default$2());
        Parameter addParameters3 = parameterCollection.addParameters(Dim$.MODULE$.apply(Predef$.MODULE$.wrapIntArray(new int[]{OUTPUT_SIZE(), HIDDEN_SIZE()})), parameterCollection.addParameters$default$2());
        Parameter addParameters4 = parameterCollection.addParameters(Dim$.MODULE$.apply(Predef$.MODULE$.wrapIntArray(new int[]{OUTPUT_SIZE()})), parameterCollection.addParameters$default$2());
        LstmBuilder lstmBuilder = new LstmBuilder(LAYERS_SIZE(), INPUT_SIZE(), HIDDEN_SIZE(), parameterCollection, LstmBuilder$.MODULE$.$lessinit$greater$default$5());
        PairModel pairModel = new PairModel(addParameters, addParameters2, addParameters3, addParameters4, lstmBuilder, parameterCollection);
        FloatPointer floatPointer = new FloatPointer();
        RichInt$.MODULE$.until$extension0(Predef$.MODULE$.intWrapper(0), ITERATIONS()).foreach$mVc$sp(i -> {
            Predef$.MODULE$.println(new StringContext(Predef$.MODULE$.wrapRefArray(new String[]{"index = ", ", loss = ", ""})).s(Predef$.MODULE$.genericWrapArray(new Object[]{BoxesRunTime.boxToInteger(i), BoxesRunTime.boxToFloat(BoxesRunTime.unboxToFloat(((TraversableOnce) MODULE$.random().shuffle(MODULE$.transformations(), Seq$.MODULE$.canBuildFrom()).map(pairTransformation -> {
                return BoxesRunTime.boxToFloat($anonfun$train$2(simpleSGDTrainer, lstmBuilder, pairModel, floatPointer, pairTransformation));
            }, Seq$.MODULE$.canBuildFrom())).sum(Numeric$FloatIsFractional$.MODULE$)))})));
            simpleSGDTrainer.learningRate_$eq(simpleSGDTrainer.learningRate() * 0.999f);
        });
        return new Tuple2<>(pairModel, predict(pairModel, lstmBuilder));
    }

    public Seq<Object> predict(PairModel pairModel, RnnBuilder rnnBuilder) {
        IntRef create = IntRef.create(0);
        Predef$.MODULE$.println();
        Seq<Object> seq = (Seq) transformations().map(pairTransformation -> {
            return BoxesRunTime.boxToFloat($anonfun$predict$1(pairModel, rnnBuilder, create, pairTransformation));
        }, Seq$.MODULE$.canBuildFrom());
        Predef$.MODULE$.println(new StringContext(Predef$.MODULE$.wrapRefArray(new String[]{"Accuracy: ", " / ", " = ", ""})).s(Predef$.MODULE$.genericWrapArray(new Object[]{BoxesRunTime.boxToInteger(create.elem), BoxesRunTime.boxToInteger(transformations().size()), BoxesRunTime.boxToFloat(create.elem / transformations().size())})));
        return seq;
    }

    public void save(String str, PairModel pairModel) {
        Closer$.MODULE$.AutoCloser(new CloseableModelSaver(str)).autoClose(closeableModelSaver -> {
            $anonfun$save$1(pairModel, closeableModelSaver);
            return BoxedUnit.UNIT;
        });
    }

    public PairModel load(String str) {
        Repo repo = new Repo(str);
        Model model = repo.getModel(repo.getDesigns(repo.getDesigns$default$1()), "/model");
        return new PairModel(model.getParameter(0), model.getParameter(1), model.getParameter(2), model.getParameter(3), model.getRnnBuilder(0), model.getParameterCollection());
    }

    public void main(String[] strArr) {
        Initialize$.MODULE$.initialize((Map<String, Object>) Predef$.MODULE$.Map().apply(Predef$.MODULE$.wrapRefArray(new Tuple2[]{Predef$ArrowAssoc$.MODULE$.$minus$greater$extension(Predef$.MODULE$.ArrowAssoc("random-seed"), BoxesRunTime.boxToLong(2522620396L))})));
        Tuple2<PairModel, Seq<Object>> train = train();
        if (train == null) {
            throw new MatchError(train);
        }
        Tuple2 tuple2 = new Tuple2((PairModel) train._1(), (Seq) train._2());
        PairModel pairModel = (PairModel) tuple2._1();
        Seq seq = (Seq) tuple2._2();
        Seq<Object> predict = predict(pairModel, pairModel.rnnBuilder());
        save("PairModel.dat", pairModel);
        PairModel load = load("PairModel.dat");
        Seq<Object> predict2 = predict(load, load.rnnBuilder());
        Predef$.MODULE$.assert(seq != null ? seq.equals(predict) : predict == null);
        Predef$.MODULE$.assert(predict != null ? predict.equals(predict2) : predict2 == null);
    }

    public static final /* synthetic */ Expression $anonfun$mkPredictionGraph$1(float f) {
        return Expression$.MODULE$.input(f);
    }

    public static final /* synthetic */ float $anonfun$train$2(SimpleSGDTrainer simpleSGDTrainer, LstmBuilder lstmBuilder, PairModel pairModel, FloatPointer floatPointer, PairTransformation pairTransformation) {
        float[] fArr = new float[pairTransformation.inputs().length];
        pairTransformation.transform(fArr, floatPointer);
        Expression squaredDistance = Expression$.MODULE$.squaredDistance(MODULE$.mkPredictionGraph(pairModel, Predef$.MODULE$.wrapFloatArray(fArr), lstmBuilder), Expression$.MODULE$.input(pairTransformation.output()));
        float f = squaredDistance.value().toFloat();
        ComputationGraph$.MODULE$.backward(squaredDistance);
        simpleSGDTrainer.update();
        return f;
    }

    public static final /* synthetic */ float $anonfun$predict$1(PairModel pairModel, RnnBuilder rnnBuilder, IntRef intRef, PairTransformation pairTransformation) {
        float[] fArr = new float[pairTransformation.inputs().length];
        pairTransformation.transform(fArr);
        float f = MODULE$.mkPredictionGraph(pairModel, Predef$.MODULE$.wrapFloatArray(fArr), rnnBuilder).value().toFloat();
        boolean z = pairTransformation.output() == RichFloat$.MODULE$.round$extension(Predef$.MODULE$.floatWrapper(f));
        if (z) {
            intRef.elem++;
        }
        Predef$.MODULE$.println(new StringContext(Predef$.MODULE$.wrapRefArray(new String[]{"TRANSFORMATION = ", ", PREDICTION = ", ", CORRECT = ", ""})).s(Predef$.MODULE$.genericWrapArray(new Object[]{pairTransformation, BoxesRunTime.boxToFloat(f), BoxesRunTime.boxToBoolean(z)})));
        return f;
    }

    public static final /* synthetic */ void $anonfun$save$1(PairModel pairModel, CloseableModelSaver closeableModelSaver) {
        closeableModelSaver.addModel(pairModel.model(), "/model");
    }

    private PairExampleApp$() {
        MODULE$ = this;
        this.random = new Random(1234L);
        this.LAYERS_SIZE = 1;
        this.INPUT_SIZE = 1;
        this.HIDDEN_SIZE = 4;
        this.OUTPUT_SIZE = 1;
        this.ITERATIONS = 200;
        this.transformations = Seq$.MODULE$.apply(Predef$.MODULE$.wrapRefArray(new PairTransformation[]{new PairTransformation(new int[]{0, 0}, 0), new PairTransformation(new int[]{0, 1}, 0), new PairTransformation(new int[]{1, 0}, 0), new PairTransformation(new int[]{1, 1}, 1), new PairTransformation(new int[]{0, 0, 0}, 0), new PairTransformation(new int[]{0, 0, 1}, 0), new PairTransformation(new int[]{0, 1, 0}, 0), new PairTransformation(new int[]{0, 1, 1}, 1), new PairTransformation(new int[]{1, 0, 0}, 0), new PairTransformation(new int[]{1, 0, 1}, 0), new PairTransformation(new int[]{1, 1, 0}, 1), new PairTransformation(new int[]{1, 1, 1}, 1), new PairTransformation(new int[]{0, 0, 0, 0}, 0), new PairTransformation(new int[]{0, 0, 0, 1}, 0), new PairTransformation(new int[]{0, 0, 1, 0}, 0), new PairTransformation(new int[]{0, 0, 1, 1}, 1), new PairTransformation(new int[]{0, 1, 0, 0}, 0), new PairTransformation(new int[]{0, 1, 0, 1}, 0), new PairTransformation(new int[]{0, 1, 1, 0}, 1), new PairTransformation(new int[]{0, 1, 1, 1}, 1), new PairTransformation(new int[]{1, 0, 0, 0}, 0), new PairTransformation(new int[]{1, 0, 0, 1}, 0), new PairTransformation(new int[]{1, 0, 1, 0}, 0), new PairTransformation(new int[]{1, 0, 1, 1}, 1), new PairTransformation(new int[]{1, 1, 0, 0}, 1), new PairTransformation(new int[]{1, 1, 0, 1}, 1), new PairTransformation(new int[]{1, 1, 1, 0}, 1), new PairTransformation(new int[]{1, 1, 1, 1}, 1), new PairTransformation(new int[]{0, 0, 0, 0, 0}, 0), new PairTransformation(new int[]{0, 0, 0, 0, 1}, 0), new PairTransformation(new int[]{0, 0, 0, 1, 0}, 0), new PairTransformation(new int[]{0, 0, 0, 1, 1}, 1), new PairTransformation(new int[]{0, 0, 1, 0, 0}, 0), new PairTransformation(new int[]{0, 0, 1, 0, 1}, 0), new PairTransformation(new int[]{0, 0, 1, 1, 0}, 1), new PairTransformation(new int[]{0, 0, 1, 1, 1}, 1), new PairTransformation(new int[]{0, 1, 0, 0, 0}, 0), new PairTransformation(new int[]{0, 1, 0, 0, 1}, 0), new PairTransformation(new int[]{0, 1, 0, 1, 0}, 0), new PairTransformation(new int[]{0, 1, 0, 1, 1}, 1), new PairTransformation(new int[]{0, 1, 1, 0, 0}, 1), new PairTransformation(new int[]{0, 1, 1, 0, 1}, 1), new PairTransformation(new int[]{0, 1, 1, 1, 0}, 1), new PairTransformation(new int[]{0, 1, 1, 1, 1}, 1), new PairTransformation(new int[]{1, 0, 0, 0, 0}, 0), new PairTransformation(new int[]{1, 0, 0, 0, 1}, 0), new PairTransformation(new int[]{1, 0, 0, 1, 0}, 0), new PairTransformation(new int[]{1, 0, 0, 1, 1}, 1), new PairTransformation(new int[]{1, 0, 1, 0, 0}, 0), new PairTransformation(new int[]{1, 0, 1, 0, 1}, 0), new PairTransformation(new int[]{1, 0, 1, 1, 0}, 1), new PairTransformation(new int[]{1, 0, 1, 1, 1}, 1), new PairTransformation(new int[]{1, 1, 0, 0, 0}, 1), new PairTransformation(new int[]{1, 1, 0, 0, 1}, 1), new PairTransformation(new int[]{1, 1, 0, 1, 0}, 1), new PairTransformation(new int[]{1, 1, 0, 1, 1}, 1), new PairTransformation(new int[]{1, 1, 1, 0, 0}, 1), new PairTransformation(new int[]{1, 1, 1, 0, 1}, 1), new PairTransformation(new int[]{1, 1, 1, 1, 0}, 1), new PairTransformation(new int[]{1, 1, 1, 1, 1}, 1)}));
    }
}
