/*
 * Decompiled with CFR 0.152.
 */
package org.clulab.dynet;

import edu.cmu.dynet.Expression;
import edu.cmu.dynet.ExpressionVector;
import edu.cmu.dynet.ParameterCollection;
import java.io.Serializable;
import org.clulab.dynet.AnnotatedSentence;
import org.clulab.dynet.EmbeddingLayer;
import org.clulab.dynet.EmbeddingLayer$;
import org.clulab.dynet.FinalLayer;
import org.clulab.dynet.ForwardLayer;
import org.clulab.dynet.ForwardLayer$;
import org.clulab.dynet.InitialLayer;
import org.clulab.dynet.IntermediateLayer;
import org.clulab.dynet.Layers;
import org.clulab.dynet.RnnLayer;
import org.clulab.dynet.RnnLayer$;
import org.clulab.dynet.Utils;
import org.clulab.dynet.Utils$;
import org.clulab.fatdynet.utils.Synchronizer$;
import org.clulab.struct.Counter;
import org.clulab.utils.Configured;
import scala.Function0;
import scala.Function1;
import scala.None$;
import scala.Option;
import scala.Predef$;
import scala.Some;
import scala.StringContext;
import scala.Tuple2;
import scala.collection.BufferedIterator;
import scala.collection.IndexedSeq;
import scala.collection.Iterable;
import scala.collection.Seq;
import scala.collection.mutable.ArrayBuffer;
import scala.runtime.BooleanRef;
import scala.runtime.BoxesRunTime;
import scala.runtime.ObjectRef;
import scala.runtime.RichInt$;
import scala.runtime.java8.JFunction0;
import scala.runtime.java8.JFunction1;

public final class Layers$ {
    public static Layers$ MODULE$;
    private final int MAX_INTERMEDIATE_LAYERS;

    static {
        new Layers$();
    }

    public Layers apply(Configured config, String paramPrefix, ParameterCollection parameters, Counter<String> wordCounter, Option<Counter<String>> labelCounterOpt, boolean isDual, Option<Object> providedInputSize) {
        Option<ForwardLayer> option;
        Option<InitialLayer> initialLayer = EmbeddingLayer$.MODULE$.initialize(config, paramPrefix + ".initial", parameters, wordCounter);
        ObjectRef inputSize = ObjectRef.create((Object)(initialLayer.nonEmpty() ? new Some((Object)BoxesRunTime.boxToInteger((int)((InitialLayer)initialLayer.get()).outDim())) : (providedInputSize.nonEmpty() ? providedInputSize : None$.MODULE$)));
        ArrayBuffer intermediateLayers = new ArrayBuffer();
        BooleanRef done = BooleanRef.create((boolean)false);
        RichInt$.MODULE$.to$extension0(Predef$.MODULE$.intWrapper(1), this.MAX_INTERMEDIATE_LAYERS()).withFilter((Function1)(JFunction1.mcZI.sp & Serializable & scala.Serializable)i -> !done$1.elem).foreach((Function1)(JFunction1.mcVI.sp & Serializable & scala.Serializable)i -> {
            if (((Option)inputSize$1.elem).isEmpty()) {
                throw new RuntimeException("ERROR: trying to construct an intermediate layer without a known input size!");
            }
            Option<IntermediateLayer> intermediateLayer = RnnLayer$.MODULE$.initialize(config, paramPrefix + new StringContext((Seq)Predef$.MODULE$.wrapRefArray((Object[])new String[]{".intermediate", ""})).s((Seq)Predef$.MODULE$.genericWrapArray((Object)new Object[]{BoxesRunTime.boxToInteger((int)i)})), parameters, BoxesRunTime.unboxToInt((Object)((Option)inputSize$1.elem).get()));
            if (intermediateLayer.nonEmpty()) {
                intermediateLayers.$plus$eq(intermediateLayer.get());
                inputSize$1.elem = new Some((Object)BoxesRunTime.boxToInteger((int)((IntermediateLayer)intermediateLayer.get()).outDim()));
            } else {
                done$1.elem = true;
            }
        });
        if (labelCounterOpt.nonEmpty()) {
            if (((Option)inputSize.elem).isEmpty()) {
                throw new RuntimeException("ERROR: trying to construct a final layer without a known input size!");
            }
            option = ForwardLayer$.MODULE$.initialize(config, paramPrefix + ".final", parameters, (Counter)labelCounterOpt.get(), isDual, BoxesRunTime.unboxToInt((Object)((Option)inputSize.elem).get()));
        } else {
            option = None$.MODULE$;
        }
        Option<ForwardLayer> finalLayer = option;
        return new Layers(initialLayer, (IndexedSeq<IntermediateLayer>)intermediateLayers, (Option<FinalLayer>)finalLayer);
    }

    public int MAX_INTERMEDIATE_LAYERS() {
        return this.MAX_INTERMEDIATE_LAYERS;
    }

    public Layers loadX2i(ParameterCollection parameters, BufferedIterator<String> lines) {
        None$ none$;
        None$ none$2;
        Utils.ByLineIntBuilder byLineIntBuilder = new Utils.ByLineIntBuilder();
        int hasInitial = BoxesRunTime.unboxToInt(byLineIntBuilder.build(lines, "hasInitial"));
        if (hasInitial == 1) {
            EmbeddingLayer layer = EmbeddingLayer$.MODULE$.load(parameters, lines);
            none$2 = new Some((Object)layer);
        } else {
            none$2 = None$.MODULE$;
        }
        None$ initialLayer = none$2;
        ArrayBuffer intermediateLayers = new ArrayBuffer();
        int intermCount = BoxesRunTime.unboxToInt(byLineIntBuilder.build(lines, "intermediateCount"));
        RichInt$.MODULE$.until$extension0(Predef$.MODULE$.intWrapper(0), intermCount).foreach((Function1 & Serializable & scala.Serializable)_ -> Layers$.$anonfun$loadX2i$1(parameters, lines, intermediateLayers, BoxesRunTime.unboxToInt((Object)_)));
        int hasFinal = BoxesRunTime.unboxToInt(byLineIntBuilder.build(lines, "hasFinal"));
        if (hasFinal == 1) {
            ForwardLayer layer = ForwardLayer$.MODULE$.load(parameters, lines);
            none$ = new Some((Object)layer);
        } else {
            none$ = None$.MODULE$;
        }
        None$ finalLayer = none$;
        return new Layers((Option<InitialLayer>)initialLayer, (IndexedSeq<IntermediateLayer>)intermediateLayers, (Option<FinalLayer>)finalLayer);
    }

    /*
     * WARNING - void declaration
     */
    public IndexedSeq<IndexedSeq<String>> predictJointly(IndexedSeq<Layers> layers, AnnotatedSentence sentence) {
        void var3_3;
        ArrayBuffer labelsPerTask = new ArrayBuffer();
        Synchronizer$.MODULE$.withComputationGraph((Object)"Layers.predictJointly()", (Function0)(JFunction0.mcV.sp & Serializable & scala.Serializable)() -> {
            if (((Layers)layers.apply(0)).nonEmpty()) {
                ExpressionVector sharedStates = ((Layers)layers.apply(0)).forward(sentence, false);
                RichInt$.MODULE$.until$extension0(Predef$.MODULE$.intWrapper(1), layers.length()).foreach((Function1 & Serializable & scala.Serializable)i -> Layers$.$anonfun$predictJointly$2(layers, sentence, labelsPerTask, sharedStates, BoxesRunTime.unboxToInt((Object)i)));
            } else {
                RichInt$.MODULE$.until$extension0(Predef$.MODULE$.intWrapper(1), layers.length()).foreach((Function1 & Serializable & scala.Serializable)i -> Layers$.$anonfun$predictJointly$3(layers, sentence, labelsPerTask, BoxesRunTime.unboxToInt((Object)i)));
            }
        });
        return var3_3;
    }

    private ExpressionVector forwardForTask(IndexedSeq<Layers> layers, int taskId, AnnotatedSentence sentence, boolean doDropout) {
        ExpressionVector expressionVector;
        if (((Layers)layers.apply(0)).nonEmpty()) {
            ExpressionVector sharedStates = ((Layers)layers.apply(0)).forward(sentence, doDropout);
            expressionVector = ((Layers)layers.apply(taskId + 1)).forwardFrom(sharedStates, sentence.headPositions(), doDropout);
        } else {
            expressionVector = ((Layers)layers.apply(taskId + 1)).forward(sentence, doDropout);
        }
        ExpressionVector states = expressionVector;
        return states;
    }

    public IndexedSeq<String> predict(IndexedSeq<Layers> layers, int taskId, AnnotatedSentence sentence) {
        IndexedSeq labelsForTask = (IndexedSeq)Synchronizer$.MODULE$.withComputationGraph((Object)"Layers.predict()", (Function0 & Serializable & scala.Serializable)() -> {
            ExpressionVector states = MODULE$.forwardForTask(layers, taskId, sentence, false);
            float[][] emissionScores = Utils$.MODULE$.emissionScoresToArrays((Iterable<Expression>)states);
            IndexedSeq<String> out = ((FinalLayer)((Layers)layers.apply(taskId + 1)).finalLayer().get()).inference(emissionScores);
            return out;
        });
        return labelsForTask;
    }

    public IndexedSeq<IndexedSeq<Tuple2<String, Object>>> predictWithScores(IndexedSeq<Layers> layers, int taskId, AnnotatedSentence sentence) {
        IndexedSeq labelsForTask = (IndexedSeq)Synchronizer$.MODULE$.withComputationGraph((Object)"Layers.predictWithScores()", (Function0 & Serializable & scala.Serializable)() -> {
            ExpressionVector states = MODULE$.forwardForTask(layers, taskId, sentence, false);
            float[][] emissionScores = Utils$.MODULE$.emissionScoresToArrays((Iterable<Expression>)states);
            IndexedSeq<IndexedSeq<Tuple2<String, Object>>> out = ((FinalLayer)((Layers)layers.apply(taskId + 1)).finalLayer().get()).inferenceWithScores(emissionScores);
            return out;
        });
        return labelsForTask;
    }

    public Expression loss(IndexedSeq<Layers> layers, int taskId, AnnotatedSentence sentence, IndexedSeq<String> goldLabels) {
        ExpressionVector states = this.forwardForTask(layers, taskId, sentence, true);
        return ((FinalLayer)((Layers)layers.apply(taskId + 1)).finalLayer().get()).loss(states, goldLabels);
    }

    public static final /* synthetic */ ArrayBuffer $anonfun$loadX2i$1(ParameterCollection parameters$2, BufferedIterator lines$1, ArrayBuffer intermediateLayers$2, int _) {
        RnnLayer il = RnnLayer$.MODULE$.load(parameters$2, (BufferedIterator<String>)lines$1);
        return intermediateLayers$2.$plus$eq((Object)il);
    }

    public static final /* synthetic */ ArrayBuffer $anonfun$predictJointly$2(IndexedSeq layers$1, AnnotatedSentence sentence$1, ArrayBuffer labelsPerTask$1, ExpressionVector sharedStates$1, int i) {
        ExpressionVector states = ((Layers)layers$1.apply(i)).forwardFrom(sharedStates$1, sentence$1.headPositions(), false);
        float[][] emissionScores = Utils$.MODULE$.emissionScoresToArrays((Iterable<Expression>)states);
        IndexedSeq<String> labels = ((FinalLayer)((Layers)layers$1.apply(i)).finalLayer().get()).inference(emissionScores);
        return labelsPerTask$1.$plus$eq(labels);
    }

    public static final /* synthetic */ ArrayBuffer $anonfun$predictJointly$3(IndexedSeq layers$1, AnnotatedSentence sentence$1, ArrayBuffer labelsPerTask$1, int i) {
        ExpressionVector states = ((Layers)layers$1.apply(i)).forward(sentence$1, false);
        float[][] emissionScores = Utils$.MODULE$.emissionScoresToArrays((Iterable<Expression>)states);
        IndexedSeq<String> labels = ((FinalLayer)((Layers)layers$1.apply(i)).finalLayer().get()).inference(emissionScores);
        return labelsPerTask$1.$plus$eq(labels);
    }

    private Layers$() {
        MODULE$ = this;
        this.MAX_INTERMEDIATE_LAYERS = 10;
    }
}

