package org.platanios.tensorflow.api.learn;

import org.platanios.tensorflow.api.core.Graph$Keys$GLOBAL_STEP$;
import org.platanios.tensorflow.api.learn.Model;
import org.platanios.tensorflow.api.learn.layers.Input;
import org.platanios.tensorflow.api.learn.layers.Layer;
import org.platanios.tensorflow.api.ops.Math$;
import org.platanios.tensorflow.api.ops.Output;
import org.platanios.tensorflow.api.ops.OutputLike;
import org.platanios.tensorflow.api.ops.OutputOps$;
import org.platanios.tensorflow.api.ops.io.data.Iterator;
import org.platanios.tensorflow.api.ops.metrics.Metric;
import org.platanios.tensorflow.api.ops.training.optimizers.Optimizer;
import org.platanios.tensorflow.api.ops.variables.Variable;
import scala.Some;
import scala.Tuple2;
import scala.collection.Seq;
import scala.collection.Seq$;
import scala.reflect.ScalaSignature;

/* compiled from: Model.scala */
@ScalaSignature(bytes = "\u0006\u0001\t=a!B\u0001\u0003\u0001\ta!aI*va\u0016\u0014h/[:fI\u000e{g\u000eZ5uS>t\u0017\r\u001c+sC&t\u0017M\u00197f\u001b>$W\r\u001c\u0006\u0003\u0007\u0011\tQ\u0001\\3be:T!!\u0002\u0004\u0002\u0007\u0005\u0004\u0018N\u0003\u0002\b\u0011\u0005QA/\u001a8t_J4Gn\\<\u000b\u0005%Q\u0011!\u00039mCR\fg.[8t\u0015\u0005Y\u0011aA8sOVYQ\u0002\u0006\u0012&Q-\nDg\u000e\u001e>'\r\u0001a\"\f\t\b\u001fA\u0011\u0012\u0005J\u0014+\u001b\u0005\u0011\u0011BA\t\u0003\u0005Q\u0019\u0016.\u001c9mK&sg-\u001a:f]\u000e,Wj\u001c3fYB\u00111\u0003\u0006\u0007\u0001\t\u0015)\u0002A1\u0001\u0018\u0005\tIEk\u0001\u0001\u0012\u0005aq\u0002CA\r\u001d\u001b\u0005Q\"\"A\u000e\u0002\u000bM\u001c\u0017\r\\1\n\u0005uQ\"a\u0002(pi\"Lgn\u001a\t\u00033}I!\u0001\t\u000e\u0003\u0007\u0005s\u0017\u0010\u0005\u0002\u0014E\u0011)1\u0005\u0001b\u0001/\t\u0011\u0011j\u0014\t\u0003'\u0015\"QA\n\u0001C\u0002]\u0011!!\u0013#\u0011\u0005MAC!B\u0015\u0001\u0005\u00049\"AA%T!\t\u00192\u0006B\u0003-\u0001\t\u0007qCA\u0001J!1yaFE\u0011%O)\u00024GN\u001d=\u0013\ty#A\u0001\rTkB,'O^5tK\u0012$&/Y5oC\ndW-T8eK2\u0004\"aE\u0019\u0005\u000bI\u0002!\u0019A\f\u0003\u0005Q#\u0006CA\n5\t\u0015)\u0004A1\u0001\u0018\u0005\t!v\n\u0005\u0002\u0014o\u0011)\u0001\b\u0001b\u0001/\t\u0011A\u000b\u0012\t\u0003'i\"Qa\u000f\u0001C\u0002]\u0011!\u0001V*\u0011\u0005MiD!\u0002 \u0001\u0005\u00049\"!\u0001+\t\u0011\u0001\u0003!Q1A\u0005B\u0005\u000bQ!\u001b8qkR,\u0012A\u0011\u0019\u0003\u0007*\u0003r\u0001R$\u0013C%#s%D\u0001F\u0015\t1%!\u0001\u0004mCf,'o]\u0005\u0003\u0011\u0016\u0013Q!\u00138qkR\u0004\"a\u0005&\u0005\u0013-c\u0015\u0011!A\u0001\u0006\u00039\"aA0%k!IQ\n\u0001B\u0001B\u0003%!IT\u0001\u0007S:\u0004X\u000f\u001e\u0011\n\u0005\u0001\u0003\u0002\u0002\u0003)\u0001\u0005\u000b\u0007I\u0011I)\u0002\u000b1\f\u00170\u001a:\u0016\u0003I\u0003B\u0001R*\"U%\u0011A+\u0012\u0002\u0006\u0019\u0006LXM\u001d\u0005\n-\u0002\u0011\t\u0011)A\u0005%^\u000ba\u0001\\1zKJ\u0004\u0013B\u0001)\u0011\u0011!I\u0006A!b\u0001\n\u0003Q\u0016A\u0003;sC&tG*Y=feV\t1\f\u0005\u0003E'rS\u0003\u0003B\r^CMJ!A\u0018\u000e\u0003\rQ+\b\u000f\\33\u0011!\u0001\u0007A!A!\u0002\u0013Y\u0016a\u0003;sC&tG*Y=fe\u0002B\u0001B\u0019\u0001\u0003\u0006\u0004%\taY\u0001\u000biJ\f\u0017N\\%oaV$X#\u000131\u0005\u0015<\u0007c\u0002#HaM2g'\u000f\t\u0003'\u001d$\u0011\u0002[5\u0002\u0002\u0003\u0005)\u0011A\f\u0003\u0007}#c\u0007\u0003\u0005k\u0001\t\u0005\t\u0015!\u0003e\u0003-!(/Y5o\u0013:\u0004X\u000f\u001e\u0011\t\u00111\u0004!Q1A\u0005\u00025\fq\u0002\u001e:bS:Le\u000e];u\u0019\u0006LXM]\u000b\u0002]B!AiU\u001a=\u0011!\u0001\bA!A!\u0002\u0013q\u0017\u0001\u0005;sC&t\u0017J\u001c9vi2\u000b\u00170\u001a:!\u0011!\u0011\bA!b\u0001\n\u0003\u0019\u0018\u0001\u00027pgN,\u0012\u0001\u001e\t\u0005\tN+h\u000f\u0005\u0003\u001a;*b\u0004CA<{\u001b\u0005A(BA=\u0005\u0003\ry\u0007o]\u0005\u0003wb\u0014aaT;uaV$\b\u0002C?\u0001\u0005\u0003\u0005\u000b\u0011\u0002;\u0002\u000b1|7o\u001d\u0011\t\u0013}\u0004!Q1A\u0005\u0002\u0005\u0005\u0011!C8qi&l\u0017N_3s+\t\t\u0019\u0001\u0005\u0003\u0002\u0006\u0005=QBAA\u0004\u0015\u0011\tI!a\u0003\u0002\u0015=\u0004H/[7ju\u0016\u00148OC\u0002\u0002\u000ea\f\u0001\u0002\u001e:bS:LgnZ\u0005\u0005\u0003#\t9AA\u0005PaRLW.\u001b>fe\"Q\u0011Q\u0003\u0001\u0003\u0002\u0003\u0006I!a\u0001\u0002\u0015=\u0004H/[7ju\u0016\u0014\b\u0005\u0003\u0006\u0002\u001a\u0001\u0011)\u0019!C\u0001\u00037\tQb\u00197ja\u001e\u0013\u0018\rZ5f]R\u001cXCAA\u000f!\ry\u0011qD\u0005\u0004\u0003C\u0011!!D\"mSB<%/\u00193jK:$8\u000f\u0003\u0006\u0002&\u0001\u0011\t\u0011)A\u0005\u0003;\tab\u00197ja\u001e\u0013\u0018\rZ5f]R\u001c\b\u0005\u0003\u0006\u0002*\u0001\u0011)\u0019!C)\u0003W\t\u0001dY8m_\u000e\fG/Z$sC\u0012LWM\u001c;t/&$\bn\u00149t+\t\ti\u0003E\u0002\u001a\u0003_I1!!\r\u001b\u0005\u001d\u0011un\u001c7fC:D!\"!\u000e\u0001\u0005\u0003\u0005\u000b\u0011BA\u0017\u0003e\u0019w\u000e\\8dCR,wI]1eS\u0016tGo],ji\"|\u0005o\u001d\u0011\t\u0011\u0005e\u0002\u0001\"\u0001\u0003\u0003w\ta\u0001P5oSRtD\u0003FA\u001f\u0003\u007f\tI%a\u0013\u0002N\u0005]\u0013\u0011LA.\u0003;\ny\u0006\u0005\u0007\u0010\u0001I\tCe\n\u00161gYJD\bC\u0004A\u0003o\u0001\r!!\u00111\t\u0005\r\u0013q\t\t\t\t\u001e\u0013\u0012%!\u0012%OA\u00191#a\u0012\u0005\u0015-\u000by$!A\u0001\u0002\u000b\u0005q\u0003\u0003\u0004Q\u0003o\u0001\rA\u0015\u0005\u00073\u0006]\u0002\u0019A.\t\u000f\t\f9\u00041\u0001\u0002PA\"\u0011\u0011KA+!!!u\tM\u001a\u0002TYJ\u0004cA\n\u0002V\u0011Q\u0001.!\u0014\u0002\u0002\u0003\u0005)\u0011A\f\t\r1\f9\u00041\u0001o\u0011\u0019\u0011\u0018q\u0007a\u0001i\"9q0a\u000eA\u0002\u0005\r\u0001BCA\r\u0003o\u0001\n\u00111\u0001\u0002\u001e!Q\u0011\u0011FA\u001c!\u0003\u0005\r!!\f\t\u000f\u0005\r\u0004\u0001\"\u0011\u0002f\u0005i!-^5mIR\u0013\u0018-\u001b8PaN$\"!a\u001a\u0011\u001d\u0005%\u0014q\u000e\n\"I\u001dR\u0003g\r\u001c:y9\u0019q\"a\u001b\n\u0007\u00055$!A\u0003N_\u0012,G.\u0003\u0003\u0002r\u0005M$AE*va\u0016\u0014h/[:fIR\u0013\u0018-\u001b8PaNT1!!\u001c\u0003\u0011\u001d\t9\b\u0001C!\u0003s\n\u0001CY;jY\u0012,e/\u00197vCR,w\n]:\u0015\t\u0005m\u0014q\u0011\t\r\u0003S\ni(!!]\u0003\u0007\u000b)IK\u0005\u0005\u0003\u007f\n\u0019HA\u0006Fm\u0006dW/\u0019;f\u001fB\u001c\b\u0003B\r^%A\u0002B!G/%mA!\u0011$X\u0014:\u0011!\tI)!\u001eA\u0002\u0005-\u0015aB7fiJL7m\u001d\t\u0007\u0003\u001b\u000bi*a)\u000f\t\u0005=\u0015\u0011\u0014\b\u0005\u0003#\u000b9*\u0004\u0002\u0002\u0014*\u0019\u0011Q\u0013\f\u0002\rq\u0012xn\u001c;?\u0013\u0005Y\u0012bAAN5\u00059\u0001/Y2lC\u001e,\u0017\u0002BAP\u0003C\u00131aU3r\u0015\r\tYJ\u0007\t\u0007\u0003K\u000bI+\u001e<\u000e\u0005\u0005\u001d&bAAEq&!\u00111VAT\u0005\u0019iU\r\u001e:jG\u001eQ\u0011q\u0016\u0002\u0002\u0002#\u0005!!!-\u0002GM+\b/\u001a:wSN,GmQ8oI&$\u0018n\u001c8bYR\u0013\u0018-\u001b8bE2,Wj\u001c3fYB\u0019q\"a-\u0007\u0013\u0005\u0011\u0011\u0011!E\u0001\u0005\u0005U6\u0003BAZ\u0003o\u00032!GA]\u0013\r\tYL\u0007\u0002\u0007\u0003:L(+\u001a4\t\u0011\u0005e\u00121\u0017C\u0001\u0003\u007f#\"!!-\t\u0017\u0005\r\u00171WI\u0001\n\u0003\u0011\u0011QY\u0001\u001cI1,7o]5oSR$sM]3bi\u0016\u0014H\u0005Z3gCVdG\u000f\n\u001d\u0016-\u0005\u001d\u0017Q\\Ap\u0003C\f\u0019/!:\u0002h\u0006%\u00181^Aw\u0003_,\"!!3+\t\u0005u\u00111Z\u0016\u0003\u0003\u001b\u0004B!a4\u0002Z6\u0011\u0011\u0011\u001b\u0006\u0005\u0003'\f).A\u0005v]\u000eDWmY6fI*\u0019\u0011q\u001b\u000e\u0002\u0015\u0005tgn\u001c;bi&|g.\u0003\u0003\u0002\\\u0006E'!E;oG\",7m[3e-\u0006\u0014\u0018.\u00198dK\u00121Q#!1C\u0002]!aaIAa\u0005\u00049BA\u0002\u0014\u0002B\n\u0007q\u0003\u0002\u0004*\u0003\u0003\u0014\ra\u0006\u0003\u0007Y\u0005\u0005'\u0019A\f\u0005\rI\n\tM1\u0001\u0018\t\u0019)\u0014\u0011\u0019b\u0001/\u00111\u0001(!1C\u0002]!aaOAa\u0005\u00049BA\u0002 \u0002B\n\u0007q\u0003C\u0006\u0002t\u0006M\u0016\u0013!C\u0001\u0005\u0005U\u0018a\u0007\u0013mKN\u001c\u0018N\\5uI\u001d\u0014X-\u0019;fe\u0012\"WMZ1vYR$\u0013(\u0006\f\u0002x\u0006m\u0018Q`A��\u0005\u0003\u0011\u0019A!\u0002\u0003\b\t%!1\u0002B\u0007+\t\tIP\u000b\u0003\u0002.\u0005-GAB\u000b\u0002r\n\u0007q\u0003\u0002\u0004$\u0003c\u0014\ra\u0006\u0003\u0007M\u0005E(\u0019A\f\u0005\r%\n\tP1\u0001\u0018\t\u0019a\u0013\u0011\u001fb\u0001/\u00111!'!=C\u0002]!a!NAy\u0005\u00049BA\u0002\u001d\u0002r\n\u0007q\u0003\u0002\u0004<\u0003c\u0014\ra\u0006\u0003\u0007}\u0005E(\u0019A\f")
/* loaded from: input_file:org/platanios/tensorflow/api/learn/SupervisedConditionalTrainableModel.class */
public class SupervisedConditionalTrainableModel<IT, IO, ID, IS, I, TT, TO, TD, TS, T> extends SimpleInferenceModel<IT, IO, ID, IS, I> implements SupervisedTrainableModel<IT, IO, ID, IS, I, TT, TO, TD, TS, T> {
    private final Layer<Tuple2<IO, TO>, I> trainLayer;
    private final Input<TT, TO, ?, TD, TS> trainInput;
    private final Layer<TO, T> trainInputLayer;
    private final Layer<Tuple2<I, T>, Output> loss;
    private final Optimizer optimizer;
    private final ClipGradients clipGradients;
    private final boolean colocateGradientsWithOps;

    @Override // org.platanios.tensorflow.api.learn.SimpleInferenceModel
    public Input<IT, IO, ?, ID, IS> input() {
        return super.input();
    }

    @Override // org.platanios.tensorflow.api.learn.SimpleInferenceModel
    public Layer<IO, I> layer() {
        return super.layer();
    }

    public Layer<Tuple2<IO, TO>, I> trainLayer() {
        return this.trainLayer;
    }

    public Input<TT, TO, ?, TD, TS> trainInput() {
        return this.trainInput;
    }

    public Layer<TO, T> trainInputLayer() {
        return this.trainInputLayer;
    }

    public Layer<Tuple2<I, T>, Output> loss() {
        return this.loss;
    }

    public Optimizer optimizer() {
        return this.optimizer;
    }

    public ClipGradients clipGradients() {
        return this.clipGradients;
    }

    @Override // org.platanios.tensorflow.api.learn.SimpleInferenceModel, org.platanios.tensorflow.api.learn.Model
    public boolean colocateGradientsWithOps() {
        return this.colocateGradientsWithOps;
    }

    /* JADX WARN: Multi-variable type inference failed */
    @Override // org.platanios.tensorflow.api.learn.TrainableModel
    public Model.SupervisedTrainOps<IT, IO, ID, IS, I, TT, TO, TD, TS, T> buildTrainOps() {
        TRAINING$ training$ = TRAINING$.MODULE$;
        Iterator apply = input().zip(trainInput()).apply();
        Tuple2<IO, TO> tuple2 = (Tuple2) apply.next(apply.next$default$1());
        I apply2 = trainLayer().apply(tuple2, training$);
        Object apply3 = trainInputLayer().apply(tuple2._2(), training$);
        Output output = (Output) Math$.MODULE$.cast(loss().apply(new Tuple2<>(apply2, apply3), training$), org.platanios.tensorflow.api.types.package$.MODULE$.FLOAT32(), "LossCast", OutputOps$.MODULE$.outputOps());
        Variable orCreate = Counter$.MODULE$.getOrCreate(Graph$Keys$GLOBAL_STEP$.MODULE$, false, Counter$.MODULE$.getOrCreate$default$3());
        boolean colocateGradientsWithOps = colocateGradientsWithOps();
        Seq<Tuple2<OutputLike, Variable>> computeGradients = optimizer().computeGradients(output, optimizer().computeGradients$default$2(), optimizer().computeGradients$default$3(), optimizer().computeGradients$default$4(), optimizer().computeGradients$default$5(), colocateGradientsWithOps);
        return new Model.SupervisedTrainOps<>(apply, tuple2, apply2, apply3, output, computeGradients, optimizer().applyGradients(clipGradients().apply(computeGradients), new Some(orCreate), optimizer().applyGradients$default$3()));
    }

    /* JADX WARN: Multi-variable type inference failed */
    @Override // org.platanios.tensorflow.api.learn.SupervisedTrainableModel, org.platanios.tensorflow.api.learn.TrainableModel
    public Model.EvaluateOps<Tuple2<IT, TT>, Tuple2<IO, TO>, Tuple2<ID, TD>, Tuple2<IS, TS>, I> buildEvaluateOps(Seq<Metric<Tuple2<I, T>, Output>> seq) {
        EVALUATION$ evaluation$ = EVALUATION$.MODULE$;
        Iterator apply = input().zip(trainInput()).apply();
        Tuple2 tuple2 = (Tuple2) apply.next(apply.next$default$1());
        Object apply2 = layer().apply(tuple2._1(), evaluation$);
        Object apply3 = trainInputLayer().apply(tuple2._2(), evaluation$);
        Seq seq2 = (Seq) seq.map(metric -> {
            return metric.streaming(new Tuple2(apply2, apply3), metric.streaming$default$2(), metric.streaming$default$3());
        }, Seq$.MODULE$.canBuildFrom());
        return new Model.EvaluateOps<>(apply, tuple2, apply2, (Seq) seq2.map(streamingInstance -> {
            return (Output) streamingInstance.value();
        }, Seq$.MODULE$.canBuildFrom()), (Seq) seq2.map(streamingInstance2 -> {
            return (Output) streamingInstance2.update();
        }, Seq$.MODULE$.canBuildFrom()), (Seq) seq2.map(streamingInstance3 -> {
            return streamingInstance3.reset();
        }, Seq$.MODULE$.canBuildFrom()));
    }

    /* JADX WARN: 'super' call moved to the top of the method (can break code semantics) */
    public SupervisedConditionalTrainableModel(Input<IT, IO, ?, ID, IS> input, Layer<IO, I> layer, Layer<Tuple2<IO, TO>, I> layer2, Input<TT, TO, ?, TD, TS> input2, Layer<TO, T> layer3, Layer<Tuple2<I, T>, Output> layer4, Optimizer optimizer, ClipGradients clipGradients, boolean z) {
        super(input, layer);
        this.trainLayer = layer2;
        this.trainInput = input2;
        this.trainInputLayer = layer3;
        this.loss = layer4;
        this.optimizer = optimizer;
        this.clipGradients = clipGradients;
        this.colocateGradientsWithOps = z;
    }
}
