/* Copyright 2017-19, Emmanouil Antonios Platanios. All Rights Reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License"); you may not
 * use this file except in compliance with the License. You may obtain a copy of
 * the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 * License for the specific language governing permissions and limitations under
 * the License.
 */

package org.platanios.tensorflow.api.core

import org.platanios.tensorflow.api._
import org.platanios.tensorflow.api.core.exception.{GraphMismatchException, InvalidArgumentException}
import org.platanios.tensorflow.api.ops.{Op, UntypedOp}
import org.platanios.tensorflow.api.ops.Op.createWith
import org.platanios.tensorflow.api.ops.basic.Basic
import org.platanios.tensorflow.api.ops.math.Math
import org.platanios.tensorflow.api.tensors.Tensor

import org.scalatest.flatspec.AnyFlatSpec
import org.scalatest.matchers.should.Matchers

/**
  * @author Emmanouil Antonios Platanios
  */
class GraphSpec extends AnyFlatSpec with Matchers {
  private[this] def prepareGraph(): (Graph, Array[UntypedOp]) = {
    val graph = Graph()
    val ops = createWith(graph = graph) {
      val c1 = Basic.constant(Tensor(1.0), name = "C_1")
      val c2 = Basic.constant(Tensor(2.0), name = "C_2")
      val c3 = createWith(nameScope = "Nested") {
        Basic.constant(Tensor(3.0), name = "C_3")
      }
      val c4 = Basic.constant(Tensor(4.0), name = "C_4")
      Array(c1.op, c2.op, c3.op, c4.op)
    }
    (graph, ops)
  }

  // TODO: Add collections specification.

  "'preventFeeding'" must "prevent valid ops from being fetched" in {
    val (graph, ops) = prepareGraph()
    assert(graph.isFeedable(ops(0).outputsSeq(0)))
    assert(graph.isFeedable(ops(1).outputsSeq(0)))
    assert(graph.isFeedable(ops(2).outputsSeq(0)))
    assert(graph.isFeedable(ops(3).outputsSeq(0)))
    graph.preventFeeding(ops(0).outputsSeq(0))
    assert(!graph.isFeedable(ops(0).outputsSeq(0)))
    assert(graph.isFeedable(ops(1).outputsSeq(0)))
    assert(graph.isFeedable(ops(2).outputsSeq(0)))
    assert(graph.isFeedable(ops(3).outputsSeq(0)))
    graph.preventFeeding(ops(2).outputsSeq(0))
    assert(!graph.isFeedable(ops(0).outputsSeq(0)))
    assert(graph.isFeedable(ops(1).outputsSeq(0)))
    assert(!graph.isFeedable(ops(2).outputsSeq(0)))
    assert(graph.isFeedable(ops(3).outputsSeq(0)))
  }

  it must "throw a 'GraphMismatchException' when provided ops from other graphs" in {
    val (graph, ops) = prepareGraph()
    createWith(graph = Graph()) {
      assert(intercept[GraphMismatchException](graph.isFeedable(Basic.constant(1.0))).getMessage ===
          "The provided op output does not belong to this graph.")
      assert(intercept[GraphMismatchException](graph.preventFeeding(Basic.constant(1.0))).getMessage ===
          "The provided op output does not belong to this graph.")
    }
  }

  "'preventFetching'" must "prevent valid ops from being fetched" in {
    val (graph, ops) = prepareGraph()
    assert(graph.isFetchable(ops(0)))
    assert(graph.isFetchable(ops(1)))
    assert(graph.isFetchable(ops(2)))
    assert(graph.isFetchable(ops(3)))
    graph.preventFetching(ops(0))
    assert(!graph.isFetchable(ops(0)))
    assert(graph.isFetchable(ops(1)))
    assert(graph.isFetchable(ops(2)))
    assert(graph.isFetchable(ops(3)))
    graph.preventFetching(ops(2))
    assert(!graph.isFetchable(ops(0)))
    assert(graph.isFetchable(ops(1)))
    assert(!graph.isFetchable(ops(2)))
    assert(graph.isFetchable(ops(3)))
  }

  it must "throw a 'GraphMismatchException' when provided ops from other graphs" in {
    val (graph, ops) = prepareGraph()
    createWith(graph = Graph()) {
      assert(intercept[GraphMismatchException](graph.isFetchable(Basic.constant(1.0).op)).getMessage ===
          "The provided op does not belong to this graph.")
      assert(intercept[GraphMismatchException](graph.preventFetching(Basic.constant(1.0).op)).getMessage ===
          "The provided op does not belong to this graph.")
    }
  }

  "'findOp'" must "return an existing op in a graph" in {
    val (graph, ops) = prepareGraph()
    assert(graph.findOp("C_2").get === ops(1))
  }

  it must "return 'None' if an op name does not exist in the graph" in {
    val (graph, _) = prepareGraph()
    assert(graph.findOp("A") === None)
  }

  "'ops'" must "return all the ops in a graph" in {
    val (graph, ops) = prepareGraph()
    assert(graph.ops === ops)
  }

  "'opByName'" must "return an existing op in a graph" in {
    val (graph, ops) = prepareGraph()
    assert(graph.getOpByName("C_2") === ops(1))
  }

  it must "throw an 'InvalidArgumentException' exception with an informative message " +
      "if an op name does not exist in the graph" in {
    val (graph, _) = prepareGraph()
    assert(intercept[InvalidArgumentException](graph.getOpByName("A")).getMessage
        === "Name 'A' refers to an op which does not exist in the graph.")
    assert(intercept[InvalidArgumentException](graph.getOpByName("A:0")).getMessage
        === "Name 'A:0' appears to refer to an op output, but 'allowOutput' was set to 'false'.")
  }

  "'outputByName'" must "return an existing op output in a graph" in {
    val (graph, ops) = prepareGraph()
    assert(graph.getOutputByName("C_2:0") == ops(1).outputsSeq(0))
  }

  it must "throw an 'InvalidArgumentException' exception with an informative message " +
      "if an op output name does not exist in the graph" in {
    val (graph, _) = prepareGraph()
    assert(intercept[InvalidArgumentException](graph.getOutputByName("A:0:3")).getMessage
        === "Name 'A:0:3' looks a like an op output name, but it is not a valid one. " +
        "Op output names must be of the form \"<op_name>:<output_index>\".")
    assert(intercept[InvalidArgumentException](graph.getOutputByName("A:0")).getMessage
        === "Name 'A:0' refers to an op output which does not exist in the graph. " +
        "More specifically, op, 'A', does not exist in the graph.")
    assert(intercept[InvalidArgumentException](graph.getOutputByName("C_2:5")).getMessage
        === "Name 'C_2:5' refers to an op output which does not exist in the graph. " +
        "More specifically, op, 'C_2', does exist in the graph, but it only has 1 output(s).")
    assert(intercept[InvalidArgumentException](graph.getOutputByName("A")).getMessage
        === "Name 'A' looks like an (invalid) op name, and not an op output name. " +
        "Op output names must be of the form \"<op_name>:<output_index>\".")
    assert(intercept[InvalidArgumentException](graph.getOutputByName("C_2")).getMessage
        === "Name 'C_2' appears to refer to an op, but 'allowOp' was set to 'false'.")
  }

  "'graphElementByName'" must "return an existing element in a graph" in {
    val (graph, ops) = prepareGraph()
    graph.getByName("C_2").left.foreach(op => assert(op === ops(1)))
    graph.getByName("C_2:0").foreach(output => assert(output == ops(1).outputsSeq.head))
  }

  it must "throw an 'InvalidArgumentException' exception with an informative message " +
      "if an element name does not exist in the graph" in {
    val (graph, _) = prepareGraph()
    assert(intercept[InvalidArgumentException](
      graph.getByName("A", allowOp = true, allowOutput = true)).getMessage
        === "Name 'A' refers to an op which does not exist in the graph.")
    assert(intercept[InvalidArgumentException](
      graph.getByName("A:0:3", allowOp = true, allowOutput = true)).getMessage
        === "Name 'A:0:3' looks a like an op output name, but it is not a valid one. " +
        "Op output names must be of the form \"<op_name>:<output_index>\".")
    assert(intercept[InvalidArgumentException](
      graph.getByName("A:0", allowOp = true, allowOutput = true)).getMessage
        === "Name 'A:0' refers to an op output which does not exist in the graph. " +
        "More specifically, op, 'A', does not exist in the graph.")
    assert(intercept[InvalidArgumentException](
      graph.getByName("C_2:5", allowOp = true, allowOutput = true)).getMessage
        === "Name 'C_2:5' refers to an op output which does not exist in the graph. " +
        "More specifically, op, 'C_2', does exist in the graph, but it only has 1 output(s).")
    assert(intercept[IllegalArgumentException](
      graph.getByName("A", allowOp = false, allowOutput = false)).getMessage
        === "'allowOutput' and 'allowOp' cannot both be set to 'false'.")
  }

  object INPUTS extends Graph.Keys.OutputCollectionKey {
    override def name: String = "inputs"
  }

  object OUTPUTS extends Graph.Keys.OutputCollectionKey {
    override def name: String = "outputs"
  }

  "'Graph.toMetaGraphDef'" must "work when no scope is provided" in {
    val graph = Graph()
    val session = Session(graph)

    Op.createWith(graph) {
      // Create a minimal graph with zero variables.
      val input = Basic.placeholder[Float](Shape(), name = "Input")
      val offset = Basic.constant(42.0f, name = "Offset")
      val output = Math.add(input, offset, name = "AddOffset")

      // Add input and output tensors to graph collections.
      graph.addToCollection(INPUTS)(input.asInstanceOf[Output[Any]])
      graph.addToCollection(OUTPUTS)(output.asInstanceOf[Output[Any]])

      val outputValue = session.run(Map(input -> Tensor(-10f)), output)
      assert(outputValue.scalar == 32)
    }

    // Generate the 'MetaGraphDef' object.
    val metaGraphDef = graph.toMetaGraphDef(collections = Set(INPUTS, OUTPUTS))
    assert(metaGraphDef.hasMetaInfoDef)
    assert(metaGraphDef.getMetaInfoDef.getTensorflowVersion !== "")
    // assert(metaGraphDef.getMetaInfoDef.getTensorflowGitVersion !== "")

    session.close()

    // Create a clean graph and import the 'MetaGraphDef' object.
    val newGraph = Graph()
    val newSession = Session(newGraph)

    newGraph.importMetaGraphDef(metaGraphDef)

    // Re-exports the current graph state for comparison to the original.
    val newMetaGraphDef = newGraph.toMetaGraphDef()
    // TODO: [PROTO] Utility functions for ProtoBuf comparisons.
    // assert(newMetaGraphDef.equals(metaGraphDef))

    // Ensure that we can still get a reference to our graph collections.
    val newInput = newGraph.getCollection(INPUTS).head.asInstanceOf[Output[Float]]
    val newOutput = newGraph.getCollection(OUTPUTS).head.asInstanceOf[Output[Float]]

    // Verify that the new graph computes the same result as the original.
    val newOutputValue = newSession.run(Map(newInput -> Tensor(-10f)), newOutput)
    assert(newOutputValue.scalar == 32.0f)

    newSession.close()
  }
}
