Skip to content
This repository has been archived by the owner on May 22, 2023. It is now read-only.

Commit

Permalink
enhance parser error reporting (#242)
Browse files Browse the repository at this point in the history
  • Loading branch information
Hzfengsy authored Sep 8, 2022
1 parent 97ffe01 commit f1709ac
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 10 deletions.
25 changes: 15 additions & 10 deletions python/tvm/script/parser/core/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from collections import defaultdict
from contextlib import contextmanager
from typing import Any, Callable, Dict, List, Optional, Set, Union
from tvm._ffi.base import TVMError

from tvm.error import DiagnosticError

Expand Down Expand Up @@ -100,7 +101,7 @@ def _wrapper(self: "Parser", node: doc.AST) -> None:
except DiagnosticError:
raise
except Exception as e: # pylint: disable=broad-except,invalid-name
self.report_error(node, str(e))
self.report_error(node, e)
raise

return _wrapper
Expand Down Expand Up @@ -185,7 +186,14 @@ def eval_assign(
self.var_table.add(k, var)
return var_values

def report_error(self, node: doc.AST, msg: str) -> None: # pylint: disable=no-self-use
def report_error(
self, node: doc.AST, err: Union[Exception, str]
) -> None: # pylint: disable=no-self-use
# Only take the last line of the error message
if isinstance(err, TVMError):
msg = list(filter(None, str(err).split("\n")))[-1]
else:
msg = str(err)
self.diag.error(node, msg)

def visit(self, node: doc.AST) -> None:
Expand All @@ -204,8 +212,11 @@ def visit(self, node: doc.AST) -> None:
raise NotImplementedError(f"Visitor of AST node is not implemented: {name}")
try:
func(node)
except DiagnosticError:
raise
except Exception as e: # pylint: disable=broad-except,invalid-name
self.report_error(node, str(e))
raise

def visit_body(self, node: List[doc.stmt]) -> Any:
for stmt in node:
Expand All @@ -225,19 +236,13 @@ def visit_FunctionDef(self, node: doc.FunctionDef) -> Any: # pylint: disable=in
func = dispatch.get(token=token, type_name="FunctionDef", default=None)
if func is None:
self.report_error(node, "The parser does not understand the decorator")
try:
func(self, node)
except Exception as e: # pylint: disable=broad-except,invalid-name
self.report_error(node, str(e))
_dispatch_wrapper(func)(self, node)

def visit_ClassDef(self, node: doc.ClassDef) -> Any: # pylint: disable=invalid-name
func = dispatch.get(token="ir", type_name="ClassDef", default=None)
if func is None:
self.report_error(node, "The parser does not understand the decorator")
try:
func(self, node)
except Exception as e: # pylint: disable=broad-except,invalid-name
self.report_error(node, str(e))
_dispatch_wrapper(func)(self, node)

def visit_arguments(self, node: doc.arguments) -> Any:
return _dispatch(self, "arguments")(self, node)
Expand Down
11 changes: 11 additions & 0 deletions tests/python/relax/test_tvmscript_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

import pytest
from typing import Union
import tvm
import tvm.testing
Expand Down Expand Up @@ -47,5 +49,14 @@ def foo(x: R.Tensor((128, 128), "float32")) -> R.Tensor(None, "float32", ndim=2)
_check(foo, bb.get()["foo"])


def test_error_report():
with pytest.raises(tvm.error.DiagnosticError):

@R.function
def foo(x: R.Tensor((128, 128), "float32")) -> R.Tensor(None, "float32", ndim=2):
gv0 = gv1 = R.call_tir("extern_func", x, (128, 128), dtype="float32")
return gv0


if __name__ == "__main__":
tvm.testing.main()

0 comments on commit f1709ac

Please sign in to comment.