package org.clulab.dynet;

import edu.cmu.dynet.Expression;
import edu.cmu.dynet.ExpressionVector;
import edu.cmu.dynet.ParameterCollection;
import org.clulab.dynet.Utils;
import org.clulab.fatdynet.utils.Synchronizer$;
import org.clulab.scala.WrappedArrayBuffer$;
import org.clulab.struct.Counter;
import org.clulab.utils.Configured;
import org.clulab.utils.MathUtils$;
import scala.None$;
import scala.Option;
import scala.Predef$;
import scala.Some;
import scala.Tuple2;
import scala.collection.BufferedIterator;
import scala.collection.StringOps$;
import scala.collection.immutable.IndexedSeq;
import scala.collection.mutable.ArrayBuffer;
import scala.runtime.BooleanRef;
import scala.runtime.BoxedUnit;
import scala.runtime.BoxesRunTime;
import scala.runtime.ObjectRef;
import scala.runtime.RichInt$;

/* compiled from: Layers.scala */
/* loaded from: input_file:org/clulab/dynet/Layers$.class */
public final class Layers$ {
    public static final Layers$ MODULE$ = new Layers$();
    private static final int MAX_INTERMEDIATE_LAYERS = 10;

    public Layers apply(Configured configured, String str, ParameterCollection parameterCollection, Counter<String> counter, Option<Counter<String>> option, boolean z, Option<Object> option2) {
        Option<ForwardLayer> option3;
        Option<InitialLayer> initialize = EmbeddingLayer$.MODULE$.initialize(configured, new StringBuilder(8).append(str).append(".initial").toString(), parameterCollection, counter);
        ObjectRef create = ObjectRef.create(initialize.nonEmpty() ? new Some<>(BoxesRunTime.boxToInteger(((InitialLayer) initialize.get()).outDim())) : option2.nonEmpty() ? option2 : None$.MODULE$);
        ArrayBuffer arrayBuffer = new ArrayBuffer();
        BooleanRef create2 = BooleanRef.create(false);
        RichInt$.MODULE$.to$extension(Predef$.MODULE$.intWrapper(1), MAX_INTERMEDIATE_LAYERS()).withFilter(i -> {
            return !create2.elem;
        }).foreach(i2 -> {
            if (((Option) create.elem).isEmpty()) {
                throw new RuntimeException("ERROR: trying to construct an intermediate layer without a known input size!");
            }
            Option<IntermediateLayer> initialize2 = RnnLayer$.MODULE$.initialize(configured, new StringBuilder(13).append(str).append(".intermediate").append(i2).toString(), parameterCollection, BoxesRunTime.unboxToInt(((Option) create.elem).get()));
            if (!initialize2.nonEmpty()) {
                create2.elem = true;
            } else {
                arrayBuffer.$plus$eq(initialize2.get());
                create.elem = new Some(BoxesRunTime.boxToInteger(((IntermediateLayer) initialize2.get()).outDim()));
            }
        });
        if (!option.nonEmpty()) {
            option3 = None$.MODULE$;
        } else {
            if (((Option) create.elem).isEmpty()) {
                throw new RuntimeException("ERROR: trying to construct a final layer without a known input size!");
            }
            option3 = ForwardLayer$.MODULE$.initialize(configured, new StringBuilder(6).append(str).append(".final").toString(), parameterCollection, (Counter) option.get(), z, BoxesRunTime.unboxToInt(((Option) create.elem).get()));
        }
        return new Layers(initialize, WrappedArrayBuffer$.MODULE$._toIndexedSeq(arrayBuffer), option3);
    }

    public int MAX_INTERMEDIATE_LAYERS() {
        return MAX_INTERMEDIATE_LAYERS;
    }

    public Layers loadX2i(ParameterCollection parameterCollection, BufferedIterator<String> bufferedIterator) {
        Utils.ByLineIntBuilder byLineIntBuilder = new Utils.ByLineIntBuilder();
        Some some = BoxesRunTime.unboxToInt(byLineIntBuilder.build(bufferedIterator, "hasInitial")) == 1 ? new Some(EmbeddingLayer$.MODULE$.load(parameterCollection, bufferedIterator)) : None$.MODULE$;
        ArrayBuffer arrayBuffer = new ArrayBuffer();
        RichInt$.MODULE$.until$extension(Predef$.MODULE$.intWrapper(0), BoxesRunTime.unboxToInt(byLineIntBuilder.build(bufferedIterator, "intermediateCount"))).foreach(obj -> {
            return $anonfun$loadX2i$1(parameterCollection, bufferedIterator, arrayBuffer, BoxesRunTime.unboxToInt(obj));
        });
        return new Layers(some, WrappedArrayBuffer$.MODULE$._toIndexedSeq(arrayBuffer), BoxesRunTime.unboxToInt(byLineIntBuilder.build(bufferedIterator, "hasFinal")) == 1 ? new Some(ForwardLayer$.MODULE$.load(parameterCollection, bufferedIterator)) : None$.MODULE$);
    }

    public IndexedSeq<IndexedSeq<String>> predictJointly(IndexedSeq<Layers> indexedSeq, AnnotatedSentence annotatedSentence, Option<IndexedSeq<ModifierHeadPair>> option, ConstEmbeddingParameters constEmbeddingParameters) {
        ArrayBuffer arrayBuffer = new ArrayBuffer();
        Synchronizer$.MODULE$.withComputationGraph("Layers.predictJointly()", () -> {
            if (!((Layers) indexedSeq.apply(0)).nonEmpty()) {
                RichInt$.MODULE$.until$extension(Predef$.MODULE$.intWrapper(1), indexedSeq.length()).foreach(obj -> {
                    return $anonfun$predictJointly$3(indexedSeq, annotatedSentence, option, constEmbeddingParameters, arrayBuffer, BoxesRunTime.unboxToInt(obj));
                });
            } else {
                ExpressionVector forward = ((Layers) indexedSeq.apply(0)).forward(annotatedSentence, option, constEmbeddingParameters, false);
                RichInt$.MODULE$.until$extension(Predef$.MODULE$.intWrapper(1), indexedSeq.length()).foreach(obj2 -> {
                    return $anonfun$predictJointly$2(indexedSeq, forward, option, arrayBuffer, BoxesRunTime.unboxToInt(obj2));
                });
            }
        });
        return WrappedArrayBuffer$.MODULE$._toIndexedSeq(arrayBuffer);
    }

    private ExpressionVector forwardForTask(IndexedSeq<Layers> indexedSeq, int i, AnnotatedSentence annotatedSentence, Option<IndexedSeq<ModifierHeadPair>> option, ConstEmbeddingParameters constEmbeddingParameters, boolean z) {
        ExpressionVector forward;
        if (((Layers) indexedSeq.apply(0)).nonEmpty()) {
            forward = ((Layers) indexedSeq.apply(i + 1)).forwardFrom(((Layers) indexedSeq.apply(0)).forward(annotatedSentence, option, constEmbeddingParameters, z), option, z);
        } else {
            forward = ((Layers) indexedSeq.apply(i + 1)).forward(annotatedSentence, option, constEmbeddingParameters, z);
        }
        return forward;
    }

    public IndexedSeq<String> predict(IndexedSeq<Layers> indexedSeq, int i, AnnotatedSentence annotatedSentence, Option<IndexedSeq<ModifierHeadPair>> option, ConstEmbeddingParameters constEmbeddingParameters) {
        return (IndexedSeq) Synchronizer$.MODULE$.withComputationGraph("Layers.predict()", () -> {
            return ((FinalLayer) ((Layers) indexedSeq.apply(i + 1)).finalLayer().get()).inference(Utils$.MODULE$.emissionScoresToArrays(MODULE$.forwardForTask(indexedSeq, i, annotatedSentence, option, constEmbeddingParameters, false)));
        });
    }

    public IndexedSeq<IndexedSeq<Tuple2<String, Object>>> predictWithScores(IndexedSeq<Layers> indexedSeq, int i, AnnotatedSentence annotatedSentence, Option<IndexedSeq<ModifierHeadPair>> option, ConstEmbeddingParameters constEmbeddingParameters, boolean z) {
        return (IndexedSeq) Synchronizer$.MODULE$.withComputationGraph("Layers.predictWithScores()", () -> {
            IndexedSeq<IndexedSeq<Tuple2<String, Object>>> inferenceWithScores = ((FinalLayer) ((Layers) indexedSeq.apply(i + 1)).finalLayer().get()).inferenceWithScores(Utils$.MODULE$.emissionScoresToArrays(MODULE$.forwardForTask(indexedSeq, i, annotatedSentence, option, constEmbeddingParameters, false)));
            return z ? MODULE$.softmax(inferenceWithScores) : inferenceWithScores;
        });
    }

    public boolean predictWithScores$default$6() {
        return true;
    }

    public IndexedSeq<IndexedSeq<Tuple2<String, Object>>> softmax(IndexedSeq<IndexedSeq<Tuple2<String, Object>>> indexedSeq) {
        ArrayBuffer arrayBuffer = new ArrayBuffer();
        indexedSeq.foreach(indexedSeq2 -> {
            return arrayBuffer.$plus$eq(((IndexedSeq) indexedSeq2.map(tuple2 -> {
                return (String) tuple2._1();
            })).zip(MathUtils$.MODULE$.softmaxFloat((IndexedSeq) indexedSeq2.map(tuple22 -> {
                return BoxesRunTime.boxToFloat($anonfun$softmax$2(tuple22));
            }), MathUtils$.MODULE$.softmaxFloat$default$2())));
        });
        return WrappedArrayBuffer$.MODULE$._toIndexedSeq(arrayBuffer);
    }

    public IndexedSeq<Tuple2<Object, String>> parse(IndexedSeq<Layers> indexedSeq, AnnotatedSentence annotatedSentence, ConstEmbeddingParameters constEmbeddingParameters) {
        return WrappedArrayBuffer$.MODULE$._toIndexedSeq((ArrayBuffer) Synchronizer$.MODULE$.withComputationGraph("Layers.parse()", () -> {
            Predef$.MODULE$.assert(((Layers) indexedSeq.apply(0)).nonEmpty());
            ExpressionVector forward = ((Layers) indexedSeq.apply(0)).forward(annotatedSentence, None$.MODULE$, constEmbeddingParameters, false);
            IndexedSeq<IndexedSeq<Tuple2<String, Object>>> inferenceWithScores = ((FinalLayer) ((Layers) indexedSeq.apply(1)).finalLayer().get()).inferenceWithScores(Utils$.MODULE$.emissionScoresToArrays(((Layers) indexedSeq.apply(1)).forwardFrom(forward, None$.MODULE$, false)));
            ArrayBuffer arrayBuffer = new ArrayBuffer();
            inferenceWithScores.indices().foreach(obj -> {
                return $anonfun$parse$2(inferenceWithScores, arrayBuffer, annotatedSentence, BoxesRunTime.unboxToInt(obj));
            });
            ArrayBuffer arrayBuffer2 = new ArrayBuffer();
            arrayBuffer.indices().foreach(obj2 -> {
                return $anonfun$parse$5(arrayBuffer2, arrayBuffer, BoxesRunTime.unboxToInt(obj2));
            });
            IndexedSeq<String> inference = ((FinalLayer) ((Layers) indexedSeq.apply(2)).finalLayer().get()).inference(Utils$.MODULE$.emissionScoresToArrays(((Layers) indexedSeq.apply(2)).forwardFrom(forward, new Some(WrappedArrayBuffer$.MODULE$._toIndexedSeq(arrayBuffer2)), false)));
            Predef$.MODULE$.assert(inference.size() == arrayBuffer.size());
            return (ArrayBuffer) arrayBuffer.zip(inference);
        }));
    }

    public Expression loss(IndexedSeq<Layers> indexedSeq, int i, AnnotatedSentence annotatedSentence, IndexedSeq<Label> indexedSeq2) {
        return ((FinalLayer) ((Layers) indexedSeq.apply(i + 1)).finalLayer().get()).loss(forwardForTask(indexedSeq, i, annotatedSentence, Utils$.MODULE$.getModHeadPairs(indexedSeq2), ConstEmbeddingsGlove$.MODULE$.mkConstLookupParams(annotatedSentence.words()), true), indexedSeq2);
    }

    public static final /* synthetic */ ArrayBuffer $anonfun$loadX2i$1(ParameterCollection parameterCollection, BufferedIterator bufferedIterator, ArrayBuffer arrayBuffer, int i) {
        return arrayBuffer.$plus$eq(RnnLayer$.MODULE$.load(parameterCollection, bufferedIterator));
    }

    public static final /* synthetic */ ArrayBuffer $anonfun$predictJointly$2(IndexedSeq indexedSeq, ExpressionVector expressionVector, Option option, ArrayBuffer arrayBuffer, int i) {
        return arrayBuffer.$plus$eq(((FinalLayer) ((Layers) indexedSeq.apply(i)).finalLayer().get()).inference(Utils$.MODULE$.emissionScoresToArrays(((Layers) indexedSeq.apply(i)).forwardFrom(expressionVector, option, false))));
    }

    public static final /* synthetic */ ArrayBuffer $anonfun$predictJointly$3(IndexedSeq indexedSeq, AnnotatedSentence annotatedSentence, Option option, ConstEmbeddingParameters constEmbeddingParameters, ArrayBuffer arrayBuffer, int i) {
        return arrayBuffer.$plus$eq(((FinalLayer) ((Layers) indexedSeq.apply(i)).finalLayer().get()).inference(Utils$.MODULE$.emissionScoresToArrays(((Layers) indexedSeq.apply(i)).forward(annotatedSentence, option, constEmbeddingParameters, false))));
    }

    public static final /* synthetic */ float $anonfun$softmax$2(Tuple2 tuple2) {
        return BoxesRunTime.unboxToFloat(tuple2._2());
    }

    public static final /* synthetic */ Object $anonfun$parse$2(IndexedSeq indexedSeq, ArrayBuffer arrayBuffer, AnnotatedSentence annotatedSentence, int i) {
        IndexedSeq indexedSeq2 = (IndexedSeq) indexedSeq.apply(i);
        BooleanRef create = BooleanRef.create(false);
        indexedSeq2.indices().withFilter(i2 -> {
            return !create.elem;
        }).foreach(i3 -> {
            try {
                int int$extension = StringOps$.MODULE$.toInt$extension(Predef$.MODULE$.augmentString((String) ((Tuple2) indexedSeq2.apply(i3))._1()));
                if (int$extension == 0) {
                    arrayBuffer.$plus$eq(BoxesRunTime.boxToInteger(-1));
                    create.elem = true;
                } else {
                    int i3 = i + int$extension;
                    if (i3 >= 0 && i3 < annotatedSentence.size()) {
                        arrayBuffer.$plus$eq(BoxesRunTime.boxToInteger(i3));
                        create.elem = true;
                    }
                }
            } catch (NumberFormatException e) {
                create.elem = false;
            }
        });
        return !create.elem ? arrayBuffer.$plus$eq(BoxesRunTime.boxToInteger(-1)) : BoxedUnit.UNIT;
    }

    public static final /* synthetic */ ArrayBuffer $anonfun$parse$5(ArrayBuffer arrayBuffer, ArrayBuffer arrayBuffer2, int i) {
        return arrayBuffer.$plus$eq(new ModifierHeadPair(i, BoxesRunTime.unboxToInt(arrayBuffer2.apply(i))));
    }

    private Layers$() {
    }
}
