Skip to content

Commit

Permalink
[TOPI-ARM] Do not alter layout if layout is NHWC (apache#5350)
Browse files Browse the repository at this point in the history
* [TOPI-ARM] Do not alter layout if layout is NHWC

* Add test.
  • Loading branch information
anijain2305 authored and trevor-m committed Jun 18, 2020
1 parent 455b4b5 commit 224252c
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 25 deletions.
29 changes: 4 additions & 25 deletions tests/python/relay/test_pass_alter_op_layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -940,11 +940,8 @@ def expected_nhwc():
assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)


# TODO(@anijain2305, @icemelon9): We should fix this. This doesn't seem to be the
# right behavior of alter_layout
@pytest.mark.skip
def test_alter_layout_nhwc_nchw_arm():
""" Check NHWC to NHCW conversion for a small sequence of ops."""
def test_alter_layout_nhwc_arm():
""" Check that AlterOplayout does not alter NHWC data layout. """
def alter_conv2d(attrs, inputs, tinfos, out_type):
import topi
with tvm.target.create("llvm -device=arm_cpu"):
Expand Down Expand Up @@ -974,25 +971,7 @@ def before_nhwc():
return y

def expected_nhwc():
x = relay.var("x", shape=(1, 56, 56, 64))
weight1 = relay.var('weight1', shape=(3, 3, 64, 64))
weight2 = relay.var('weight2', shape=(3, 3, 64, 64))
y = relay.layout_transform(x, "NHWC", "NCHW")
weight1 = relay.layout_transform(weight1, "HWIO", "OIHW")
weight2 = relay.layout_transform(weight2, "HWIO", "OIHW")
y = relay.nn.conv2d(y, weight1,
channels=64,
kernel_size=(3, 3))
y = relay.nn.relu(y)
y = relay.nn.avg_pool2d(y,
pool_size=(1,1))
y = relay.nn.conv2d(y, weight2,
channels=64,
kernel_size=(3, 3))
y = relay.nn.relu(y)
y = relay.layout_transform(y, "NCHW", "NHWC")
y = relay.Function(analysis.free_vars(y), y)
return y
return before_nhwc()

with TempOpAttr("nn.conv2d", "FTVMAlterOpLayout", alter_conv2d):
a = before_nhwc()
Expand Down Expand Up @@ -1060,5 +1039,5 @@ def expected():
test_alter_layout_pad()
test_alter_layout_pool()
test_alter_layout_sum()
# test_alter_layout_nhwc_nchw_arm()
test_alter_layout_nhwc_arm()
test_alter_op_with_global_var()
4 changes: 4 additions & 0 deletions topi/python/topi/arm_cpu/conv2d_alter_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,10 @@ def _alter_conv2d_layout(attrs, inputs, tinfos, out_type):
data, kernel = tinfos
out_dtype = out_type.dtype

# We only perform layout alteration for NCHW data layout.
if data_layout == "NHWC":
return None

# Extract data types
data_tensor, kernel_tensor = tinfos
data_dtype = data_tensor.dtype
Expand Down

0 comments on commit 224252c

Please sign in to comment.