diff --git a/tests/python/relay/test_pass_alter_op_layout.py b/tests/python/relay/test_pass_alter_op_layout.py index 3cee84f6a4dd5..c8b5f3f8da60f 100644 --- a/tests/python/relay/test_pass_alter_op_layout.py +++ b/tests/python/relay/test_pass_alter_op_layout.py @@ -1353,6 +1353,48 @@ def alter_conv2d(attrs, inputs, tinfos, out_type): assert before.body.attrs.layout == "NCHW" +def test_alter_op_dense_packed_data(): + def before(): + x = relay.var("x", shape=(1, 32, 8, 8)) + weight = relay.var("conv2d_weight", shape=(32, 32, 3, 3)) + conv = relay.nn.conv2d(x, weight, channels=32, kernel_size=(3, 3), padding=(1, 1)) + pool = relay.nn.avg_pool2d(conv, pool_size=[8, 8], padding=[0, 0, 0, 0]) + squeeze = relay.squeeze(pool, axis=[2, 3]) + dense = relay.nn.dense(squeeze, relay.var("dense_weight", shape=(16, 32))) + return relay.Function(analysis.free_vars(dense), dense) + + def expected(): + x = relay.var("x", shape=(1, 32, 8, 8)) + conv_weight = relay.var("conv2d_weight", shape=(32, 32, 3, 3)) + dense_weight = relay.var("dense_weight", shape=(16, 32)) + conv = relay.nn.contrib_conv2d_nchwc( + relay.layout_transform(x, "NCHW", "NCHW8c"), + relay.layout_transform(conv_weight, "OIHW", "OIHW8i8o"), + channels=32, + kernel_size=(3, 3), + padding=(1, 1), + data_layout="NCHW8c", + kernel_layout="OIHW8i8o", + out_layout="NCHW8c", + ) + pool = relay.nn.avg_pool2d(conv, pool_size=[8, 8], padding=[0, 0, 0, 0], layout="NCHW8c") + squeeze = relay.squeeze(pool, axis=[2, 3]) + dense = relay.nn.contrib_dense_pack( + relay.layout_transform(squeeze, "NC8c", "NC"), + relay.layout_transform(dense_weight, "NK", "NK16n"), + out_dtype="float32", + ) + return relay.Function(analysis.free_vars(dense), dense) + + with tvm.target.Target("llvm"): + with TempOpAttr( + "nn.dense", "FTVMAlterOpLayout", topi.x86.dense_alter_op._alter_dense_layout + ): + a = run_opt_pass(before(), transform.AlterOpLayout()) + b = run_opt_pass(expected(), transform.InferType()) + assert tvm.ir.structural_equal(a, b) + + if __name__ == "__main__": test_alter_op() test_alter_return_none() @@ -1377,3 +1419,4 @@ def alter_conv2d(attrs, inputs, tinfos, out_type): test_alter_op_dense() test_alter_layout_strided_slice_axes_nhwc() test_not_inplace_modify() + test_alter_op_dense_packed_data()