diff --git a/src/printer/tvmscript_printer.cc b/src/printer/tvmscript_printer.cc index cc7536b48cfd..df02a6906a0c 100644 --- a/src/printer/tvmscript_printer.cc +++ b/src/printer/tvmscript_printer.cc @@ -1069,7 +1069,7 @@ Doc TVMScriptPrinter::PrintLoop(const For& loop) { res << Print(loop->thread_binding.value()->thread_tag); } if (!loop->annotations.empty()) { - res << ", annotation = {"; + res << ", annotations = {"; res << PrintAnnotations(loop->annotations); res << "}"; } diff --git a/tests/python/unittest/test_tvmscript_roundtrip.py b/tests/python/unittest/test_tvmscript_roundtrip.py index 0566ff5044d9..f9aee67f1d71 100644 --- a/tests/python/unittest/test_tvmscript_roundtrip.py +++ b/tests/python/unittest/test_tvmscript_roundtrip.py @@ -2803,7 +2803,9 @@ def for_thread_binding(a: ty.handle, b: ty.handle) -> None: B = tir.match_buffer(b, (16, 16), "float32") for i in tir.thread_binding(0, 16, thread="threadIdx.x"): - for j in tir.thread_binding(0, 16, thread="threadIdx.y"): + for j in tir.thread_binding( + 0, 16, thread="threadIdx.y", annotations={"attr_key": "attr_value"} + ): A[i, j] = B[i, j] + tir.float32(1) @@ -2818,6 +2820,7 @@ def test_for_thread_binding(): assert isinstance(rt_func.body.body, tir.stmt.For) assert rt_func.body.body.kind == 4 assert rt_func.body.body.thread_binding.thread_tag == "threadIdx.y" + assert rt_func.body.body.annotations["attr_key"] == "attr_value" @tvm.script.tir