Skip to content

Commit

Permalink
add test
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi committed Aug 5, 2021
1 parent 7425fca commit 1c2ec67
Showing 1 changed file with 43 additions and 0 deletions.
43 changes: 43 additions & 0 deletions tests/python/relay/test_pass_alter_op_layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -1353,6 +1353,48 @@ def alter_conv2d(attrs, inputs, tinfos, out_type):
assert before.body.attrs.layout == "NCHW"


def test_alter_op_dense_packed_data():
def before():
x = relay.var("x", shape=(1, 32, 8, 8))
weight = relay.var("conv2d_weight", shape=(32, 32, 3, 3))
conv = relay.nn.conv2d(x, weight, channels=32, kernel_size=(3, 3), padding=(1, 1))
pool = relay.nn.avg_pool2d(conv, pool_size=[8, 8], padding=[0, 0, 0, 0])
squeeze = relay.squeeze(pool, axis=[2, 3])
dense = relay.nn.dense(squeeze, relay.var("dense_weight", shape=(16, 32)))
return relay.Function(analysis.free_vars(dense), dense)

def expected():
x = relay.var("x", shape=(1, 32, 8, 8))
conv_weight = relay.var("conv2d_weight", shape=(32, 32, 3, 3))
dense_weight = relay.var("dense_weight", shape=(16, 32))
conv = relay.nn.contrib_conv2d_nchwc(
relay.layout_transform(x, "NCHW", "NCHW8c"),
relay.layout_transform(conv_weight, "OIHW", "OIHW8i8o"),
channels=32,
kernel_size=(3, 3),
padding=(1, 1),
data_layout="NCHW8c",
kernel_layout="OIHW8i8o",
out_layout="NCHW8c",
)
pool = relay.nn.avg_pool2d(conv, pool_size=[8, 8], padding=[0, 0, 0, 0], layout="NCHW8c")
squeeze = relay.squeeze(pool, axis=[2, 3])
dense = relay.nn.contrib_dense_pack(
relay.layout_transform(squeeze, "NC8c", "NC"),
relay.layout_transform(dense_weight, "NK", "NK16n"),
out_dtype="float32",
)
return relay.Function(analysis.free_vars(dense), dense)

with tvm.target.Target("llvm"):
with TempOpAttr(
"nn.dense", "FTVMAlterOpLayout", topi.x86.dense_alter_op._alter_dense_layout
):
a = run_opt_pass(before(), transform.AlterOpLayout())
b = run_opt_pass(expected(), transform.InferType())
assert tvm.ir.structural_equal(a, b)


if __name__ == "__main__":
test_alter_op()
test_alter_return_none()
Expand All @@ -1377,3 +1419,4 @@ def alter_conv2d(attrs, inputs, tinfos, out_type):
test_alter_op_dense()
test_alter_layout_strided_slice_axes_nhwc()
test_not_inplace_modify()
test_alter_op_dense_packed_data()

0 comments on commit 1c2ec67

Please sign in to comment.