package org.platanios.tensorflow.examples.python2scala;

import com.typesafe.scalalogging.Logger;
import com.typesafe.scalalogging.Logger$;
import java.io.BufferedInputStream;
import java.io.BufferedWriter;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileWriter;
import java.nio.file.Path;
import java.nio.file.Paths;
import org.platanios.tensorflow.api.core.Graph$;
import org.platanios.tensorflow.api.core.client.FeedMap$;
import org.platanios.tensorflow.api.core.client.Session;
import org.platanios.tensorflow.api.core.types.package$TF$;
import org.platanios.tensorflow.api.implicits.helpers.OutputToTensor$;
import org.platanios.tensorflow.api.implicits.helpers.TensorStructure$;
import org.platanios.tensorflow.api.ops.Op;
import org.platanios.tensorflow.api.ops.Output;
import org.platanios.tensorflow.api.package$;
import org.platanios.tensorflow.api.package$tf$;
import org.platanios.tensorflow.api.package$tfi$;
import org.platanios.tensorflow.api.tensors.Tensor;
import org.platanios.tensorflow.api.utilities.DefaultsTo$;
import org.platanios.tensorflow.proto.MetaGraphDef;
import org.slf4j.LoggerFactory;
import scala.$less$colon$less$;
import scala.MatchError;
import scala.Predef$;
import scala.Predef$ArrowAssoc$;
import scala.Tuple2;
import scala.collection.ArrayOps$;
import scala.collection.IterableOnceOps;
import scala.collection.StringOps$;
import scala.collection.immutable.Map;
import scala.collection.immutable.Seq;
import scala.collection.immutable.Seq$;
import scala.collection.mutable.ArrayBuffer;
import scala.collection.mutable.ArrayBuffer$;
import scala.runtime.BoxedUnit;
import scala.runtime.BoxesRunTime;
import scala.runtime.RichInt$;
import scala.runtime.ScalaRunTime$;
import scala.util.Random;

/* compiled from: LinearRegressionFromRestoredPythonModel.scala */
/* loaded from: input_file:org/platanios/tensorflow/examples/python2scala/LinearRegressionFromRestoredPythonModel$.class */
public final class LinearRegressionFromRestoredPythonModel$ {
    public static final LinearRegressionFromRestoredPythonModel$ MODULE$ = new LinearRegressionFromRestoredPythonModel$();
    private static final Logger logger = Logger$.MODULE$.apply(LoggerFactory.getLogger("Examples / Linear Regression"));
    private static final Random random = new Random(22);
    private static final float weight = MODULE$.random().nextFloat();
    private static final String checkpoint = "examples/src/main/resources/python2scala/linear-regression";
    private static final File meta = new File(MODULE$.getClass().getClassLoader().getResource("python2scala/linear-regression.meta").getFile());
    private static final File metaGraphDefFile = new File(MODULE$.getClass().getClassLoader().getResource("python2scala/MetaGraphDef.txt").getFile());

    private Logger logger() {
        return logger;
    }

    private Random random() {
        return random;
    }

    private float weight() {
        return weight;
    }

    private String checkpoint() {
        return checkpoint;
    }

    private File meta() {
        return meta;
    }

    private File metaGraphDefFile() {
        return metaGraphDefFile;
    }

    public void main(String[] strArr) {
        MetaGraphDef parseFrom = MetaGraphDef.parseFrom(new BufferedInputStream(new FileInputStream(meta())));
        Path path = Paths.get(checkpoint(), new String[0]);
        BufferedWriter bufferedWriter = new BufferedWriter(new FileWriter(metaGraphDefFile()));
        bufferedWriter.write(parseFrom.toString());
        bufferedWriter.close();
        package$tf$.MODULE$.createWith(Graph$.MODULE$.apply(), package$tf$.MODULE$.createWith$default$2(), package$tf$.MODULE$.createWith$default$3(), package$tf$.MODULE$.createWith$default$4(), package$tf$.MODULE$.createWith$default$5(), package$tf$.MODULE$.createWith$default$6(), package$tf$.MODULE$.createWith$default$7(), package$tf$.MODULE$.createWith$default$8(), () -> {
            Session apply = package$.MODULE$.Session().apply(package$.MODULE$.Session().apply$default$1(), package$.MODULE$.Session().apply$default$2(), package$.MODULE$.Session().apply$default$3());
            package$tf$.MODULE$.Saver().fromMetaGraphDef(parseFrom, package$tf$.MODULE$.Saver().fromMetaGraphDef$default$2(), package$tf$.MODULE$.Saver().fromMetaGraphDef$default$3(), package$tf$.MODULE$.Saver().fromMetaGraphDef$default$4(), package$tf$.MODULE$.Saver().fromMetaGraphDef$default$5(), package$tf$.MODULE$.Saver().fromMetaGraphDef$default$6(), package$tf$.MODULE$.Saver().fromMetaGraphDef$default$7(), package$tf$.MODULE$.Saver().fromMetaGraphDef$default$8(), package$tf$.MODULE$.Saver().fromMetaGraphDef$default$9(), package$tf$.MODULE$.Saver().fromMetaGraphDef$default$10()).restore(apply, path);
            Predef$.MODULE$.println(StringOps$.MODULE$.$times$extension(Predef$.MODULE$.augmentString(" -"), 40));
            ArrayOps$.MODULE$.foreach$extension(Predef$.MODULE$.refArrayOps(apply.graph().ops()), op -> {
                $anonfun$main$2(op);
                return BoxedUnit.UNIT;
            });
            Predef$.MODULE$.println(StringOps$.MODULE$.$times$extension(Predef$.MODULE$.augmentString(" -"), 40));
            Output<Object> outputByName = apply.graph().getOutputByName("p2s_input:0");
            Output<Object> outputByName2 = apply.graph().getOutputByName("p2s_output:0");
            Output<Object> outputByName3 = apply.graph().getOutputByName("p2s_weights_w/Read/ReadVariableOp:0");
            Output<Object> outputByName4 = apply.graph().getOutputByName("p2s_weights_b/Read/ReadVariableOp:0");
            Output<Object> outputByName5 = apply.graph().getOutputByName("p2s_prediction:0");
            Output<Object> outputByName6 = apply.graph().getOutputByName("p2s_loss:0");
            Op<Seq<Output<Object>>, Seq<Output<Object>>> opByName = apply.graph().getOpByName("p2s_train_op");
            MODULE$.printRestoredNodesAndOperations(apply, outputByName, outputByName2, outputByName3, outputByName4, outputByName5, outputByName6, opByName);
            RichInt$.MODULE$.to$extension(Predef$.MODULE$.intWrapper(0), 50).foreach$mVc$sp(i -> {
                Tuple2<Tensor<Object>, Tensor<Object>> batch = MODULE$.batch(10000);
                if (batch == null) {
                    throw new MatchError(batch);
                }
                Tuple2 tuple2 = new Tuple2((Tensor) batch._1(), (Tensor) batch._2());
                Seq seq = (Seq) apply.run(FeedMap$.MODULE$.apply((Map) Predef$.MODULE$.Map().apply(ScalaRunTime$.MODULE$.wrapRefArray(new Tuple2[]{Predef$ArrowAssoc$.MODULE$.$minus$greater$extension(Predef$.MODULE$.ArrowAssoc(outputByName), (Tensor) tuple2._1()), Predef$ArrowAssoc$.MODULE$.$minus$greater$extension(Predef$.MODULE$.ArrowAssoc(outputByName2), (Tensor) tuple2._2())})), package$.MODULE$.evStructureUntyped(), OutputToTensor$.MODULE$.fromOutput(), TensorStructure$.MODULE$.fromTensor()), Seq$.MODULE$.apply(ScalaRunTime$.MODULE$.wrapRefArray(new Output[]{outputByName6, outputByName3, outputByName4})), opByName, apply.run$default$4(), DefaultsTo$.MODULE$.defaultDefaultsTo(), package$.MODULE$.evStructureSeqUntyped(), DefaultsTo$.MODULE$.fallback(), package$.MODULE$.evStructureUntypedOp(), OutputToTensor$.MODULE$.fromSeq(OutputToTensor$.MODULE$.fromOutput()));
                Tensor tensor = (Tensor) seq.apply(0);
                Tensor tensor2 = (Tensor) seq.apply(1);
                Tensor tensor3 = (Tensor) seq.apply(2);
                if (!MODULE$.logger().underlying().isInfoEnabled()) {
                    BoxedUnit boxedUnit = BoxedUnit.UNIT;
                } else {
                    MODULE$.logger().underlying().info("\nTrain loss at iteration {} = {}\nTrain weight at iteration {} = {}\nTrain bias at iteration {} = {}\n", new Object[]{BoxesRunTime.boxToInteger(i + 1), tensor.scalar(), BoxesRunTime.boxToInteger(i + 1), tensor2.scalar(), BoxesRunTime.boxToInteger(i + 1), tensor3.scalar()});
                    BoxedUnit boxedUnit2 = BoxedUnit.UNIT;
                }
            });
        });
    }

    public Tuple2<Tensor<Object>, Tensor<Object>> batch(int i) {
        ArrayBuffer empty = ArrayBuffer$.MODULE$.empty();
        ArrayBuffer empty2 = ArrayBuffer$.MODULE$.empty();
        RichInt$.MODULE$.until$extension(Predef$.MODULE$.intWrapper(0), i).foreach(obj -> {
            return $anonfun$batch$1(empty, empty2, BoxesRunTime.unboxToInt(obj));
        });
        return new Tuple2<>(package$tfi$.MODULE$.reshape(package$tfi$.MODULE$.stack(((IterableOnceOps) empty.map(obj2 -> {
            return $anonfun$batch$2(BoxesRunTime.unboxToDouble(obj2));
        })).toSeq(), package$tfi$.MODULE$.stack$default$2(), package$TF$.MODULE$.doubleEvTF()), package$.MODULE$.Tensor().apply(ScalaRunTime$.MODULE$.wrapRefArray(new Tensor[]{package$.MODULE$.intToTensor(-1), package$.MODULE$.intToTensor(1)}), package$TF$.MODULE$.intEvTF()), package$TF$.MODULE$.doubleEvTF(), package$TF$.MODULE$.intEvTF(), $less$colon$less$.MODULE$.refl()), package$tfi$.MODULE$.reshape(package$tfi$.MODULE$.stack(((IterableOnceOps) empty2.map(obj3 -> {
            return $anonfun$batch$3(BoxesRunTime.unboxToDouble(obj3));
        })).toSeq(), package$tfi$.MODULE$.stack$default$2(), package$TF$.MODULE$.doubleEvTF()), package$.MODULE$.Tensor().apply(ScalaRunTime$.MODULE$.wrapRefArray(new Tensor[]{package$.MODULE$.intToTensor(-1), package$.MODULE$.intToTensor(1)}), package$TF$.MODULE$.intEvTF()), package$TF$.MODULE$.doubleEvTF(), package$TF$.MODULE$.intEvTF(), $less$colon$less$.MODULE$.refl()));
    }

    public void printRestoredNodesAndOperations(Session session, Output<Object> output, Output<Object> output2, Output<Object> output3, Output<Object> output4, Output<Object> output5, Output<Object> output6, Op<Seq<Output<Object>>, Seq<Output<Object>>> op) {
        Predef$.MODULE$.println(StringOps$.MODULE$.$times$extension(Predef$.MODULE$.augmentString(" *"), 60));
        myPrintln(new StringBuilder(22).append("Trained weight value: ").append(((Tensor) session.run(session.run$default$1(), output3, session.run$default$3(), session.run$default$4(), DefaultsTo$.MODULE$.fallback(), package$.MODULE$.evStructureUntyped(), DefaultsTo$.MODULE$.defaultDefaultsTo(), package$.MODULE$.evStructureSetUntypedOps(), OutputToTensor$.MODULE$.fromOutput())).scalar()).toString(), 90);
        myPrintln(new StringBuilder(20).append("Trained bias value: ").append(((Tensor) session.run(session.run$default$1(), output4, session.run$default$3(), session.run$default$4(), DefaultsTo$.MODULE$.fallback(), package$.MODULE$.evStructureUntyped(), DefaultsTo$.MODULE$.defaultDefaultsTo(), package$.MODULE$.evStructureSetUntypedOps(), OutputToTensor$.MODULE$.fromOutput())).scalar()).toString(), 90);
        myPrintln(String.valueOf(output), 116);
        myPrintln(String.valueOf(output2), 116);
        myPrintln(String.valueOf(output3), 116);
        myPrintln(String.valueOf(output4), 116);
        myPrintln(String.valueOf(output5), 116);
        myPrintln(String.valueOf(output6), 116);
        Predef$.MODULE$.println(StringOps$.MODULE$.$times$extension(Predef$.MODULE$.augmentString(" *"), 60));
        Predef$.MODULE$.println("");
    }

    public void myPrintln(String str, int i) {
        Predef$.MODULE$.println(String.format(new StringBuilder(9).append("| %1$-").append(i).append("s |").toString(), str));
    }

    public static final /* synthetic */ void $anonfun$main$2(Op op) {
        MODULE$.myPrintln(new StringBuilder(19).append("OPERATION name:    ").append(op).toString(), 76);
    }

    public static final /* synthetic */ ArrayBuffer $anonfun$batch$1(ArrayBuffer arrayBuffer, ArrayBuffer arrayBuffer2, int i) {
        arrayBuffer.$plus$eq(BoxesRunTime.boxToDouble(MODULE$.random().nextFloat()));
        return arrayBuffer2.$plus$eq(BoxesRunTime.boxToDouble(MODULE$.weight() * r0));
    }

    public static final /* synthetic */ Tensor $anonfun$batch$2(double d) {
        return package$.MODULE$.Tensor().apply(ScalaRunTime$.MODULE$.wrapRefArray(new Tensor[]{package$.MODULE$.doubleToTensor(d)}), package$TF$.MODULE$.doubleEvTF());
    }

    public static final /* synthetic */ Tensor $anonfun$batch$3(double d) {
        return package$.MODULE$.Tensor().apply(ScalaRunTime$.MODULE$.wrapRefArray(new Tensor[]{package$.MODULE$.doubleToTensor(d)}), package$TF$.MODULE$.doubleEvTF());
    }

    private LinearRegressionFromRestoredPythonModel$() {
    }
}
