Skip to content

Commit

Permalink
add test
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi committed Dec 15, 2022
1 parent d111237 commit 500cfcf
Showing 1 changed file with 17 additions and 0 deletions.
17 changes: 17 additions & 0 deletions tests/python/unittest/test_meta_schedule_relay_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,23 @@ def test_meta_schedule_integration_extract_from_resnet():
assert t.task_name in expected_task_names, t.task_name


@requires_torch
def test_task_extraction_winograd_tensorcore():
mod, params, _ = get_network(name="resnet_50", input_shape=[16, 3, 224, 224])
seq = tvm.transform.Sequential(
[
relay.transform.ToMixedPrecision("float16"),
relay.transform.ConvertLayout({"nn.conv2d": ["NHWC", "HWIO"]})
]
)
with tvm.transform.PassContext(opt_level=3):
mod = seq(mod)

extracted_tasks = ms.relay_integration.extract_tasks(mod, target="cuda", params=params)

assert len([t for t in extracted_tasks if "winograd" in t.task_name]) == 4


@requires_torch
def test_task_extraction_anchor_block():
mod, params, _ = get_network(name="resnet_18", input_shape=[1, 3, 224, 224])
Expand Down

0 comments on commit 500cfcf

Please sign in to comment.