From b0c6a2e18c9e870b5c8ef592ea7e8331f24d3eda Mon Sep 17 00:00:00 2001 From: lixiaoquan Date: Sat, 24 Apr 2021 07:26:48 +0800 Subject: [PATCH] [ConvertLayout] Keep span in ConvertLayout (#7895) --- src/relay/transforms/convert_layout.cc | 2 +- src/relay/transforms/forward_rewrite.cc | 2 +- src/relay/transforms/transform_layout.h | 2 +- tests/python/relay/test_pass_convert_op_layout.py | 13 +++++++++++++ 4 files changed, 16 insertions(+), 3 deletions(-) diff --git a/src/relay/transforms/convert_layout.cc b/src/relay/transforms/convert_layout.cc index 1293be70273e4..f2eb02e1b9bea 100644 --- a/src/relay/transforms/convert_layout.cc +++ b/src/relay/transforms/convert_layout.cc @@ -121,7 +121,7 @@ class ConvertTransformMemorizer : public TransformMemorizer { const CallNode* new_call = new_e.as(); ICHECK(new_call) << "Can only replace the original operator with another call node"; - return GetRef(new_call); + return Call(new_call->op, new_call->args, new_call->attrs, new_call->type_args, ref_call->span); } using ContainerType = ConvertTransformMemorizerNode; diff --git a/src/relay/transforms/forward_rewrite.cc b/src/relay/transforms/forward_rewrite.cc index be2d37477eb6b..1212ad7f19be9 100644 --- a/src/relay/transforms/forward_rewrite.cc +++ b/src/relay/transforms/forward_rewrite.cc @@ -168,7 +168,7 @@ class ForwardRewriter : private MixedModeMutator { } } if (unchanged) return ref_call; - return Call(new_op, call_args, call_node->attrs, call_node->type_args); + return Call(new_op, call_args, call_node->attrs, call_node->type_args, call_node->span); } }; diff --git a/src/relay/transforms/transform_layout.h b/src/relay/transforms/transform_layout.h index 35fb176c6bca2..3ac8ea224b79b 100644 --- a/src/relay/transforms/transform_layout.h +++ b/src/relay/transforms/transform_layout.h @@ -363,7 +363,7 @@ Expr LayoutRewriter(const Call& ref_call, const Array& new_args, const Obj } else { auto rnode = make_object>(); ICHECK_EQ(new_out.size(), 1); - rnode->value = Call(new_call->op, transformed_args, new_call->attrs); + rnode->value = Call(new_call->op, transformed_args, new_call->attrs, {}, ref_call->span); rnode->old_layout = old_out[0]; rnode->new_layout = new_out[0]; rnode->memorizer = memorizer; diff --git a/tests/python/relay/test_pass_convert_op_layout.py b/tests/python/relay/test_pass_convert_op_layout.py index 7eccc4a82c705..dd2dc979a7316 100644 --- a/tests/python/relay/test_pass_convert_op_layout.py +++ b/tests/python/relay/test_pass_convert_op_layout.py @@ -1703,6 +1703,11 @@ def expected(): assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) def _test_conv_reduce_convert_layout2(): + def _set_span(y, text): + return relay.Call( + y.op, y.args, y.attrs, y.type_args, relay.Span(relay.SourceName(text), 0, 0, 0, 0) + ) + def before(): x = relay.var("x", shape=(1, 38, 38, 512)) weight = relay.var("weight", shape=(3, 3, 512, 512)) @@ -1714,9 +1719,13 @@ def before(): data_layout="NHWC", kernel_layout="HWIO", ) + y = _set_span(y, "SpanConv2D") y = relay.nn.relu(y) + y = _set_span(y, "SpanRelu") y = relay.multiply(y, y) + y = _set_span(y, "SpanMultiply") y = relay.sum(y, axis=(3,), keepdims=True) + y = _set_span(y, "SpanSum") return relay.Function(analysis.free_vars(y), y) def expected(): @@ -1733,6 +1742,10 @@ def expected(): a = before() a = run_opt_pass(a, transform.ConvertLayout({"nn.conv2d": ["NCHW", "default"]})) + assert "SpanConv2D" in a.astext() + assert "SpanRelu" in a.astext() + assert "SpanMultiply" in a.astext() + assert "SpanSum" in a.astext() b = run_opt_pass(expected(), transform.InferType()) assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)