package org.platanios.tensorflow.api.learn.hooks;

import org.platanios.tensorflow.api.core.Graph$Keys$GLOBAL_EPOCH$;
import org.platanios.tensorflow.api.core.Graph$Keys$GLOBAL_STEP$;
import org.platanios.tensorflow.api.core.Graph$Keys$LOSSES$;
import org.platanios.tensorflow.api.core.client.Executable;
import org.platanios.tensorflow.api.core.client.Executable$;
import org.platanios.tensorflow.api.core.client.Fetchable;
import org.platanios.tensorflow.api.core.client.Fetchable$;
import org.platanios.tensorflow.api.core.client.Session;
import org.platanios.tensorflow.api.learn.Counter$;
import org.platanios.tensorflow.api.learn.StopCriteria;
import org.platanios.tensorflow.api.learn.hooks.Hook;
import org.platanios.tensorflow.api.ops.Math$;
import org.platanios.tensorflow.api.ops.Op;
import org.platanios.tensorflow.api.ops.Op$;
import org.platanios.tensorflow.api.ops.Output;
import org.platanios.tensorflow.api.ops.variables.Variable;
import org.platanios.tensorflow.api.tensors.Tensor;
import scala.None$;
import scala.Option;
import scala.Predef$;
import scala.Some;
import scala.collection.Seq;
import scala.collection.Seq$;
import scala.collection.Traversable;
import scala.collection.mutable.ListBuffer;
import scala.collection.mutable.ListBuffer$;
import scala.reflect.ScalaSignature;
import scala.runtime.BooleanRef;
import scala.runtime.BoxedUnit;
import scala.runtime.BoxesRunTime;

/* compiled from: Stopper.scala */
@ScalaSignature(bytes = "\u0006\u0001\t]b!B\u0001\u0003\u0001\u0011q!aB*u_B\u0004XM\u001d\u0006\u0003\u0007\u0011\tQ\u0001[8pWNT!!\u0002\u0004\u0002\u000b1,\u0017M\u001d8\u000b\u0005\u001dA\u0011aA1qS*\u0011\u0011BC\u0001\u000bi\u0016t7o\u001c:gY><(BA\u0006\r\u0003%\u0001H.\u0019;b]&|7OC\u0001\u000e\u0003\ry'oZ\n\u0003\u0001=\u0001\"\u0001E\t\u000e\u0003\tI!A\u0005\u0002\u0003\t!{wn\u001b\u0005\t)\u0001\u0011\t\u0019!C\t-\u0005A1M]5uKJL\u0017m\u0001\u0001\u0016\u0003]\u0001\"\u0001G\r\u000e\u0003\u0011I!A\u0007\u0003\u0003\u0019M#x\u000e]\"sSR,'/[1\t\u0011q\u0001!\u00111A\u0005\u0012u\tAb\u0019:ji\u0016\u0014\u0018.Y0%KF$\"A\b\u0013\u0011\u0005}\u0011S\"\u0001\u0011\u000b\u0003\u0005\nQa]2bY\u0006L!a\t\u0011\u0003\tUs\u0017\u000e\u001e\u0005\bKm\t\t\u00111\u0001\u0018\u0003\rAH%\r\u0005\tO\u0001\u0011\t\u0011)Q\u0005/\u0005I1M]5uKJL\u0017\r\t\u0005\u0006S\u0001!\tBK\u0001\u0007y%t\u0017\u000e\u001e \u0015\u0005-b\u0003C\u0001\t\u0001\u0011\u0015!\u0002\u00061\u0001\u0018\u0011%q\u0003\u00011A\u0001B\u0003&q&A\u0003fa>\u001c\u0007\u000e\u0005\u00021k5\t\u0011G\u0003\u00023g\u0005Ia/\u0019:jC\ndWm\u001d\u0006\u0003i\u0019\t1a\u001c9t\u0013\t1\u0014G\u0001\u0005WCJL\u0017M\u00197f\u0011%A\u0004\u00011A\u0001B\u0003&q&\u0001\u0003ti\u0016\u0004\b\"\u0003\u001e\u0001\u0001\u0004\u0005\t\u0015)\u0003<\u0003\u0011awn]:\u0011\u0005qjT\"A\u001a\n\u0005y\u001a$AB(viB,H\u000f\u0003\u0004A\u0001\u0001\u0006K!Q\u0001\ngR\f'\u000f\u001e+j[\u0016\u0004\"a\b\"\n\u0005\r\u0003#\u0001\u0002'p]\u001eDa!\u0012\u0001!B\u00131\u0015!\u00037bgR,\u0005o\\2i!\ryr)Q\u0005\u0003\u0011\u0002\u0012aa\u00149uS>t\u0007B\u0002&\u0001A\u0003&a)\u0001\u0005mCN$8\u000b^3q\u0011\u0019a\u0005\u0001)Q\u0005\u001b\u0006AA.Y:u\u0019>\u001c8\u000f\u0005\u0002 \u001d&\u0011q\n\t\u0002\u0006\r2|\u0017\r\u001e\u0005\u0007#\u0002\u0001\u000b\u0015\u0002*\u0002!9,Xn\u0015;faN\u0014U\r\\8x)>d\u0007CA\u0010T\u0013\t!\u0006EA\u0002J]RD\u0011B\u0016\u0001A\u0002\u0003\u0005\u000b\u0015B,\u0002\u001dM,7o]5p]\u001a+Go\u00195fgB\u0019\u0001\fY\u001e\u000f\u0005esfB\u0001.^\u001b\u0005Y&B\u0001/\u0016\u0003\u0019a$o\\8u}%\t\u0011%\u0003\u0002`A\u00059\u0001/Y2lC\u001e,\u0017BA1c\u0005\r\u0019V-\u001d\u0006\u0003?\u0002B\u0011\u0002\u001a\u0001A\u0002\u0003\u0005\u000b\u0015\u0002*\u0002\u001f\u0015\u0004xn\u00195GKR\u001c\u0007.\u00138eKbD\u0011B\u001a\u0001A\u0002\u0003\u0005\u000b\u0015\u0002*\u0002\u001dM$X\r\u001d$fi\u000eD\u0017J\u001c3fq\"I\u0001\u000e\u0001a\u0001\u0002\u0003\u0006KAU\u0001\u000fY>\u001c8OR3uG\"Le\u000eZ3y\u0011\u0015Q\u0007\u0001\"\u0001l\u00039)\b\u000fZ1uK\u000e\u0013\u0018\u000e^3sS\u0006$\"A\b7\t\u000bQI\u0007\u0019A\f\t\u000b9\u0004A\u0011A8\u0002\u000bI,7/\u001a;\u0015\u0005y\u0001\b\"B9n\u0001\u0004\u0011\u0018aB:fgNLwN\u001c\t\u0003gbl\u0011\u0001\u001e\u0006\u0003kZ\faa\u00197jK:$(BA<\u0007\u0003\u0011\u0019wN]3\n\u0005e$(aB*fgNLwN\u001c\u0005\u0006w\u0002!\t\u0006`\u0001\u0006E\u0016<\u0017N\u001c\u000b\u0002=!)a\u0010\u0001C)\u007f\u0006!\u0012M\u001a;feN+7o]5p]\u000e\u0013X-\u0019;j_:$2AHA\u0001\u0011\u0015\tX\u00101\u0001s\u0011\u001d\t)\u0001\u0001C)\u0003\u000f\t\u0001CY3g_J,7+Z:tS>t'+\u001e8\u0016\u0011\u0005%\u00111NA#\u0003c\"B!a\u0003\u0002vQ1\u0011QBA\u001c\u0003/\u0002BaH$\u0002\u0010AI\u0011\u0011CA\f/\u0006u\u0011\u0011\u0006\b\u0004!\u0005M\u0011bAA\u000b\u0005\u0005!\u0001j\\8l\u0013\u0011\tI\"a\u0007\u0003\u001dM+7o]5p]J+h.\u0011:hg*\u0019\u0011Q\u0003\u0002\u0011\u000ba\u000by\"a\t\n\u0007\u0005\u0005\"MA\u0006Ue\u00064XM]:bE2,\u0007c\u0001\u001f\u0002&%\u0019\u0011qE\u001a\u0003\u0005=\u0003\b\u0003\u0002-a\u0003W\u0001B!!\f\u000245\u0011\u0011q\u0006\u0006\u0004\u0003c1\u0011a\u0002;f]N|'o]\u0005\u0005\u0003k\tyC\u0001\u0004UK:\u001cxN\u001d\u0005\t\u0003s\t\u0019\u0001q\u0001\u0002<\u0005aQ\r_3dkR\f'\r\\3FmB)1/!\u0010\u0002B%\u0019\u0011q\b;\u0003\u0015\u0015CXmY;uC\ndW\r\u0005\u0003\u0002D\u0005\u0015C\u0002\u0001\u0003\t\u0003\u000f\n\u0019A1\u0001\u0002J\t\tQ)\u0005\u0003\u0002L\u0005E\u0003cA\u0010\u0002N%\u0019\u0011q\n\u0011\u0003\u000f9{G\u000f[5oOB\u0019q$a\u0015\n\u0007\u0005U\u0003EA\u0002B]fD\u0001\"!\u0017\u0002\u0004\u0001\u000f\u00111L\u0001\fM\u0016$8\r[1cY\u0016,e\u000f\u0005\u0005\u0002^\u0005\r\u0014\u0011NA8\u001d\r\u0019\u0018qL\u0005\u0004\u0003C\"\u0018!\u0003$fi\u000eD\u0017M\u00197f\u0013\u0011\t)'a\u001a\u0003\u0007\u0005+\bPC\u0002\u0002bQ\u0004B!a\u0011\u0002l\u0011A\u0011QNA\u0002\u0005\u0004\tIEA\u0001G!\u0011\t\u0019%!\u001d\u0005\u0011\u0005M\u00141\u0001b\u0001\u0003\u0013\u0012\u0011A\u0015\u0005\t\u0003o\n\u0019\u00011\u0001\u0002z\u0005Q!/\u001e8D_:$X\r\u001f;\u0011\u0015\u0005E\u00111PA5\u0003\u0003\ny'\u0003\u0003\u0002~\u0005m!!E*fgNLwN\u001c*v]\u000e{g\u000e^3yi\"9\u0011\u0011\u0011\u0001\u0005R\u0005\r\u0015aD1gi\u0016\u00148+Z:tS>t'+\u001e8\u0016\u0011\u0005\u0015\u0015qSAH\u00037#b!a\"\u0002\u001e\u0006\u0005F#\u0002\u0010\u0002\n\u0006E\u0005\u0002CA\u001d\u0003\u007f\u0002\u001d!a#\u0011\u000bM\fi$!$\u0011\t\u0005\r\u0013q\u0012\u0003\t\u0003\u000f\nyH1\u0001\u0002J!A\u0011\u0011LA@\u0001\b\t\u0019\n\u0005\u0005\u0002^\u0005\r\u0014QSAM!\u0011\t\u0019%a&\u0005\u0011\u00055\u0014q\u0010b\u0001\u0003\u0013\u0002B!a\u0011\u0002\u001c\u0012A\u00111OA@\u0005\u0004\tI\u0005\u0003\u0005\u0002x\u0005}\u0004\u0019AAP!)\t\t\"a\u001f\u0002\u0016\u00065\u0015\u0011\u0014\u0005\t\u0003G\u000by\b1\u0001\u0002&\u0006I!/\u001e8SKN,H\u000e\u001e\t\b\u0003#\t9kVA\u0015\u0013\u0011\tI+a\u0007\u0003!M+7o]5p]J+hNU3tk2$\bFBA@\u0003[\u000b\u0019\rE\u0003 \u0003_\u000b\u0019,C\u0002\u00022\u0002\u0012a\u0001\u001e5s_^\u001c\b\u0003BA[\u0003\u007fk!!a.\u000b\t\u0005e\u00161X\u0001\u0005Y\u0006twM\u0003\u0002\u0002>\u0006!!.\u0019<b\u0013\u0011\t\t-a.\u0003+%cG.Z4bYN#\u0018\r^3Fq\u000e,\u0007\u000f^5p]F:a$!2\u0002V\u0006m\b\u0003BAd\u0003\u001ftA!!3\u0002LB\u0011!\fI\u0005\u0004\u0003\u001b\u0004\u0013A\u0002)sK\u0012,g-\u0003\u0003\u0002R\u0006M'AB*ue&twMC\u0002\u0002N\u0002\n\u0014bIAl\u0003?\f\t0!9\u0016\t\u0005e\u00171\\\u000b\u0003\u0003\u000b$q!!8\u0016\u0005\u0004\t9OA\u0001U\u0013\u0011\t\t/a9\u00027\u0011bWm]:j]&$He\u001a:fCR,'\u000f\n3fM\u0006,H\u000e\u001e\u00132\u0015\r\t)\u000fI\u0001\u0007i\"\u0014xn^:\u0012\t\u0005-\u0013\u0011\u001e\t\u0005\u0003W\fiO\u0004\u0002 =&\u0019\u0011q\u001e2\u0003\u0013QC'o\\<bE2,\u0017'C\u0012\u0002t\u0006U\u0018q_As\u001d\ry\u0012Q_\u0005\u0004\u0003K\u0004\u0013'\u0002\u0012 A\u0005e(!B:dC2\f\u0017g\u0001\u0014\u00024\u001e9\u0011q \u0002\t\u0002\t\u0005\u0011aB*u_B\u0004XM\u001d\t\u0004!\t\raAB\u0001\u0003\u0011\u0003\u0011)a\u0005\u0003\u0003\u0004\t\u001d\u0001cA\u0010\u0003\n%\u0019!1\u0002\u0011\u0003\r\u0005s\u0017PU3g\u0011\u001dI#1\u0001C\u0001\u0005\u001f!\"A!\u0001\t\u0019\tM!1\u0001b\u0001\n\u0003\u0011\u0019A!\u0006\u0002\r1|wmZ3s+\t\u00119\u0002\u0005\u0003\u0003\u001a\t\u001dRB\u0001B\u000e\u0015\u0011\u0011iBa\b\u0002\u0019M\u001c\u0017\r\\1m_\u001e<\u0017N\\4\u000b\t\t\u0005\"1E\u0001\tif\u0004Xm]1gK*\u0011!QE\u0001\u0004G>l\u0017\u0002\u0002B\u0015\u00057\u0011a\u0001T8hO\u0016\u0014\b\"\u0003B\u0017\u0005\u0007\u0001\u000b\u0011\u0002B\f\u0003\u001dawnZ4fe\u0002B\u0001B!\r\u0003\u0004\u0011\u0005!1G\u0001\u0006CB\u0004H.\u001f\u000b\u0004W\tU\u0002B\u0002\u000b\u00030\u0001\u0007q\u0003")
/* loaded from: input_file:org/platanios/tensorflow/api/learn/hooks/Stopper.class */
public class Stopper extends Hook {
    private StopCriteria criteria;
    private Variable epoch;
    private Variable step;
    private Output loss;
    private long startTime = 0;
    private Option<Object> lastEpoch = None$.MODULE$;
    private Option<Object> lastStep = None$.MODULE$;
    private float lastLoss = Float.MAX_VALUE;
    private int numStepsBelowTol = 0;
    private Seq<Output> sessionFetches;
    private int epochFetchIndex;
    private int stepFetchIndex;
    private int lossFetchIndex;

    public static Stopper apply(StopCriteria stopCriteria) {
        return Stopper$.MODULE$.apply(stopCriteria);
    }

    public StopCriteria criteria() {
        return this.criteria;
    }

    public void criteria_$eq(StopCriteria stopCriteria) {
        this.criteria = stopCriteria;
    }

    public void updateCriteria(StopCriteria stopCriteria) {
        criteria_$eq(stopCriteria);
    }

    public void reset(Session session) {
        this.startTime = System.currentTimeMillis();
        if (criteria().needEpoch()) {
            long unboxToLong = BoxesRunTime.unboxToLong(((Tensor) session.run(session.run$default$1(), this.epoch.value(), session.run$default$3(), session.run$default$4(), Executable$.MODULE$.traversableExecutable(Executable$.MODULE$.opExecutable()), Fetchable$.MODULE$.outputFetchable())).scalar());
            this.lastEpoch = criteria().restartCounting() ? criteria().maxEpochs().map(j -> {
                return j + unboxToLong;
            }) : criteria().maxEpochs();
        }
        if (criteria().needStep()) {
            long unboxToLong2 = BoxesRunTime.unboxToLong(((Tensor) session.run(session.run$default$1(), this.step.value(), session.run$default$3(), session.run$default$4(), Executable$.MODULE$.traversableExecutable(Executable$.MODULE$.opExecutable()), Fetchable$.MODULE$.outputFetchable())).scalar());
            this.lastStep = criteria().restartCounting() ? criteria().maxSteps().map(j2 -> {
                return j2 + unboxToLong2;
            }) : criteria().maxSteps();
        }
        if (criteria().needLoss()) {
            this.lastLoss = BoxesRunTime.unboxToFloat(((Tensor) session.run(session.run$default$1(), this.loss, session.run$default$3(), session.run$default$4(), Executable$.MODULE$.traversableExecutable(Executable$.MODULE$.opExecutable()), Fetchable$.MODULE$.outputFetchable())).scalar());
        }
        this.numStepsBelowTol = 0;
    }

    @Override // org.platanios.tensorflow.api.learn.hooks.Hook
    public void begin() {
        ListBuffer empty = ListBuffer$.MODULE$.empty();
        if (criteria().maxSeconds().isDefined()) {
            this.startTime = System.currentTimeMillis();
        }
        if (criteria().needEpoch()) {
            this.epoch = (Variable) Counter$.MODULE$.get(Graph$Keys$GLOBAL_EPOCH$.MODULE$, false, Op$.MODULE$.currentGraph()).getOrElse(() -> {
                throw new IllegalStateException(new StringBuilder(61).append("A ").append(Graph$Keys$GLOBAL_EPOCH$.MODULE$.name()).append(" variable should be created in order to use the 'StopHook'.").toString());
            });
            this.epochFetchIndex = empty.size();
            empty.append(Predef$.MODULE$.wrapRefArray(new Output[]{this.epoch.value()}));
        }
        if (criteria().needStep()) {
            this.step = (Variable) Counter$.MODULE$.get(Graph$Keys$GLOBAL_STEP$.MODULE$, false, Op$.MODULE$.currentGraph()).getOrElse(() -> {
                throw new IllegalStateException(new StringBuilder(61).append("A ").append(Graph$Keys$GLOBAL_STEP$.MODULE$.name()).append(" variable should be created in order to use the 'StopHook'.").toString());
            });
            this.stepFetchIndex = empty.size();
            empty.append(Predef$.MODULE$.wrapRefArray(new Output[]{this.step.value()}));
        }
        if (criteria().needLoss()) {
            this.loss = Math$.MODULE$.addN(Op$.MODULE$.currentGraph().getCollection(Graph$Keys$LOSSES$.MODULE$).toSeq(), Math$.MODULE$.addN$default$2());
            this.lossFetchIndex = empty.size();
            empty.append(Predef$.MODULE$.wrapRefArray(new Output[]{this.loss}));
        }
        this.sessionFetches = empty;
    }

    @Override // org.platanios.tensorflow.api.learn.hooks.Hook
    public void afterSessionCreation(Session session) {
        reset(session);
    }

    @Override // org.platanios.tensorflow.api.learn.hooks.Hook
    public <F, E, R> Option<Hook.SessionRunArgs<Seq<Output>, Traversable<Op>, Seq<Tensor>>> beforeSessionRun(Hook.SessionRunContext<F, E, R> sessionRunContext, Executable<E> executable, Fetchable<F> fetchable) {
        return new Some(new Hook.SessionRunArgs(Hook$SessionRunArgs$.MODULE$.apply$default$1(), this.sessionFetches, Hook$SessionRunArgs$.MODULE$.apply$default$3(), Hook$SessionRunArgs$.MODULE$.apply$default$4(), Hook$SessionRunArgs$.MODULE$.apply$default$5(), Executable$.MODULE$.traversableExecutable(Executable$.MODULE$.opExecutable()), Fetchable$.MODULE$.fetchableSeq(Fetchable$.MODULE$.outputFetchable(), Seq$.MODULE$.canBuildFrom())));
    }

    @Override // org.platanios.tensorflow.api.learn.hooks.Hook
    public <F, E, R> void afterSessionRun(Hook.SessionRunContext<F, E, R> sessionRunContext, Hook.SessionRunResult<Seq<Output>, Seq<Tensor>> sessionRunResult, Executable<E> executable, Fetchable<F> fetchable) throws IllegalStateException {
        BooleanRef create = BooleanRef.create(false);
        if (criteria().maxEpochs().isDefined()) {
            long unboxToLong = BoxesRunTime.unboxToLong(((Tensor) sessionRunResult.values().apply(this.epochFetchIndex)).scalar());
            if (this.lastEpoch.exists(j -> {
                return unboxToLong >= j;
            })) {
                if (Stopper$.MODULE$.logger().underlying().isDebugEnabled()) {
                    Stopper$.MODULE$.logger().underlying().debug("Stop requested: Exceeded maximum number of epochs.");
                    BoxedUnit boxedUnit = BoxedUnit.UNIT;
                } else {
                    BoxedUnit boxedUnit2 = BoxedUnit.UNIT;
                }
                create.elem = true;
            }
        }
        if (criteria().maxSteps().isDefined()) {
            long unboxToLong2 = BoxesRunTime.unboxToLong(((Tensor) sessionRunResult.values().apply(this.stepFetchIndex)).scalar());
            if (this.lastStep.exists(j2 -> {
                return unboxToLong2 >= j2;
            })) {
                if (Stopper$.MODULE$.logger().underlying().isDebugEnabled()) {
                    Stopper$.MODULE$.logger().underlying().debug("Stop requested: Exceeded maximum number of steps.");
                    BoxedUnit boxedUnit3 = BoxedUnit.UNIT;
                } else {
                    BoxedUnit boxedUnit4 = BoxedUnit.UNIT;
                }
                create.elem = true;
            }
        }
        criteria().maxSeconds().foreach(j3 -> {
            if (System.currentTimeMillis() - this.startTime >= j3) {
                if (Stopper$.MODULE$.logger().underlying().isDebugEnabled()) {
                    Stopper$.MODULE$.logger().underlying().debug("Stop requested: Exceeded maximum number of seconds.");
                    BoxedUnit boxedUnit5 = BoxedUnit.UNIT;
                } else {
                    BoxedUnit boxedUnit6 = BoxedUnit.UNIT;
                }
                create.elem = true;
            }
        });
        if (criteria().absLossChangeTol().isDefined() || criteria().relLossChangeTol().isDefined()) {
            float abs = scala.math.package$.MODULE$.abs(this.lastLoss - BoxesRunTime.unboxToFloat(((Tensor) sessionRunResult.values().apply(this.lossFetchIndex)).scalar()));
            if (criteria().absLossChangeTol().exists(d -> {
                return ((double) abs) < d;
            }) || criteria().relLossChangeTol().exists(d2 -> {
                return ((double) scala.math.package$.MODULE$.abs(abs / this.lastLoss)) < d2;
            })) {
                this.numStepsBelowTol++;
            } else {
                this.numStepsBelowTol = 0;
            }
            if (this.numStepsBelowTol > criteria().maxStepBelowTol()) {
                if (Stopper$.MODULE$.logger().underlying().isDebugEnabled()) {
                    Stopper$.MODULE$.logger().underlying().debug("Stop requested: Loss value converged.");
                    BoxedUnit boxedUnit5 = BoxedUnit.UNIT;
                } else {
                    BoxedUnit boxedUnit6 = BoxedUnit.UNIT;
                }
                create.elem = true;
            }
        }
        if (create.elem) {
            sessionRunContext.requestStop();
        }
    }

    public Stopper(StopCriteria stopCriteria) {
        this.criteria = stopCriteria;
    }
}
