diff --git a/python/tvm/script/parser.py b/python/tvm/script/parser.py index 9acf21b6ba3a..60fc49678866 100644 --- a/python/tvm/script/parser.py +++ b/python/tvm/script/parser.py @@ -536,7 +536,7 @@ def transform_SubscriptAssign(self, node): if len(indexes) != 1: self.report_error( f"Store is only allowed with one index, but {len(indexes)} were provided.", - tvm.ir.Span.union([x.span for x in indexes]), + node.params[1].span, ) # Store return tvm.tir.Store( diff --git a/tests/python/unittest/test_tvmscript_error_report.py b/tests/python/unittest/test_tvmscript_error_report.py index 7aeceeccfa89..70a2aea11293 100644 --- a/tests/python/unittest/test_tvmscript_error_report.py +++ b/tests/python/unittest/test_tvmscript_error_report.py @@ -14,6 +14,9 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. + +import pytest +import sys import tvm from tvm import tir from tvm.script import ty, from_source @@ -380,6 +383,17 @@ def test_match_buffer_shape_mismatch(): check_error(buffer_shape_mismatch, 7) +def high_dim_store() -> None: + with tir.block([], "root"): + B = tir.allocate([256], "float32", "global") + for i, j in tir.grid(16, 16): + B[i, j] = 1.0 # error: Store is only allowed with one index + + +def test_high_dim_store(): + check_error(high_dim_store, 5) + + def check_error(module, rel_lineno): # Override the default renderer to accumulate errors _, start_line = inspect.getsourcelines(module) @@ -404,31 +418,4 @@ def render(e): if __name__ == "__main__": - test_buffer_bind() - test_range_missing_args() - test_undefined_buffer() - test_unsupported_stmt() - test_unsupported_function_call() - test_missing_type_annotation() - test_invalid_expr_stmt() - test_invalid_for_function() - test_invalid_block_function() - test_return_not_allowed() - test_tir_assert() - test_no_body() - test_allocate_with_buffers() - test_inconsistent_binding() - test_invalid_block_axes() - test_miss_block_bind() - test_invalid_loop_var() - test_inconsistent_grid() - test_invalid_match_buffer_region() - test_duplicate_buffer() - test_duplicate_block_signature() - test_opaque_access_during_complete() - test_convert_slice_to_bufferload() - test_error_index_type() - test_error_index_with_stop_slice() - test_mismatch_args() - test_tvm_exception_catch() - test_match_buffer_shape_mismatch() + sys.exit(pytest.main([__file__] + sys.argv[1:]))