package org.platanios.tensorflow.api.ops;

import org.junit.Test;
import org.platanios.tensorflow.api.core.Graph$;
import org.platanios.tensorflow.api.core.Indexer;
import org.platanios.tensorflow.api.core.Shape;
import org.platanios.tensorflow.api.core.Shape$;
import org.platanios.tensorflow.api.core.client.Executable$;
import org.platanios.tensorflow.api.core.client.FeedMap$;
import org.platanios.tensorflow.api.core.client.Feedable$;
import org.platanios.tensorflow.api.core.client.Fetchable$;
import org.platanios.tensorflow.api.core.client.Session;
import org.platanios.tensorflow.api.core.client.Session$;
import org.platanios.tensorflow.api.package$;
import org.platanios.tensorflow.api.tensors.Tensor;
import org.platanios.tensorflow.api.tensors.Tensor$;
import org.platanios.tensorflow.api.tensors.TensorConvertible$;
import org.platanios.tensorflow.api.tensors.ops.Math$;
import org.platanios.tensorflow.api.types.DataType;
import org.platanios.tensorflow.api.types.SupportedType$;
import org.scalactic.Bool$;
import org.scalactic.Equality$;
import org.scalactic.Prettifier$;
import org.scalactic.TripleEqualsSupport;
import org.scalactic.source.Position;
import org.scalatest.junit.JUnitSuite;
import scala.MatchError;
import scala.Predef$;
import scala.Predef$ArrowAssoc$;
import scala.Tuple2;
import scala.Tuple4;
import scala.collection.Seq;
import scala.collection.Seq$;
import scala.reflect.ScalaSignature;
import scala.runtime.BoxesRunTime;

/* compiled from: CallbackSuite.scala */
@ScalaSignature(bytes = "\u0006\u0001%3A!\u0001\u0002\u0001\u001b\ti1)\u00197mE\u0006\u001c7nU;ji\u0016T!a\u0001\u0003\u0002\u0007=\u00048O\u0003\u0002\u0006\r\u0005\u0019\u0011\r]5\u000b\u0005\u001dA\u0011A\u0003;f]N|'O\u001a7po*\u0011\u0011BC\u0001\na2\fG/\u00198j_NT\u0011aC\u0001\u0004_J<7\u0001A\n\u0003\u00019\u0001\"a\u0004\u000b\u000e\u0003AQ!!\u0005\n\u0002\u000b),h.\u001b;\u000b\u0005MQ\u0011!C:dC2\fG/Z:u\u0013\t)\u0002C\u0001\u0006K+:LGoU;ji\u0016DQa\u0006\u0001\u0005\u0002a\ta\u0001P5oSRtD#A\r\u0011\u0005i\u0001Q\"\u0001\u0002\t\u000bq\u0001A\u0011A\u000f\u0002\rM\fX/\u0019:f)\tqB\u0005\u0005\u0002 E5\t\u0001E\u0003\u0002\"\t\u00059A/\u001a8t_J\u001c\u0018BA\u0012!\u0005\u0019!VM\\:pe\")Qe\u0007a\u0001=\u0005)\u0011N\u001c9vi\")q\u0005\u0001C\u0001Q\u0005\u0019\u0011\r\u001a3\u0015\u0005yI\u0003\"\u0002\u0016'\u0001\u0004Y\u0013AB5oaV$8\u000fE\u0002-myq!!L\u001a\u000f\u00059\nT\"A\u0018\u000b\u0005Ab\u0011A\u0002\u001fs_>$h(C\u00013\u0003\u0015\u00198-\u00197b\u0013\t!T'A\u0004qC\u000e\\\u0017mZ3\u000b\u0003IJ!a\u000e\u001d\u0003\u0007M+\u0017O\u0003\u00025k!)!\b\u0001C\u0001w\u0005YC/Z:u\u0013\u0012,g\u000e^5usNKgn\u001a7f\u0013:\u0004X\u000f^*j]\u001edWmT;uaV$8)\u00197mE\u0006\u001c7\u000eF\u0001=!\tid(D\u00016\u0013\tyTG\u0001\u0003V]&$\bFA\u001dB!\t\u0011E)D\u0001D\u0015\t\t\"\"\u0003\u0002F\u0007\n!A+Z:u\u0011\u00159\u0005\u0001\"\u0001<\u00035\"Xm\u001d;JI\u0016tG/\u001b;z\u001bVdG/\u001b9mK&s\u0007/\u001e;TS:<G.Z(viB,HoQ1mY\n\f7m\u001b\u0015\u0003\r\u0006\u0003")
/* loaded from: input_file:org/platanios/tensorflow/api/ops/CallbackSuite.class */
public class CallbackSuite extends JUnitSuite {
    public Tensor square(Tensor tensor) {
        return package$.MODULE$.tensorToMathOps(tensor).square();
    }

    public Tensor add(Seq<Tensor> seq) {
        return Math$.MODULE$.addN(seq, package$.MODULE$.tensorEagerExecutionContext());
    }

    @Test
    public void testIdentitySingleInputSingleOutputCallback() {
        org.platanios.tensorflow.api.utilities.package$.MODULE$.using(Graph$.MODULE$.apply(), graph -> {
            Tuple2 tuple2 = (Tuple2) Op$.MODULE$.createWith(graph, Op$.MODULE$.createWith$default$2(), Op$.MODULE$.createWith$default$3(), Op$.MODULE$.createWith$default$4(), Op$.MODULE$.createWith$default$5(), Op$.MODULE$.createWith$default$6(), Op$.MODULE$.createWith$default$7(), Op$.MODULE$.createWith$default$8(), () -> {
                Output placeholder = Basic$.MODULE$.placeholder(org.platanios.tensorflow.api.types.package$.MODULE$.FLOAT32(), Basic$.MODULE$.placeholder$default$2(), Basic$.MODULE$.placeholder$default$3());
                return new Tuple2(placeholder, (Output) Callback$.MODULE$.callback(tensor -> {
                    return this.square(tensor);
                }, placeholder, org.platanios.tensorflow.api.types.package$.MODULE$.FLOAT32(), Callback$.MODULE$.callback$default$4(), Callback$.MODULE$.callback$default$5(), Callback$ArgType$.MODULE$.tensorArgType(), Callback$ArgType$.MODULE$.tensorArgType()));
            }, package$.MODULE$.opCreationContext());
            if (tuple2 == null) {
                throw new MatchError(tuple2);
            }
            Tuple2 tuple22 = new Tuple2((Output) tuple2._1(), (Output) tuple2._2());
            Output output = (Output) tuple22._1();
            Output output2 = (Output) tuple22._2();
            Session apply = Session$.MODULE$.apply(graph, Session$.MODULE$.apply$default$2(), Session$.MODULE$.apply$default$3());
            Tensor tensor = (Tensor) apply.run(FeedMap$.MODULE$.feedMap(Predef$.MODULE$.Map().apply(Predef$.MODULE$.wrapRefArray(new Tuple2[]{Predef$ArrowAssoc$.MODULE$.$minus$greater$extension(Predef$.MODULE$.ArrowAssoc(output), Tensor$.MODULE$.apply(BoxesRunTime.boxToFloat(2.0f), Predef$.MODULE$.wrapFloatArray(new float[]{5.0f}), TensorConvertible$.MODULE$.supportedTypeTensorConvertible(SupportedType$.MODULE$.floatIsSupportedType())))})), Feedable$.MODULE$.outputFeedable()), output2, apply.run$default$3(), apply.run$default$4(), Executable$.MODULE$.traversableExecutable(Executable$.MODULE$.opExecutable()), Fetchable$.MODULE$.outputFetchable());
            TripleEqualsSupport.Equalizer convertToEqualizer = this.convertToEqualizer(tensor.dataType());
            DataType.Aux FLOAT32 = org.platanios.tensorflow.api.types.package$.MODULE$.FLOAT32();
            this.assertionsHelper().macroAssert(Bool$.MODULE$.binaryMacroBool(convertToEqualizer, "===", FLOAT32, convertToEqualizer.$eq$eq$eq(FLOAT32, Equality$.MODULE$.default()), Prettifier$.MODULE$.default()), "", Prettifier$.MODULE$.default(), new Position("CallbackSuite.scala", "Please set the environment variable SCALACTIC_FILL_FILE_PATHNAMES to yes at compile time to enable this feature.", 44));
            TripleEqualsSupport.Equalizer convertToEqualizer2 = this.convertToEqualizer(tensor.shape());
            Shape apply2 = Shape$.MODULE$.apply(Predef$.MODULE$.wrapIntArray(new int[]{2}));
            this.assertionsHelper().macroAssert(Bool$.MODULE$.binaryMacroBool(convertToEqualizer2, "===", apply2, convertToEqualizer2.$eq$eq$eq(apply2, Equality$.MODULE$.default()), Prettifier$.MODULE$.default()), "", Prettifier$.MODULE$.default(), new Position("CallbackSuite.scala", "Please set the environment variable SCALACTIC_FILL_FILE_PATHNAMES to yes at compile time to enable this feature.", 45));
            TripleEqualsSupport.Equalizer convertToEqualizer3 = this.convertToEqualizer(tensor.apply(Predef$.MODULE$.wrapRefArray(new Indexer[]{package$.MODULE$.intToIndex(0)})).scalar());
            this.assertionsHelper().macroAssert(Bool$.MODULE$.binaryMacroBool(convertToEqualizer3, "===", BoxesRunTime.boxToFloat(4.0f), convertToEqualizer3.$eq$eq$eq(BoxesRunTime.boxToFloat(4.0f), Equality$.MODULE$.default()), Prettifier$.MODULE$.default()), "", Prettifier$.MODULE$.default(), new Position("CallbackSuite.scala", "Please set the environment variable SCALACTIC_FILL_FILE_PATHNAMES to yes at compile time to enable this feature.", 46));
            TripleEqualsSupport.Equalizer convertToEqualizer4 = this.convertToEqualizer(tensor.apply(Predef$.MODULE$.wrapRefArray(new Indexer[]{package$.MODULE$.intToIndex(1)})).scalar());
            return this.assertionsHelper().macroAssert(Bool$.MODULE$.binaryMacroBool(convertToEqualizer4, "===", BoxesRunTime.boxToFloat(25.0f), convertToEqualizer4.$eq$eq$eq(BoxesRunTime.boxToFloat(25.0f), Equality$.MODULE$.default()), Prettifier$.MODULE$.default()), "", Prettifier$.MODULE$.default(), new Position("CallbackSuite.scala", "Please set the environment variable SCALACTIC_FILL_FILE_PATHNAMES to yes at compile time to enable this feature.", 47));
        });
    }

    @Test
    public void testIdentityMultipleInputSingleOutputCallback() {
        org.platanios.tensorflow.api.utilities.package$.MODULE$.using(Graph$.MODULE$.apply(), graph -> {
            Tuple4 tuple4 = (Tuple4) Op$.MODULE$.createWith(graph, Op$.MODULE$.createWith$default$2(), Op$.MODULE$.createWith$default$3(), Op$.MODULE$.createWith$default$4(), Op$.MODULE$.createWith$default$5(), Op$.MODULE$.createWith$default$6(), Op$.MODULE$.createWith$default$7(), Op$.MODULE$.createWith$default$8(), () -> {
                Output placeholder = Basic$.MODULE$.placeholder(org.platanios.tensorflow.api.types.package$.MODULE$.FLOAT64(), Basic$.MODULE$.placeholder$default$2(), Basic$.MODULE$.placeholder$default$3());
                Output placeholder2 = Basic$.MODULE$.placeholder(org.platanios.tensorflow.api.types.package$.MODULE$.FLOAT64(), Basic$.MODULE$.placeholder$default$2(), Basic$.MODULE$.placeholder$default$3());
                Output placeholder3 = Basic$.MODULE$.placeholder(org.platanios.tensorflow.api.types.package$.MODULE$.FLOAT64(), Basic$.MODULE$.placeholder$default$2(), Basic$.MODULE$.placeholder$default$3());
                return new Tuple4(placeholder, placeholder2, placeholder3, (Output) Callback$.MODULE$.callback(seq -> {
                    return this.add(seq);
                }, Seq$.MODULE$.apply(Predef$.MODULE$.wrapRefArray(new Output[]{placeholder, placeholder2, placeholder3})), org.platanios.tensorflow.api.types.package$.MODULE$.FLOAT64(), Callback$.MODULE$.callback$default$4(), Callback$.MODULE$.callback$default$5(), Callback$ArgType$.MODULE$.tensorSeqArgType(Seq$.MODULE$.canBuildFrom(), Seq$.MODULE$.canBuildFrom()), Callback$ArgType$.MODULE$.tensorArgType()));
            }, package$.MODULE$.opCreationContext());
            if (tuple4 == null) {
                throw new MatchError(tuple4);
            }
            Tuple4 tuple42 = new Tuple4((Output) tuple4._1(), (Output) tuple4._2(), (Output) tuple4._3(), (Output) tuple4._4());
            Output output = (Output) tuple42._1();
            Output output2 = (Output) tuple42._2();
            Output output3 = (Output) tuple42._3();
            Output output4 = (Output) tuple42._4();
            Session apply = Session$.MODULE$.apply(graph, Session$.MODULE$.apply$default$2(), Session$.MODULE$.apply$default$3());
            Tensor tensor = (Tensor) apply.run(FeedMap$.MODULE$.feedMap(Predef$.MODULE$.Map().apply(Predef$.MODULE$.wrapRefArray(new Tuple2[]{Predef$ArrowAssoc$.MODULE$.$minus$greater$extension(Predef$.MODULE$.ArrowAssoc(output), Tensor$.MODULE$.apply(BoxesRunTime.boxToDouble(2.0d), Predef$.MODULE$.wrapDoubleArray(new double[]{5.0d}), TensorConvertible$.MODULE$.supportedTypeTensorConvertible(SupportedType$.MODULE$.doubleIsSupportedType()))), Predef$ArrowAssoc$.MODULE$.$minus$greater$extension(Predef$.MODULE$.ArrowAssoc(output2), Tensor$.MODULE$.apply(BoxesRunTime.boxToDouble(-1.3d), Predef$.MODULE$.wrapDoubleArray(new double[]{3.1d}), TensorConvertible$.MODULE$.supportedTypeTensorConvertible(SupportedType$.MODULE$.doubleIsSupportedType()))), Predef$ArrowAssoc$.MODULE$.$minus$greater$extension(Predef$.MODULE$.ArrowAssoc(output3), Tensor$.MODULE$.apply(BoxesRunTime.boxToDouble(8.9d), Predef$.MODULE$.wrapDoubleArray(new double[]{-4.1d}), TensorConvertible$.MODULE$.supportedTypeTensorConvertible(SupportedType$.MODULE$.doubleIsSupportedType())))})), Feedable$.MODULE$.outputFeedable()), output4, apply.run$default$3(), apply.run$default$4(), Executable$.MODULE$.traversableExecutable(Executable$.MODULE$.opExecutable()), Fetchable$.MODULE$.outputFetchable());
            TripleEqualsSupport.Equalizer convertToEqualizer = this.convertToEqualizer(tensor.dataType());
            DataType.Aux FLOAT64 = org.platanios.tensorflow.api.types.package$.MODULE$.FLOAT64();
            this.assertionsHelper().macroAssert(Bool$.MODULE$.binaryMacroBool(convertToEqualizer, "===", FLOAT64, convertToEqualizer.$eq$eq$eq(FLOAT64, Equality$.MODULE$.default()), Prettifier$.MODULE$.default()), "", Prettifier$.MODULE$.default(), new Position("CallbackSuite.scala", "Please set the environment variable SCALACTIC_FILL_FILE_PATHNAMES to yes at compile time to enable this feature.", 64));
            TripleEqualsSupport.Equalizer convertToEqualizer2 = this.convertToEqualizer(tensor.shape());
            Shape apply2 = Shape$.MODULE$.apply(Predef$.MODULE$.wrapIntArray(new int[]{2}));
            this.assertionsHelper().macroAssert(Bool$.MODULE$.binaryMacroBool(convertToEqualizer2, "===", apply2, convertToEqualizer2.$eq$eq$eq(apply2, Equality$.MODULE$.default()), Prettifier$.MODULE$.default()), "", Prettifier$.MODULE$.default(), new Position("CallbackSuite.scala", "Please set the environment variable SCALACTIC_FILL_FILE_PATHNAMES to yes at compile time to enable this feature.", 65));
            TripleEqualsSupport.Equalizer convertToEqualizer3 = this.convertToEqualizer(tensor.apply(Predef$.MODULE$.wrapRefArray(new Indexer[]{package$.MODULE$.intToIndex(0)})).scalar());
            this.assertionsHelper().macroAssert(Bool$.MODULE$.binaryMacroBool(convertToEqualizer3, "===", BoxesRunTime.boxToDouble(9.6d), convertToEqualizer3.$eq$eq$eq(BoxesRunTime.boxToDouble(9.6d), Equality$.MODULE$.default()), Prettifier$.MODULE$.default()), "", Prettifier$.MODULE$.default(), new Position("CallbackSuite.scala", "Please set the environment variable SCALACTIC_FILL_FILE_PATHNAMES to yes at compile time to enable this feature.", 66));
            TripleEqualsSupport.Equalizer convertToEqualizer4 = this.convertToEqualizer(tensor.apply(Predef$.MODULE$.wrapRefArray(new Indexer[]{package$.MODULE$.intToIndex(1)})).scalar());
            return this.assertionsHelper().macroAssert(Bool$.MODULE$.binaryMacroBool(convertToEqualizer4, "===", BoxesRunTime.boxToDouble(4.0d), convertToEqualizer4.$eq$eq$eq(BoxesRunTime.boxToDouble(4.0d), Equality$.MODULE$.default()), Prettifier$.MODULE$.default()), "", Prettifier$.MODULE$.default(), new Position("CallbackSuite.scala", "Please set the environment variable SCALACTIC_FILL_FILE_PATHNAMES to yes at compile time to enable this feature.", 67));
        });
    }
}
