diff --git a/tests/python/relay/test_pass_alter_op_layout.py b/tests/python/relay/test_pass_alter_op_layout.py index 2a2e265dbe5b..9b18f72cb6e7 100644 --- a/tests/python/relay/test_pass_alter_op_layout.py +++ b/tests/python/relay/test_pass_alter_op_layout.py @@ -940,11 +940,8 @@ def expected_nhwc(): assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) -# TODO(@anijain2305, @icemelon9): We should fix this. This doesn't seem to be the -# right behavior of alter_layout -@pytest.mark.skip -def test_alter_layout_nhwc_nchw_arm(): - """ Check NHWC to NHCW conversion for a small sequence of ops.""" +def test_alter_layout_nhwc_arm(): + """ Check that AlterOplayout does not alter NHWC data layout. """ def alter_conv2d(attrs, inputs, tinfos, out_type): import topi with tvm.target.create("llvm -device=arm_cpu"): @@ -974,25 +971,7 @@ def before_nhwc(): return y def expected_nhwc(): - x = relay.var("x", shape=(1, 56, 56, 64)) - weight1 = relay.var('weight1', shape=(3, 3, 64, 64)) - weight2 = relay.var('weight2', shape=(3, 3, 64, 64)) - y = relay.layout_transform(x, "NHWC", "NCHW") - weight1 = relay.layout_transform(weight1, "HWIO", "OIHW") - weight2 = relay.layout_transform(weight2, "HWIO", "OIHW") - y = relay.nn.conv2d(y, weight1, - channels=64, - kernel_size=(3, 3)) - y = relay.nn.relu(y) - y = relay.nn.avg_pool2d(y, - pool_size=(1,1)) - y = relay.nn.conv2d(y, weight2, - channels=64, - kernel_size=(3, 3)) - y = relay.nn.relu(y) - y = relay.layout_transform(y, "NCHW", "NHWC") - y = relay.Function(analysis.free_vars(y), y) - return y + return before_nhwc() with TempOpAttr("nn.conv2d", "FTVMAlterOpLayout", alter_conv2d): a = before_nhwc() @@ -1060,5 +1039,5 @@ def expected(): test_alter_layout_pad() test_alter_layout_pool() test_alter_layout_sum() - # test_alter_layout_nhwc_nchw_arm() + test_alter_layout_nhwc_arm() test_alter_op_with_global_var() diff --git a/topi/python/topi/arm_cpu/conv2d_alter_op.py b/topi/python/topi/arm_cpu/conv2d_alter_op.py index 3d194cce6534..221ccce1a2a0 100644 --- a/topi/python/topi/arm_cpu/conv2d_alter_op.py +++ b/topi/python/topi/arm_cpu/conv2d_alter_op.py @@ -59,6 +59,10 @@ def _alter_conv2d_layout(attrs, inputs, tinfos, out_type): data, kernel = tinfos out_dtype = out_type.dtype + # We only perform layout alteration for NCHW data layout. + if data_layout == "NHWC": + return None + # Extract data types data_tensor, kernel_tensor = tinfos data_dtype = data_tensor.dtype