package org.clulab.scala_transformers.encoder;

import breeze.linalg.$times$;
import breeze.linalg.BroadcastedRows;
import breeze.linalg.Broadcaster$;
import breeze.linalg.DenseMatrix;
import breeze.linalg.DenseMatrix$;
import breeze.linalg.DenseVector;
import breeze.linalg.DenseVector$;
import breeze.linalg.ImmutableNumericOps;
import breeze.linalg.NumericOps;
import breeze.linalg.Transpose;
import breeze.linalg.argmax$;
import breeze.linalg.operators.HasOps$;
import breeze.storage.Zero$;
import scala.$less$colon$less$;
import scala.Array$;
import scala.MatchError;
import scala.None$;
import scala.Option;
import scala.Predef$;
import scala.Tuple2;
import scala.Tuple2$;
import scala.collection.ArrayOps$;
import scala.math.Ordering$DeprecatedFloatOrdering$;
import scala.package$;
import scala.reflect.ClassTag$;
import scala.runtime.BoxesRunTime;
import scala.runtime.RichInt$;
import scala.runtime.Scala3RunTime$;
import scala.runtime.ScalaRunTime$;
import scala.util.NotGiven$;

/* compiled from: LinearLayer.scala */
/* loaded from: input_file:org/clulab/scala_transformers/encoder/LinearLayer.class */
public class LinearLayer {
    private final String name;
    private final boolean dual;
    private final DenseMatrix weights;
    private final Option biasesOpt;
    private final Option labelsOpt;

    public static LinearLayer fromFiles(String str) {
        return LinearLayer$.MODULE$.fromFiles(str);
    }

    public static LinearLayer fromResources(String str) {
        return LinearLayer$.MODULE$.fromResources(str);
    }

    public LinearLayer(String str, boolean z, DenseMatrix<Object> denseMatrix, Option<DenseVector<Object>> option, Option<String[]> option2) {
        this.name = str;
        this.dual = z;
        this.weights = denseMatrix;
        this.biasesOpt = option;
        this.labelsOpt = option2;
    }

    public String name() {
        return this.name;
    }

    public boolean dual() {
        return this.dual;
    }

    public DenseMatrix<Object> weights() {
        return this.weights;
    }

    public Option<DenseVector<Object>> biasesOpt() {
        return this.biasesOpt;
    }

    public Option<String[]> labelsOpt() {
        return this.labelsOpt;
    }

    public DenseMatrix<Object> forward(DenseMatrix<Object> denseMatrix) {
        return (DenseMatrix) ArrayOps$.MODULE$.head$extension(Predef$.MODULE$.refArrayOps(forward(new DenseMatrix[]{denseMatrix})));
    }

    public DenseMatrix<Object>[] forward(DenseMatrix<Object>[] denseMatrixArr) {
        return (DenseMatrix[]) ArrayOps$.MODULE$.map$extension(Predef$.MODULE$.refArrayOps(denseMatrixArr), denseMatrix -> {
            DenseMatrix denseMatrix = (DenseMatrix) denseMatrix.$times(weights(), HasOps$.MODULE$.impl_OpMulMatrix_DM_DM_eq_DM_Float());
            biasesOpt().foreach(denseVector -> {
                return (BroadcastedRows) ((NumericOps) denseMatrix.apply($times$.MODULE$, package$.MODULE$.$colon$colon(), Broadcaster$.MODULE$.canBroadcastRows(HasOps$.MODULE$.handholdCanMapCols_DM()))).$colon$plus$eq(denseVector, HasOps$.MODULE$.broadcastInplaceOp2_BRows(HasOps$.MODULE$.handholdCanMapCols_DM(), HasOps$.MODULE$.impl_OpAdd_InPlace_DV_DV_Float(), HasOps$.MODULE$.canTraverseRows_DM()));
            });
            return denseMatrix;
        }, ClassTag$.MODULE$.apply(DenseMatrix.class));
    }

    public String[] predict(DenseMatrix<Object> denseMatrix, Option<int[]> option, Option<boolean[]> option2) {
        return (String[]) ArrayOps$.MODULE$.head$extension(Predef$.MODULE$.refArrayOps(predict(new DenseMatrix[]{denseMatrix}, option.map(iArr -> {
            return (int[][]) Array$.MODULE$.apply(ScalaRunTime$.MODULE$.wrapRefArray((Object[]) new int[]{iArr}), ClassTag$.MODULE$.apply(Integer.TYPE).wrap());
        }), option2.map(zArr -> {
            return (boolean[][]) Array$.MODULE$.apply(ScalaRunTime$.MODULE$.wrapRefArray((Object[]) new boolean[]{zArr}), ClassTag$.MODULE$.apply(Boolean.TYPE).wrap());
        }))));
    }

    public String[][] predict(DenseMatrix<Object>[] denseMatrixArr, Option<int[][]> option, Option<boolean[][]> option2) {
        return dual() ? predictDual(denseMatrixArr, option, option2) : predictPrimal(denseMatrixArr);
    }

    public Tuple2<String, Object>[][] predictWithScores(DenseMatrix<Object> denseMatrix, Option<int[][]> option, Option<boolean[]> option2) {
        return (Tuple2[][]) ArrayOps$.MODULE$.head$extension(Predef$.MODULE$.refArrayOps(predictWithScores(new DenseMatrix[]{denseMatrix}, option.map(iArr -> {
            return (int[][][]) Array$.MODULE$.apply(ScalaRunTime$.MODULE$.wrapRefArray(new int[][]{iArr}), ClassTag$.MODULE$.apply(Integer.TYPE).wrap().wrap());
        }), option2.map(zArr -> {
            return (boolean[][]) Array$.MODULE$.apply(ScalaRunTime$.MODULE$.wrapRefArray((Object[]) new boolean[]{zArr}), ClassTag$.MODULE$.apply(Boolean.TYPE).wrap());
        }))));
    }

    public Tuple2<String, Object>[][][] predictWithScores(DenseMatrix<Object>[] denseMatrixArr, Option<int[][][]> option, Option<boolean[][]> option2) {
        return dual() ? predictDualWithScores(denseMatrixArr, option, option2) : predictPrimalWithScores(denseMatrixArr);
    }

    public DenseMatrix<Object> concatenateModifiersAndHeads(DenseMatrix<Object> denseMatrix, int[] iArr) {
        DenseMatrix<Object> zeros = DenseMatrix$.MODULE$.zeros(denseMatrix.rows(), 2 * denseMatrix.cols(), ClassTag$.MODULE$.apply(Float.TYPE), Zero$.MODULE$.FloatZero());
        RichInt$.MODULE$.until$extension(Predef$.MODULE$.intWrapper(0), denseMatrix.rows()).foreach(obj -> {
            return concatenateModifiersAndHeads$$anonfun$1(denseMatrix, iArr, zeros, BoxesRunTime.unboxToInt(obj));
        });
        return zeros;
    }

    public DenseMatrix<Object> concatenateModifierAndHead(DenseMatrix<Object> denseMatrix, int i, int i2) {
        DenseMatrix<Object> zeros = DenseMatrix$.MODULE$.zeros(1, 2 * denseMatrix.cols(), ClassTag$.MODULE$.apply(Float.TYPE), Zero$.MODULE$.FloatZero());
        int i3 = i + i2;
        ((NumericOps) zeros.apply(BoxesRunTime.boxToInteger(0), package$.MODULE$.$colon$colon(), HasOps$.MODULE$.canSliceRow())).$colon$plus$eq((Transpose) DenseVector$.MODULE$.vertcat(ScalaRunTime$.MODULE$.wrapRefArray(new DenseVector[]{(DenseVector) ((Transpose) denseMatrix.apply(BoxesRunTime.boxToInteger(i), package$.MODULE$.$colon$colon(), HasOps$.MODULE$.canSliceRow())).t(HasOps$.MODULE$.canUntranspose()), (DenseVector) ((Transpose) denseMatrix.apply(BoxesRunTime.boxToInteger((i3 < 0 || i3 >= denseMatrix.rows()) ? i : i3), package$.MODULE$.$colon$colon(), HasOps$.MODULE$.canSliceRow())).t(HasOps$.MODULE$.canUntranspose())}), HasOps$.MODULE$.impl_Op_InPlace_DV_DV_Float_OpSet(), ClassTag$.MODULE$.apply(Float.TYPE), Zero$.MODULE$.FloatZero()).t(HasOps$.MODULE$.transposeTensor($less$colon$less$.MODULE$.refl())), HasOps$.MODULE$.liftInPlaceOps(NotGiven$.MODULE$.value(), HasOps$.MODULE$.canUntranspose(), HasOps$.MODULE$.impl_OpAdd_InPlace_DV_DV_Float()));
        return zeros;
    }

    public String[][] predictDual(DenseMatrix<Object>[] denseMatrixArr, Option<int[][]> option, Option<boolean[][]> option2) {
        if (!option.isDefined()) {
            throw Scala3RunTime$.MODULE$.assertFailed();
        }
        if (!option2.isDefined()) {
            throw Scala3RunTime$.MODULE$.assertFailed();
        }
        String[] strArr = (String[]) labelsOpt().getOrElse(LinearLayer::$anonfun$5);
        return (String[][]) ArrayOps$.MODULE$.map$extension(Predef$.MODULE$.refArrayOps(ArrayOps$.MODULE$.zip$extension(Predef$.MODULE$.refArrayOps(denseMatrixArr), Predef$.MODULE$.wrapRefArray((Object[]) option.get()))), tuple2 -> {
            if (tuple2 == null) {
                throw new MatchError(tuple2);
            }
            DenseMatrix<Object> denseMatrix = forward(new DenseMatrix[]{concatenateModifiersAndHeads((DenseMatrix) tuple2._1(), (int[]) tuple2._2())})[0];
            return (String[]) package$.MODULE$.Range().apply(0, denseMatrix.rows()).map(obj -> {
                return $anonfun$7(strArr, denseMatrix, BoxesRunTime.unboxToInt(obj));
            }).toArray(ClassTag$.MODULE$.apply(String.class));
        }, ClassTag$.MODULE$.apply(String.class).wrap());
    }

    public Option<int[][]> predictDual$default$2() {
        return None$.MODULE$;
    }

    public Option<boolean[][]> predictDual$default$3() {
        return None$.MODULE$;
    }

    public Tuple2<String, Object>[][][] predictDualWithScores(DenseMatrix<Object>[] denseMatrixArr, Option<int[][][]> option, Option<boolean[][]> option2) {
        if (!option.isDefined()) {
            throw Scala3RunTime$.MODULE$.assertFailed();
        }
        if (!option2.isDefined()) {
            throw Scala3RunTime$.MODULE$.assertFailed();
        }
        String[] strArr = (String[]) labelsOpt().getOrElse(LinearLayer::$anonfun$8);
        return (Tuple2[][][]) ArrayOps$.MODULE$.map$extension(Predef$.MODULE$.refArrayOps(ArrayOps$.MODULE$.zip$extension(Predef$.MODULE$.refArrayOps(denseMatrixArr), Predef$.MODULE$.wrapRefArray((Object[]) option.get()))), tuple2 -> {
            if (tuple2 == null) {
                throw new MatchError(tuple2);
            }
            DenseMatrix denseMatrix = (DenseMatrix) tuple2._1();
            int[][] iArr = (int[][]) tuple2._2();
            return (Tuple2[][]) ArrayOps$.MODULE$.map$extension(Predef$.MODULE$.refArrayOps(ArrayOps$.MODULE$.zipWithIndex$extension(Predef$.MODULE$.refArrayOps(iArr))), tuple2 -> {
                if (tuple2 == null) {
                    throw new MatchError(tuple2);
                }
                int[] iArr2 = (int[]) tuple2._1();
                int unboxToInt = BoxesRunTime.unboxToInt(tuple2._2());
                return (Tuple2[]) ArrayOps$.MODULE$.map$extension(Predef$.MODULE$.intArrayOps(iArr2), obj -> {
                    return $anonfun$9$$anonfun$1$$anonfun$1(strArr, denseMatrix, unboxToInt, BoxesRunTime.unboxToInt(obj));
                }, ClassTag$.MODULE$.apply(Tuple2.class));
            }, ClassTag$.MODULE$.apply(Tuple2.class).wrap());
        }, ClassTag$.MODULE$.apply(Tuple2.class).wrap().wrap());
    }

    public Option<int[][][]> predictDualWithScores$default$2() {
        return None$.MODULE$;
    }

    public Option<boolean[][]> predictDualWithScores$default$3() {
        return None$.MODULE$;
    }

    public String[][] predictPrimal(DenseMatrix<Object>[] denseMatrixArr) {
        String[] strArr = (String[]) labelsOpt().getOrElse(LinearLayer::$anonfun$10);
        return (String[][]) ArrayOps$.MODULE$.map$extension(Predef$.MODULE$.refArrayOps(forward(denseMatrixArr)), denseMatrix -> {
            return (String[]) package$.MODULE$.Range().apply(0, denseMatrix.rows()).map(obj -> {
                return $anonfun$12(strArr, denseMatrix, BoxesRunTime.unboxToInt(obj));
            }).toArray(ClassTag$.MODULE$.apply(String.class));
        }, ClassTag$.MODULE$.apply(String.class).wrap());
    }

    public Tuple2<String, Object>[][][] predictPrimalWithScores(DenseMatrix<Object>[] denseMatrixArr) {
        String[] strArr = (String[]) labelsOpt().getOrElse(LinearLayer::$anonfun$13);
        return (Tuple2[][][]) ArrayOps$.MODULE$.map$extension(Predef$.MODULE$.refArrayOps(forward(denseMatrixArr)), denseMatrix -> {
            return (Tuple2[][]) package$.MODULE$.Range().apply(0, denseMatrix.rows()).map(obj -> {
                return $anonfun$15(strArr, denseMatrix, BoxesRunTime.unboxToInt(obj));
            }).toArray(ClassTag$.MODULE$.apply(Tuple2.class).wrap());
        }, ClassTag$.MODULE$.apply(Tuple2.class).wrap().wrap());
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static final /* synthetic */ Transpose concatenateModifiersAndHeads$$anonfun$1(DenseMatrix denseMatrix, int[] iArr, DenseMatrix denseMatrix2, int i) {
        Transpose transpose = (Transpose) denseMatrix.apply(BoxesRunTime.boxToInteger(i), package$.MODULE$.$colon$colon(), HasOps$.MODULE$.canSliceRow());
        int i2 = i + iArr[i];
        return (Transpose) ((NumericOps) denseMatrix2.apply(BoxesRunTime.boxToInteger(i), package$.MODULE$.$colon$colon(), HasOps$.MODULE$.canSliceRow())).$colon$plus$eq((Transpose) DenseVector$.MODULE$.vertcat(ScalaRunTime$.MODULE$.wrapRefArray(new DenseVector[]{(DenseVector) transpose.t(HasOps$.MODULE$.canUntranspose()), (DenseVector) ((Transpose) denseMatrix.apply(BoxesRunTime.boxToInteger((i2 < 0 || i2 >= denseMatrix.rows()) ? i : i2), package$.MODULE$.$colon$colon(), HasOps$.MODULE$.canSliceRow())).t(HasOps$.MODULE$.canUntranspose())}), HasOps$.MODULE$.impl_Op_InPlace_DV_DV_Float_OpSet(), ClassTag$.MODULE$.apply(Float.TYPE), Zero$.MODULE$.FloatZero()).t(HasOps$.MODULE$.transposeTensor($less$colon$less$.MODULE$.refl())), HasOps$.MODULE$.liftInPlaceOps(NotGiven$.MODULE$.value(), HasOps$.MODULE$.canUntranspose(), HasOps$.MODULE$.impl_OpAdd_InPlace_DV_DV_Float()));
    }

    private static final String[] $anonfun$5() {
        throw new RuntimeException("ERROR: can't predict without labels!");
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static final /* synthetic */ String $anonfun$7(String[] strArr, DenseMatrix denseMatrix, int i) {
        return strArr[BoxesRunTime.unboxToInt(argmax$.MODULE$.apply(((Transpose) denseMatrix.apply(BoxesRunTime.boxToInteger(i), package$.MODULE$.$colon$colon(), HasOps$.MODULE$.canSliceRow())).t(HasOps$.MODULE$.canUntranspose()), argmax$.MODULE$.reduce_Float(HasOps$.MODULE$.DV_canTraverseKeyValuePairs())))];
    }

    private static final String[] $anonfun$8() {
        throw new RuntimeException("ERROR: can't predict without labels!");
    }

    private final /* synthetic */ Tuple2 $anonfun$9$$anonfun$1$$anonfun$1(String[] strArr, DenseMatrix denseMatrix, int i, int i2) {
        Transpose transpose = (Transpose) forward(new DenseMatrix[]{concatenateModifierAndHead(denseMatrix, i, i2)})[0].apply(BoxesRunTime.boxToInteger(0), package$.MODULE$.$colon$colon(), HasOps$.MODULE$.canSliceRow());
        int unboxToInt = BoxesRunTime.unboxToInt(argmax$.MODULE$.apply(transpose.t(HasOps$.MODULE$.canUntranspose()), argmax$.MODULE$.reduce_Float(HasOps$.MODULE$.DV_canTraverseKeyValuePairs())));
        return Tuple2$.MODULE$.apply(strArr[unboxToInt], BoxesRunTime.boxToFloat(BoxesRunTime.unboxToFloat(HasOps$.MODULE$.LiftApply(transpose).apply(BoxesRunTime.boxToInteger(unboxToInt)))));
    }

    private static final String[] $anonfun$10() {
        throw new RuntimeException("ERROR: can't predict without labels!");
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static final /* synthetic */ String $anonfun$12(String[] strArr, DenseMatrix denseMatrix, int i) {
        return strArr[BoxesRunTime.unboxToInt(argmax$.MODULE$.apply(((Transpose) denseMatrix.apply(BoxesRunTime.boxToInteger(i), package$.MODULE$.$colon$colon(), HasOps$.MODULE$.canSliceRow())).t(HasOps$.MODULE$.canUntranspose()), argmax$.MODULE$.reduce_Float(HasOps$.MODULE$.DV_canTraverseKeyValuePairs())))];
    }

    private static final String[] $anonfun$13() {
        throw new RuntimeException("ERROR: can't predict without labels!");
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static final /* synthetic */ Tuple2[] $anonfun$15(String[] strArr, DenseMatrix denseMatrix, int i) {
        float[] fArr = (float[]) ((DenseVector) ((ImmutableNumericOps) denseMatrix.apply(BoxesRunTime.boxToInteger(i), package$.MODULE$.$colon$colon(), HasOps$.MODULE$.canSliceRow())).t(HasOps$.MODULE$.canUntranspose())).toArray(ClassTag$.MODULE$.apply(Float.TYPE));
        return (Tuple2[]) ArrayOps$.MODULE$.sortBy$extension(Predef$.MODULE$.refArrayOps(ArrayOps$.MODULE$.zip$extension(Predef$.MODULE$.refArrayOps(strArr), Predef$.MODULE$.wrapFloatArray(fArr))), tuple2 -> {
            return -BoxesRunTime.unboxToFloat(tuple2._2());
        }, Ordering$DeprecatedFloatOrdering$.MODULE$);
    }
}
