package scorch.autograd;

import botkop.numsca.Tensor;
import botkop.numsca.Tensor$;
import botkop.numsca.package$;
import com.typesafe.scalalogging.LazyLogging;
import com.typesafe.scalalogging.Logger;
import scala.Function1;
import scala.Option;
import scala.Predef$;
import scala.Product;
import scala.Serializable;
import scala.Some;
import scala.Tuple2;
import scala.collection.Iterator;
import scala.collection.immutable.List;
import scala.collection.mutable.ArrayOps;
import scala.reflect.ScalaSignature;
import scala.runtime.BoxesRunTime;
import scala.runtime.ScalaRunTime$;

/* compiled from: Function.scala */
@ScalaSignature(bytes = "\u0006\u0001\u0005Me\u0001B\u0001\u0003\u0001\u001e\u00111bU8gi6\f\u0007\u0010T8tg*\u00111\u0001B\u0001\tCV$xn\u001a:bI*\tQ!\u0001\u0004tG>\u00148\r[\u0002\u0001'\u0015\u0001\u0001B\u0004\n\u0016!\tIA\"D\u0001\u000b\u0015\u0005Y\u0011!B:dC2\f\u0017BA\u0007\u000b\u0005\u0019\te.\u001f*fMB\u0011q\u0002E\u0007\u0002\u0005%\u0011\u0011C\u0001\u0002\t\rVt7\r^5p]B\u0011\u0011bE\u0005\u0003))\u0011q\u0001\u0015:pIV\u001cG\u000f\u0005\u0002\n-%\u0011qC\u0003\u0002\r'\u0016\u0014\u0018.\u00197ju\u0006\u0014G.\u001a\u0005\t3\u0001\u0011)\u001a!C\u00015\u00051\u0011m\u0019;vC2,\u0012a\u0007\t\u0003\u001fqI!!\b\u0002\u0003\u0011Y\u000b'/[1cY\u0016D\u0001b\b\u0001\u0003\u0012\u0003\u0006IaG\u0001\bC\u000e$X/\u00197!\u0011!\t\u0003A!f\u0001\n\u0003Q\u0012A\u0002;be\u001e,G\u000f\u0003\u0005$\u0001\tE\t\u0015!\u0003\u001c\u0003\u001d!\u0018M]4fi\u0002BQ!\n\u0001\u0005\u0002\u0019\na\u0001P5oSRtDcA\u0014)SA\u0011q\u0002\u0001\u0005\u00063\u0011\u0002\ra\u0007\u0005\u0006C\u0011\u0002\ra\u0007\u0005\bW\u0001\u0011\r\u0011\"\u0001-\u0003\u0005AX#A\u0017\u0011\u00059\u001aT\"A\u0018\u000b\u0005A\n\u0014A\u00028v[N\u001c\u0017MC\u00013\u0003\u0019\u0011w\u000e^6pa&\u0011Ag\f\u0002\u0007)\u0016t7o\u001c:\t\rY\u0002\u0001\u0015!\u0003.\u0003\tA\b\u0005C\u00049\u0001\t\u0007I\u0011\u0001\u0017\u0002\u0003eDaA\u000f\u0001!\u0002\u0013i\u0013AA=!\u0011\u001da\u0004A1A\u0005\u00021\nQb\u001d5jMR,G\rT8hSR\u001c\bB\u0002 \u0001A\u0003%Q&\u0001\btQ&4G/\u001a3M_\u001eLGo\u001d\u0011\t\u000f\u0001\u0003!\u0019!C\u0001Y\u0005\t!\u0010\u0003\u0004C\u0001\u0001\u0006I!L\u0001\u0003u\u0002Bq\u0001\u0012\u0001C\u0002\u0013\u0005A&\u0001\u0005m_\u001e\u0004&o\u001c2t\u0011\u00191\u0005\u0001)A\u0005[\u0005IAn\\4Qe>\u00147\u000f\t\u0005\b\u0011\u0002\u0011\r\u0011\"\u0001J\u0003\u0005qW#\u0001&\u0011\u0005%Y\u0015B\u0001'\u000b\u0005\rIe\u000e\u001e\u0005\u0007\u001d\u0002\u0001\u000b\u0011\u0002&\u0002\u00059\u0004\u0003b\u0002)\u0001\u0005\u0004%\t!U\u0001\u0005Y>\u001c8/F\u0001S!\tI1+\u0003\u0002U\u0015\t1Ai\\;cY\u0016DaA\u0016\u0001!\u0002\u0013\u0011\u0016!\u00027pgN\u0004\u0003\"\u0002-\u0001\t\u0003J\u0016a\u00024pe^\f'\u000f\u001a\u000b\u00027!)1\f\u0001C!9\u0006A!-Y2lo\u0006\u0014H\r\u0006\u0002^AB\u0011\u0011BX\u0005\u0003?*\u0011A!\u00168ji\")\u0011M\u0017a\u00017\u0005QqM]1e\u001fV$\b/\u001e;\t\u000f\r\u0004\u0011\u0011!C\u0001I\u0006!1m\u001c9z)\r9SM\u001a\u0005\b3\t\u0004\n\u00111\u0001\u001c\u0011\u001d\t#\r%AA\u0002mAq\u0001\u001b\u0001\u0012\u0002\u0013\u0005\u0011.\u0001\bd_BLH\u0005Z3gCVdG\u000fJ\u0019\u0016\u0003)T#aG6,\u00031\u0004\"!\u001c:\u000e\u00039T!a\u001c9\u0002\u0013Ut7\r[3dW\u0016$'BA9\u000b\u0003)\tgN\\8uCRLwN\\\u0005\u0003g:\u0014\u0011#\u001e8dQ\u0016\u001c7.\u001a3WCJL\u0017M\\2f\u0011\u001d)\b!%A\u0005\u0002%\fabY8qs\u0012\"WMZ1vYR$#\u0007C\u0004x\u0001\u0005\u0005I\u0011\t=\u0002\u001bA\u0014x\u000eZ;diB\u0013XMZ5y+\u0005I\bC\u0001>��\u001b\u0005Y(B\u0001?~\u0003\u0011a\u0017M\\4\u000b\u0003y\fAA[1wC&\u0019\u0011\u0011A>\u0003\rM#(/\u001b8h\u0011!\t)\u0001AA\u0001\n\u0003I\u0015\u0001\u00049s_\u0012,8\r^!sSRL\b\"CA\u0005\u0001\u0005\u0005I\u0011AA\u0006\u00039\u0001(o\u001c3vGR,E.Z7f]R$B!!\u0004\u0002\u0014A\u0019\u0011\"a\u0004\n\u0007\u0005E!BA\u0002B]fD\u0011\"!\u0006\u0002\b\u0005\u0005\t\u0019\u0001&\u0002\u0007a$\u0013\u0007C\u0005\u0002\u001a\u0001\t\t\u0011\"\u0011\u0002\u001c\u0005y\u0001O]8ek\u000e$\u0018\n^3sCR|'/\u0006\u0002\u0002\u001eA1\u0011qDA\u0013\u0003\u001bi!!!\t\u000b\u0007\u0005\r\"\"\u0001\u0006d_2dWm\u0019;j_:LA!a\n\u0002\"\tA\u0011\n^3sCR|'\u000fC\u0005\u0002,\u0001\t\t\u0011\"\u0001\u0002.\u0005A1-\u00198FcV\fG\u000e\u0006\u0003\u00020\u0005U\u0002cA\u0005\u00022%\u0019\u00111\u0007\u0006\u0003\u000f\t{w\u000e\\3b]\"Q\u0011QCA\u0015\u0003\u0003\u0005\r!!\u0004\t\u0013\u0005e\u0002!!A\u0005B\u0005m\u0012\u0001\u00035bg\"\u001cu\u000eZ3\u0015\u0003)C\u0011\"a\u0010\u0001\u0003\u0003%\t%!\u0011\u0002\u0011Q|7\u000b\u001e:j]\u001e$\u0012!\u001f\u0005\n\u0003\u000b\u0002\u0011\u0011!C!\u0003\u000f\na!Z9vC2\u001cH\u0003BA\u0018\u0003\u0013B!\"!\u0006\u0002D\u0005\u0005\t\u0019AA\u0007\u000f%\tiEAA\u0001\u0012\u0003\ty%A\u0006T_\u001a$X.\u0019=M_N\u001c\bcA\b\u0002R\u0019A\u0011AAA\u0001\u0012\u0003\t\u0019fE\u0003\u0002R\u0005US\u0003E\u0004\u0002X\u0005u3dG\u0014\u000e\u0005\u0005e#bAA.\u0015\u00059!/\u001e8uS6,\u0017\u0002BA0\u00033\u0012\u0011#\u00112tiJ\f7\r\u001e$v]\u000e$\u0018n\u001c83\u0011\u001d)\u0013\u0011\u000bC\u0001\u0003G\"\"!a\u0014\t\u0015\u0005}\u0012\u0011KA\u0001\n\u000b\n\t\u0005\u0003\u0006\u0002j\u0005E\u0013\u0011!CA\u0003W\nQ!\u00199qYf$RaJA7\u0003_Ba!GA4\u0001\u0004Y\u0002BB\u0011\u0002h\u0001\u00071\u0004\u0003\u0006\u0002t\u0005E\u0013\u0011!CA\u0003k\nq!\u001e8baBd\u0017\u0010\u0006\u0003\u0002x\u0005\r\u0005#B\u0005\u0002z\u0005u\u0014bAA>\u0015\t1q\n\u001d;j_:\u0004R!CA@7mI1!!!\u000b\u0005\u0019!V\u000f\u001d7fe!I\u0011QQA9\u0003\u0003\u0005\raJ\u0001\u0004q\u0012\u0002\u0004BCAE\u0003#\n\t\u0011\"\u0003\u0002\f\u0006Y!/Z1e%\u0016\u001cx\u000e\u001c<f)\t\ti\tE\u0002{\u0003\u001fK1!!%|\u0005\u0019y%M[3di\u0002")
/* loaded from: input_file:scorch/autograd/SoftmaxLoss.class */
public class SoftmaxLoss implements Function, Product, Serializable {
    private final Variable actual;
    private final Variable target;
    private final Tensor x;
    private final Tensor y;
    private final Tensor shiftedLogits;
    private final Tensor z;
    private final Tensor logProbs;
    private final int n;
    private final double loss;
    private Logger logger;
    private volatile boolean bitmap$0;

    public static Option<Tuple2<Variable, Variable>> unapply(SoftmaxLoss softmaxLoss) {
        return SoftmaxLoss$.MODULE$.unapply(softmaxLoss);
    }

    public static SoftmaxLoss apply(Variable variable, Variable variable2) {
        return SoftmaxLoss$.MODULE$.apply(variable, variable2);
    }

    public static Function1<Tuple2<Variable, Variable>, SoftmaxLoss> tupled() {
        return SoftmaxLoss$.MODULE$.tupled();
    }

    public static Function1<Variable, Function1<Variable, SoftmaxLoss>> curried() {
        return SoftmaxLoss$.MODULE$.curried();
    }

    @Override // scorch.autograd.Function
    public Variable unbroadcast(Variable variable, List<Object> list) {
        Variable unbroadcast;
        unbroadcast = unbroadcast(variable, (List<Object>) list);
        return unbroadcast;
    }

    @Override // scorch.autograd.Function
    public Variable unbroadcast(Tensor tensor, List<Object> list) {
        Variable unbroadcast;
        unbroadcast = unbroadcast(tensor, (List<Object>) list);
        return unbroadcast;
    }

    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r0v0 */
    /* JADX WARN: Type inference failed for: r0v1, types: [java.lang.Throwable] */
    /* JADX WARN: Type inference failed for: r0v8, types: [scorch.autograd.SoftmaxLoss] */
    private Logger logger$lzycompute() {
        ?? r0 = this;
        synchronized (r0) {
            if (!this.bitmap$0) {
                this.logger = LazyLogging.logger$(this);
                r0 = this;
                r0.bitmap$0 = true;
            }
        }
        return this.logger;
    }

    public Logger logger() {
        return !this.bitmap$0 ? logger$lzycompute() : this.logger;
    }

    public Variable actual() {
        return this.actual;
    }

    public Variable target() {
        return this.target;
    }

    public Tensor x() {
        return this.x;
    }

    public Tensor y() {
        return this.y;
    }

    public Tensor shiftedLogits() {
        return this.shiftedLogits;
    }

    public Tensor z() {
        return this.z;
    }

    public Tensor logProbs() {
        return this.logProbs;
    }

    public int n() {
        return this.n;
    }

    public double loss() {
        return this.loss;
    }

    @Override // scorch.autograd.Function
    public Variable forward() {
        return new Variable(Tensor$.MODULE$.apply(Predef$.MODULE$.wrapDoubleArray(new double[]{loss()})), new Some(this), Variable$.MODULE$.apply$default$3());
    }

    @Override // scorch.autograd.Function
    public void backward(Variable variable) {
        Tensor exp = package$.MODULE$.exp(logProbs());
        exp.apply(Predef$.MODULE$.wrapRefArray(new Tensor[]{package$.MODULE$.arange(n()), y()})).$minus$eq(1.0d);
        exp.$div$eq(n());
        actual().backward(new Variable(exp, Variable$.MODULE$.apply$default$2(), Variable$.MODULE$.apply$default$3()));
    }

    public SoftmaxLoss copy(Variable variable, Variable variable2) {
        return new SoftmaxLoss(variable, variable2);
    }

    public Variable copy$default$1() {
        return actual();
    }

    public Variable copy$default$2() {
        return target();
    }

    public String productPrefix() {
        return "SoftmaxLoss";
    }

    public int productArity() {
        return 2;
    }

    public Object productElement(int i) {
        switch (i) {
            case 0:
                return actual();
            case 1:
                return target();
            default:
                throw new IndexOutOfBoundsException(BoxesRunTime.boxToInteger(i).toString());
        }
    }

    public Iterator<Object> productIterator() {
        return ScalaRunTime$.MODULE$.typedProductIterator(this);
    }

    public boolean canEqual(Object obj) {
        return obj instanceof SoftmaxLoss;
    }

    public int hashCode() {
        return ScalaRunTime$.MODULE$._hashCode(this);
    }

    public String toString() {
        return ScalaRunTime$.MODULE$._toString(this);
    }

    public boolean equals(Object obj) {
        boolean z;
        if (this != obj) {
            if (obj instanceof SoftmaxLoss) {
                SoftmaxLoss softmaxLoss = (SoftmaxLoss) obj;
                Variable actual = actual();
                Variable actual2 = softmaxLoss.actual();
                if (actual != null ? actual.equals(actual2) : actual2 == null) {
                    Variable target = target();
                    Variable target2 = softmaxLoss.target();
                    if (target != null ? target.equals(target2) : target2 == null) {
                        if (softmaxLoss.canEqual(this)) {
                            z = true;
                            if (!z) {
                            }
                        }
                    }
                }
                z = false;
                if (!z) {
                }
            }
            return false;
        }
        return true;
    }

    public SoftmaxLoss(Variable variable, Variable variable2) {
        this.actual = variable;
        this.target = variable2;
        LazyLogging.$init$(this);
        Function.$init$(this);
        Product.$init$(this);
        this.x = variable.data();
        this.y = variable2.data().T();
        this.shiftedLogits = x().$minus(package$.MODULE$.max(x(), 1));
        this.z = package$.MODULE$.sum(package$.MODULE$.exp(shiftedLogits()), 1);
        this.logProbs = shiftedLogits().$minus(package$.MODULE$.log(z()));
        this.n = BoxesRunTime.unboxToInt(new ArrayOps.ofInt(Predef$.MODULE$.intArrayOps(x().shape())).head());
        this.loss = (-package$.MODULE$.sum(package$.MODULE$.selectionToTensor(logProbs().apply(Predef$.MODULE$.wrapRefArray(new Tensor[]{package$.MODULE$.arange(n()), y()}))))) / n();
    }
}
