Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Relay][Frontend] Span Filling TensorFlow 1 #13728

Merged
merged 1 commit into from
Jan 9, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 29 additions & 27 deletions python/tvm/relay/frontend/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
from .common import infer_type as _infer_type
from .common import infer_shape as _infer_shape
from .common import infer_value as _infer_value
from .common import set_span

from .tensorflow_ops import _convert_map
from .tensorflow_ops import _need_prelude_for_shape_inference
Expand Down Expand Up @@ -328,7 +329,7 @@ def _while_loop(self):
`while_loop` construct.
"""
bind_map = {}
wl = tvm.relay.var("while_loop")
wl = set_span(tvm.relay.var("while_loop"), self._loop_name)
sb = tvm.relay.scope_builder.ScopeBuilder()

lv_list = []
Expand All @@ -345,7 +346,7 @@ def _while_loop(self):
if lv not in self._lvar2expr[self._loop_name]:
var_name = "{}_loop_var_{}".format(self._loop_name, i)
var_type = _infer_type(lv, self._mod).checked_type
loop_var = tvm.relay.var(var_name, type_annotation=var_type)
loop_var = set_span(tvm.relay.var(var_name, type_annotation=var_type), var_name)
self._lvar2expr[self._loop_name][loop_var] = lv
bind_map[lv] = loop_var
self.loop_vars[i] = loop_var
Expand All @@ -358,7 +359,7 @@ def _while_loop(self):
self.cond = rewrite_subgraph(self.cond, bind_map)
self.body = [rewrite_subgraph(b, bind_map) for b in self.body]

cond = tvm.relay.op.min(self.cond)
cond = set_span(tvm.relay.op.min(self.cond), self.cond.span)

for lv, exp in self._lvar2expr[self._loop_name].items():
if lv not in self.loop_vars:
Expand Down Expand Up @@ -517,8 +518,11 @@ def _get_relay_func(self, graph, layout="NHWC", shape=None, outputs=None):
self._output_shapes[node.name] = [self._input_shapes[node.name]]
attr = self._parse_attr(node.attr)
self._nodes[node.name] = [
_expr.var(
node.name, shape=self._input_shapes[node.name], dtype=attr["dtype"].name
set_span(
_expr.var(
node.name, shape=self._input_shapes[node.name], dtype=attr["dtype"].name
),
node.name,
)
]

Expand Down Expand Up @@ -708,16 +712,23 @@ def _parse_param(self, key, value, name, shape):
var_shape = shape[name]
else:
var_shape = tensor_util.TensorShapeProtoToList(value.tensor.tensor_shape)
self._nodes[name] = [_expr.var(name, shape=var_shape, dtype="uint8")]
self._nodes[name] = [
set_span(_expr.var(name, shape=var_shape, dtype="uint8"), span=name)
]
return

array_ndim = len(np_array.shape)
if array_ndim == 0:
self._nodes[name] = [tvm.relay.const(np_array, np_array.dtype)]
self._nodes[name] = [set_span(tvm.relay.const(np_array, np_array.dtype), name)]
else:
self._params[name] = tvm.nd.array(np_array)
self._nodes[name] = [
_expr.var(name, shape=self._params[name].shape, dtype=self._params[name].dtype)
set_span(
_expr.var(
name, shape=self._params[name].shape, dtype=self._params[name].dtype
),
name,
)
]
else:
if key not in ("dtype", "_output_shapes", "_class"):
Expand Down Expand Up @@ -998,6 +1009,8 @@ def _convert_operator(
----------
op_name : str
Operator name, such as Conv2D, AvgPool
node_name : str
Node name, predefined by user or default setting of TF
inputs : list of relay.op
List of input symbols.
attrs : dict
Expand Down Expand Up @@ -1028,22 +1041,8 @@ def _convert_operator(
else:
raise NotImplementedError("Operator {} not implemented.".format(op_name))

sym = self._set_span(sym, node_name)

return sym
sym = set_span(sym, node_name)

@staticmethod
def _set_span(sym, node_name):
span = tvm.relay.Span(tvm.relay.SourceName(node_name), 0, 0, 0, 0)
if isinstance(sym, _expr.Call) and sym.span is None:
sym = _expr.Call(sym.op, sym.args, sym.attrs, sym.type_args, span)
elif isinstance(sym, _expr.TupleWrapper):
tuple_value = sym.tuple_value
if isinstance(tuple_value, _expr.Call) and tuple_value.span is None:
tuple_value = _expr.Call(
tuple_value.op, tuple_value.args, tuple_value.attrs, tuple_value.type_args, span
)
sym = _expr.TupleWrapper(tuple_value, sym.size)
return sym

def _licm_construct(self, loop_name, node_name):
Expand Down Expand Up @@ -1079,7 +1078,7 @@ def _licm_construct(self, loop_name, node_name):
if node_name not in self._lname_map[loop_name]:
var_name = "{}_loop_var".format(node_name)
var_type = _infer_type(actual_expr, self._mod).checked_type
loop_var = tvm.relay.var(var_name, type_annotation=var_type)
loop_var = set_span(tvm.relay.var(var_name, type_annotation=var_type), var_name)
try:
extra_param = _infer_value(actual_expr, self._params, self._mod)
self._params[var_name] = extra_param
Expand Down Expand Up @@ -1183,10 +1182,13 @@ def _backtrack_construct(self, node_name):
if isinstance(op, np.ndarray):
self._params[node.name] = tvm.nd.array(op)
op = [
_expr.var(
set_span(
_expr.var(
node.name,
shape=self._params[node.name].shape,
dtype=self._params[node.name].dtype,
),
node.name,
shape=self._params[node.name].shape,
dtype=self._params[node.name].dtype,
)
]

Expand Down
6 changes: 5 additions & 1 deletion tests/python/frontend/tensorflow/test_bn_dynamic.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,11 @@ def verify_fused_batch_norm(shape):
if not tvm.testing.device_enabled(device):
print("Skip because %s is not enabled" % device)
continue
mod, params = relay.frontend.from_tensorflow(constant_graph, outputs=["output"])
with tvm.testing.disable_span_filling():
mod, params = relay.frontend.from_tensorflow(constant_graph, outputs=["output"])
with tvm.testing.enable_span_filling():
mod_with_span, _ = relay.frontend.from_tensorflow(constant_graph, outputs=["output"])
assert tvm.ir.structural_equal(mod["main"], mod_with_span["main"])
with tvm.transform.PassContext(opt_level=3):
graph, lib, params = relay.build(mod, target=device, params=params)
from tvm.contrib import graph_executor
Expand Down
10 changes: 7 additions & 3 deletions tests/python/frontend/tensorflow/test_control_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,17 @@
import tensorflow as tf
from tensorflow.python.ops import control_flow_ops
import numpy as np
from tvm import nd
from tvm import relay
from tvm import nd, relay, ir, testing
from tvm.relay.frontend.tensorflow import from_tensorflow


def check_equal(graph, tf_out, input_map=None):
mod, params = from_tensorflow(graph.as_graph_def(add_shapes=True))
with testing.disable_span_filling():
mod, params = from_tensorflow(graph.as_graph_def(add_shapes=True))
with testing.enable_span_filling():
mod_with_span, _ = from_tensorflow(graph.as_graph_def(add_shapes=True))
assert ir.structural_equal(mod["main"], mod_with_span["main"])

if input_map is not None:
params.update(input_map)
relay_out = relay.create_executor("vm", mod=mod).evaluate()(**params)
Expand Down
11 changes: 9 additions & 2 deletions tests/python/frontend/tensorflow/test_debugging.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,19 @@
except ImportError:
import tensorflow as tf
import numpy as np
from tvm import relay
from tvm import relay, ir, testing
from tvm.relay.frontend.tensorflow import from_tensorflow


def run_relay(graph, shape_dict=None, *vars):
mod, params = from_tensorflow(graph.as_graph_def(add_shapes=True), shape=shape_dict)
with testing.disable_span_filling():
mod, params = from_tensorflow(graph.as_graph_def(add_shapes=True), shape=shape_dict)
with testing.enable_span_filling():
mod_with_span, _ = relay.frontend.from_tensorflow(
graph.as_graph_def(add_shapes=True), shape=shape_dict
)
assert ir.structural_equal(mod["main"], mod_with_span["main"])

return relay.create_executor("debug", mod=mod).evaluate()(*vars)


Expand Down
Loading