Skip to content
This repository has been archived by the owner on May 22, 2023. It is now read-only.

[TVMScript] Enhance parser error reporting #242

Merged
merged 1 commit into from
Sep 8, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 15 additions & 10 deletions python/tvm/script/parser/core/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from collections import defaultdict
from contextlib import contextmanager
from typing import Any, Callable, Dict, List, Optional, Set, Union
from tvm._ffi.base import TVMError

from tvm.error import DiagnosticError

Expand Down Expand Up @@ -100,7 +101,7 @@ def _wrapper(self: "Parser", node: doc.AST) -> None:
except DiagnosticError:
raise
except Exception as e: # pylint: disable=broad-except,invalid-name
self.report_error(node, str(e))
self.report_error(node, e)
raise

return _wrapper
Expand Down Expand Up @@ -185,7 +186,14 @@ def eval_assign(
self.var_table.add(k, var)
return var_values

def report_error(self, node: doc.AST, msg: str) -> None: # pylint: disable=no-self-use
def report_error(
self, node: doc.AST, err: Union[Exception, str]
) -> None: # pylint: disable=no-self-use
# Only take the last line of the error message
if isinstance(err, TVMError):
msg = list(filter(None, str(err).split("\n")))[-1]
else:
msg = str(err)
self.diag.error(node, msg)

def visit(self, node: doc.AST) -> None:
Expand All @@ -204,8 +212,11 @@ def visit(self, node: doc.AST) -> None:
raise NotImplementedError(f"Visitor of AST node is not implemented: {name}")
try:
func(node)
except DiagnosticError:
raise
except Exception as e: # pylint: disable=broad-except,invalid-name
self.report_error(node, str(e))
raise

def visit_body(self, node: List[doc.stmt]) -> Any:
for stmt in node:
Expand All @@ -225,19 +236,13 @@ def visit_FunctionDef(self, node: doc.FunctionDef) -> Any: # pylint: disable=in
func = dispatch.get(token=token, type_name="FunctionDef", default=None)
if func is None:
self.report_error(node, "The parser does not understand the decorator")
try:
func(self, node)
except Exception as e: # pylint: disable=broad-except,invalid-name
self.report_error(node, str(e))
_dispatch_wrapper(func)(self, node)

def visit_ClassDef(self, node: doc.ClassDef) -> Any: # pylint: disable=invalid-name
func = dispatch.get(token="ir", type_name="ClassDef", default=None)
if func is None:
self.report_error(node, "The parser does not understand the decorator")
try:
func(self, node)
except Exception as e: # pylint: disable=broad-except,invalid-name
self.report_error(node, str(e))
_dispatch_wrapper(func)(self, node)

def visit_arguments(self, node: doc.arguments) -> Any:
return _dispatch(self, "arguments")(self, node)
Expand Down
11 changes: 11 additions & 0 deletions tests/python/relax/test_tvmscript_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

import pytest
from typing import Union
import tvm
import tvm.testing
Expand Down Expand Up @@ -47,5 +49,14 @@ def foo(x: R.Tensor((128, 128), "float32")) -> R.Tensor(None, "float32", ndim=2)
_check(foo, bb.get()["foo"])


def test_error_report():
with pytest.raises(tvm.error.DiagnosticError):

@R.function
def foo(x: R.Tensor((128, 128), "float32")) -> R.Tensor(None, "float32", ndim=2):
gv0 = gv1 = R.call_tir("extern_func", x, (128, 128), dtype="float32")
return gv0


if __name__ == "__main__":
tvm.testing.main()