From 928d6ab2ff1e99a7a82985ec95f9e5a116ba8d19 Mon Sep 17 00:00:00 2001 From: tqchen Date: Sat, 16 Mar 2019 15:51:14 -0700 Subject: [PATCH] [RUNTIME] Scaffold structured error handling. --- 3rdparty/dmlc-core | 2 +- docs/api/python/error.rst | 5 + docs/api/python/index.rst | 1 + docs/contribute/error_handling.rst | 118 +++++++++++ docs/contribute/index.rst | 1 + python/tvm/__init__.py | 1 + python/tvm/_ffi/_ctypes/function.py | 28 ++- python/tvm/_ffi/_cython/base.pxi | 4 +- python/tvm/_ffi/_cython/function.pxi | 3 +- python/tvm/_ffi/base.py | 218 ++++++++++++++++++-- python/tvm/error.py | 96 +++++++++ src/api/api_test.cc | 38 ++++ src/runtime/c_runtime_api.cc | 180 +++++++++++++++- src/runtime/runtime_base.h | 5 +- tests/python/unittest/test_runtime_error.py | 55 +++++ 15 files changed, 710 insertions(+), 45 deletions(-) create mode 100644 docs/api/python/error.rst create mode 100644 docs/contribute/error_handling.rst create mode 100644 python/tvm/error.py create mode 100644 tests/python/unittest/test_runtime_error.py diff --git a/3rdparty/dmlc-core b/3rdparty/dmlc-core index 9acddddfc349..2b5b1ba9c110 160000 --- a/3rdparty/dmlc-core +++ b/3rdparty/dmlc-core @@ -1 +1 @@ -Subproject commit 9acddddfc349eda4ef99552d11cb905afeafed39 +Subproject commit 2b5b1ba9c1103f438d164aca32da7cffd8cd48e8 diff --git a/docs/api/python/error.rst b/docs/api/python/error.rst new file mode 100644 index 000000000000..3a1a5579fbc4 --- /dev/null +++ b/docs/api/python/error.rst @@ -0,0 +1,5 @@ +tvm.error +--------- +.. automodule:: tvm.error + :members: + :imported-members: diff --git a/docs/api/python/index.rst b/docs/api/python/index.rst index ddad9d10f8f9..16e03a6ca359 100644 --- a/docs/api/python/index.rst +++ b/docs/api/python/index.rst @@ -11,6 +11,7 @@ Python API target build module + error ndarray container function diff --git a/docs/contribute/error_handling.rst b/docs/contribute/error_handling.rst new file mode 100644 index 000000000000..00958ac5b5f7 --- /dev/null +++ b/docs/contribute/error_handling.rst @@ -0,0 +1,118 @@ +.. _error_guide: + +Error Handling Guide +==================== +TVM contains structured error classes to indicate specific types of error. +Please raise a specific error type when possible, so that users can +write code to handle a specific error category if necessary. + +All the error types are defined in :any:`tvm.error` namespace. +You can directly raise the specific error object in python. +In other languages like c++, you simply add ``:`` prefix to +the error message(see below). + +Raise a Specific Error in C++ +----------------------------- +You can add ``:`` prefix to your error message to +raise an error of the corresponding type. +Note that you do not have to add a new type +:any:`tvm.error.TVMError` will be raised by default when +there is no error type prefix in the message. +This mechanism works for both ``LOG(FATAL)`` and ``CHECK`` macros. +The following code gives an example on how to do so. + +.. code:: c + + // src/api_test.cc + void ErrorTest(int x, int y) { + CHECK_EQ(x, y) << "ValueError: expect x and y to be equal." + if (x == 1) { + LOG(FATAL) << "InternalError: cannot reach here"; + } + } + +The above function is registered as PackedFunc into the python frontend, +under the name ``tvm._api_internal._ErrorTest``. +Here is what will happen if we call the registered function: + +.. code:: + + >>> import tvm + >>> tvm._api_internal._ErrorTest(0, 1) + Traceback (most recent call last): + File "", line 1, in + File "/path/to/tvm/python/tvm/_ffi/_ctypes/function.py", line 190, in __call__ + raise get_last_ffi_error() + ValueError: Traceback (most recent call last): + [bt] (3) /path/to/tvm/build/libtvm.so(TVMFuncCall+0x48) [0x7fab500b8ca8] + [bt] (2) /path/to/tvm/build/libtvm.so(+0x1c4126) [0x7fab4f7f5126] + [bt] (1) /path/to/tvm/build/libtvm.so(+0x1ba2f8) [0x7fab4f7eb2f8] + [bt] (0) /path/to/tvm/build/libtvm.so(+0x177d12) [0x7fab4f7a8d12] + File "/path/to/tvm/src/api/api_test.cc", line 80 + ValueError: Check failed: x == y (0 vs. 1) : expect x and y to be equal. + >>> + >>> tvm._api_internal._ErrorTest(1, 1) + Traceback (most recent call last): + File "", line 1, in + File "/path/to/tvm/python/tvm/_ffi/_ctypes/function.py", line 190, in __call__ + raise get_last_ffi_error() + tvm.error.InternalError: Traceback (most recent call last): + [bt] (3) /path/to/tvm/build/libtvm.so(TVMFuncCall+0x48) [0x7fab500b8ca8] + [bt] (2) /path/to/tvm/build/libtvm.so(+0x1c4126) [0x7fab4f7f5126] + [bt] (1) /path/to/tvm/build/libtvm.so(+0x1ba35c) [0x7fab4f7eb35c] + [bt] (0) /path/to/tvm/build/libtvm.so(+0x177d12) [0x7fab4f7a8d12] + File "/path/to/tvm/src/api/api_test.cc", line 83 + InternalError: cannot reach here + TVM hint: You hit an internal error. Please open a thread on https://discuss.tvm.ai/ to report it. + +As you can see in the above example, TVM's ffi system combines +both the python and c++'s stacktrace into a single message, and generate the +corresponding error class automatically. + + +How to choose an Error Type +--------------------------- +You can go through the error types are listed below, try to us common +sense and also refer to the choices in the existing code. +We try to keep a reasonable amount of error types. +If you feel there is a need to add a new error type, do the following steps: + +- Send a RFC proposal with a description and usage examples in the current codebase. +- Add the new error type to :any:`tvm.error` with clear documents. +- Update the list in this file to include the new error type. +- Change the code to use the new error type. + +We also recommend to use less abstraction when creating the short error messages. +The code is more readable in this way, and also opens path to craft specific +error messages when necessary. + +.. code:: python + + def preferred(): + # very clear about what is being raised and what is the error message. + raise OpNotImplemented("Operator relu is not implemented in the MXNet fronend") + + def _op_not_implemented(op_name): + return OpNotImpelemented("Operator {} is not implemented.").format(op_name) + + def not_preferred(): + raise _op_not_implemented("relu") + + +System-wide Errors +------------------ + +.. autoclass:: tvm.error.TVMError + +.. autoclass:: tvm.error.InternalError + + +Frontend Errors +--------------- +.. autoclass:: tvm.error.OpNotImplemented + +.. autoclass:: tvm.error.OpAttributeInvalid + +.. autoclass:: tvm.error.OpAttributeRequired + +.. autoclass:: tvm.error.OpAttributeNotImplemented diff --git a/docs/contribute/index.rst b/docs/contribute/index.rst index 57217089724d..a8a227b1e1c8 100644 --- a/docs/contribute/index.rst +++ b/docs/contribute/index.rst @@ -28,5 +28,6 @@ Here are guidelines for contributing to various aspect of the project: committer_guide document code_guide + error_handling pull_request git_howto diff --git a/python/tvm/__init__.py b/python/tvm/__init__.py index 9c09dc5a4ac3..e17e6203ed49 100644 --- a/python/tvm/__init__.py +++ b/python/tvm/__init__.py @@ -19,6 +19,7 @@ from . import generic from . import hybrid from . import testing +from . import error from . import ndarray as nd from .ndarray import context, cpu, gpu, opencl, cl, vulkan, metal, mtl diff --git a/python/tvm/_ffi/_ctypes/function.py b/python/tvm/_ffi/_ctypes/function.py index 5c176f819105..2ead24af3b25 100644 --- a/python/tvm/_ffi/_ctypes/function.py +++ b/python/tvm/_ffi/_ctypes/function.py @@ -7,7 +7,7 @@ import traceback from numbers import Number, Integral -from ..base import _LIB, check_call +from ..base import _LIB, get_last_ffi_error, py2cerror from ..base import c_str, string_types from ..node_generic import convert_to_node, NodeGeneric from ..runtime_ctypes import TVMType, TVMByteArray, TVMContext @@ -55,6 +55,7 @@ def cfun(args, type_codes, num_args, ret, _): rv = local_pyfunc(*pyargs) except Exception: msg = traceback.format_exc() + msg = py2cerror(msg) _LIB.TVMAPISetLastError(c_str(msg)) return -1 @@ -65,7 +66,8 @@ def cfun(args, type_codes, num_args, ret, _): values, tcodes, _ = _make_tvm_args((rv,), temp_args) if not isinstance(ret, TVMRetValueHandle): ret = TVMRetValueHandle(ret) - check_call(_LIB.TVMCFuncSetReturn(ret, values, tcodes, ctypes.c_int(1))) + if _LIB.TVMCFuncSetReturn(ret, values, tcodes, ctypes.c_int(1)) != 0: + raise get_last_ffi_error() _ = temp_args _ = rv return 0 @@ -76,8 +78,9 @@ def cfun(args, type_codes, num_args, ret, _): # TVM_FREE_PYOBJ will be called after it is no longer needed. pyobj = ctypes.py_object(f) ctypes.pythonapi.Py_IncRef(pyobj) - check_call(_LIB.TVMFuncCreateFromCFunc( - f, pyobj, TVM_FREE_PYOBJ, ctypes.byref(handle))) + if _LIB.TVMFuncCreateFromCFunc( + f, pyobj, TVM_FREE_PYOBJ, ctypes.byref(handle)) != 0: + raise get_last_ffi_error() return _CLASS_FUNCTION(handle, False) @@ -168,7 +171,8 @@ def __init__(self, handle, is_global): def __del__(self): if not self.is_global and _LIB is not None: - check_call(_LIB.TVMFuncFree(self.handle)) + if _LIB.TVMFuncFree(self.handle) != 0: + raise get_last_ffi_error() def __call__(self, *args): """Call the function with positional arguments @@ -180,9 +184,10 @@ def __call__(self, *args): values, tcodes, num_args = _make_tvm_args(args, temp_args) ret_val = TVMValue() ret_tcode = ctypes.c_int() - check_call(_LIB.TVMFuncCall( - self.handle, values, tcodes, ctypes.c_int(num_args), - ctypes.byref(ret_val), ctypes.byref(ret_tcode))) + if _LIB.TVMFuncCall( + self.handle, values, tcodes, ctypes.c_int(num_args), + ctypes.byref(ret_val), ctypes.byref(ret_tcode)) != 0: + raise get_last_ffi_error() _ = temp_args _ = args return RETURN_SWITCH[ret_tcode.value](ret_val) @@ -194,9 +199,10 @@ def __init_handle_by_constructor__(fconstructor, args): values, tcodes, num_args = _make_tvm_args(args, temp_args) ret_val = TVMValue() ret_tcode = ctypes.c_int() - check_call(_LIB.TVMFuncCall( - fconstructor.handle, values, tcodes, ctypes.c_int(num_args), - ctypes.byref(ret_val), ctypes.byref(ret_tcode))) + if _LIB.TVMFuncCall( + fconstructor.handle, values, tcodes, ctypes.c_int(num_args), + ctypes.byref(ret_val), ctypes.byref(ret_tcode)) != 0: + raise get_last_ffi_error() _ = temp_args _ = args assert ret_tcode.value == TypeCode.NODE_HANDLE diff --git a/python/tvm/_ffi/_cython/base.pxi b/python/tvm/_ffi/_cython/base.pxi index feb2fffebd23..22234953ae97 100644 --- a/python/tvm/_ffi/_cython/base.pxi +++ b/python/tvm/_ffi/_cython/base.pxi @@ -1,4 +1,4 @@ -from ..base import TVMError +from ..base import get_last_ffi_error from libcpp.vector cimport vector from cpython.version cimport PY_MAJOR_VERSION from cpython cimport pycapsule @@ -148,7 +148,7 @@ cdef inline c_str(pystr): cdef inline CALL(int ret): if ret != 0: - raise TVMError(py_str(TVMGetLastError())) + raise get_last_ffi_error() cdef inline object ctypes_handle(void* chandle): diff --git a/python/tvm/_ffi/_cython/function.pxi b/python/tvm/_ffi/_cython/function.pxi index 9995aea6357a..b72d0b694752 100644 --- a/python/tvm/_ffi/_cython/function.pxi +++ b/python/tvm/_ffi/_cython/function.pxi @@ -2,7 +2,7 @@ import ctypes import traceback from cpython cimport Py_INCREF, Py_DECREF from numbers import Number, Integral -from ..base import string_types +from ..base import string_types, py2cerror from ..node_generic import convert_to_node, NodeGeneric from ..runtime_ctypes import TVMType, TVMContext, TVMByteArray @@ -38,6 +38,7 @@ cdef int tvm_callback(TVMValue* args, rv = local_pyfunc(*pyargs) except Exception: msg = traceback.format_exc() + msg = py2cerror(msg) TVMAPISetLastError(c_str(msg)) return -1 if rv is not None: diff --git a/python/tvm/_ffi/base.py b/python/tvm/_ffi/base.py index 98229c092792..ec37b468bd6d 100644 --- a/python/tvm/_ffi/base.py +++ b/python/tvm/_ffi/base.py @@ -1,6 +1,6 @@ # coding: utf-8 # pylint: disable=invalid-name -""" ctypes library of nnvm and helper functions """ +"""Base library for TVM FFI.""" from __future__ import absolute_import import sys @@ -30,10 +30,6 @@ py_str = lambda x: x -class TVMError(Exception): - """Error thrown by TVM function""" - - def _load_lib(): """Load libary by searching possible path.""" lib_path = libinfo.find_lib_path() @@ -56,21 +52,6 @@ def _load_lib(): #---------------------------- # helper function in ctypes. #---------------------------- -def check_call(ret): - """Check the return value of C API call - - This function will raise exception when error occurs. - Wrap every API call with this function - - Parameters - ---------- - ret : int - return value from API calls - """ - if ret != 0: - raise TVMError(py_str(_LIB.TVMGetLastError())) - - def c_str(string): """Create ctypes char * from a python string Parameters @@ -118,3 +99,200 @@ def decorate(func, fwrapped): """ import decorator return decorator.decorate(func, fwrapped) + + +#----------------------------------------- +# Base code for structured error handling. +#----------------------------------------- +# Maps error type to its constructor +ERROR_TYPE = {} + + +class TVMError(RuntimeError): + """Default error thrown by TVM functions. + + TVMError will be raised if you do not give any error type specification, + """ + + +def register_error(func_name=None, cls=None): + """Register an error class so it can be recognized by the ffi error handler. + + Parameters + ---------- + func_name : str or function or class + The name of the error function. + + cls : function + The function to create the class + + Returns + ------- + fregister : function + Register function if f is not specified. + + Examples + -------- + .. code-block:: python + + @tvm.error.register_error + class MyError(RuntimeError): + pass + + err_inst = tvm.error.create_ffi_error("MyError: xyz") + assert isinstance(err_inst, MyError) + """ + if callable(func_name): + cls = func_name + func_name = cls.__name__ + + def register(mycls): + """internal register function""" + err_name = func_name if isinstance(func_name, str) else mycls.__name__ + ERROR_TYPE[err_name] = mycls + return mycls + if cls is None: + return register + return register(cls) + + +def _valid_error_name(name): + """Check whether name is a valid error name.""" + return all(x.isalnum() or x == '_' or x == '.' for x in name) + + +def _find_error_type(line): + """Find the error name given the first line of the error message. + + Parameters + ---------- + line : str + The first line of error message. + + Returns + ------- + name : str The error name + """ + end_pos = line.find(":") + if end_pos == -1: + return None + err_name = line[:end_pos] + if _valid_error_name(err_name): + return err_name + return None + + +def c2pyerror(err_msg): + """Translate C API error message to python style. + + Parameters + ---------- + err_msg : str + The error message. + + Returns + ------- + new_msg : str + Translated message. + + err_type : str + Detected error type. + """ + arr = err_msg.split("\n") + if arr[-1] == "": + arr.pop() + err_type = _find_error_type(arr[0]) + trace_mode = False + stack_trace = [] + message = [] + for line in arr: + if trace_mode: + if line.startswith(" "): + stack_trace.append(line) + else: + trace_mode = False + if not trace_mode: + if line.startswith("Stack trace"): + trace_mode = True + else: + message.append(line) + out_msg = "" + if stack_trace: + out_msg += "Traceback (most recent call last):\n" + out_msg += "\n".join(reversed(stack_trace)) + "\n" + out_msg += "\n".join(message) + return out_msg, err_type + + +def py2cerror(err_msg): + """Translate python style error message to C style. + + Parameters + ---------- + err_msg : str + The error message. + + Returns + ------- + new_msg : str + Translated message. + """ + arr = err_msg.split("\n") + if arr[-1] == "": + arr.pop() + trace_mode = False + stack_trace = [] + message = [] + for line in arr: + if trace_mode: + if line.startswith(" "): + stack_trace.append(line) + else: + trace_mode = False + if not trace_mode: + if line.find("Traceback") != -1: + trace_mode = True + else: + message.append(line) + # Remove the first error name if there are two of them. + # RuntimeError: MyErrorName: message => MyErrorName: message + head_arr = message[0].split(":", 3) + if len(head_arr) >= 3 and _valid_error_name(head_arr[1].strip()): + head_arr[1] = head_arr[1].strip() + message[0] = ":".join(head_arr[1:]) + # reverse the stack trace. + out_msg = "\n".join(message) + if stack_trace: + out_msg += "\nStack trace:\n" + out_msg += "\n".join(reversed(stack_trace)) + "\n" + return out_msg + + +def get_last_ffi_error(): + """Create error object given result of TVMGetLastError. + + Returns + ------- + err : object + The error object based on the err_msg + """ + c_err_msg = py_str(_LIB.TVMGetLastError()) + py_err_msg, err_type = c2pyerror(c_err_msg) + if err_type.startswith("tvm.error."): + err_type = err_type[10:] + return ERROR_TYPE.get(err_type, TVMError)(py_err_msg) + + +def check_call(ret): + """Check the return value of C API call + + This function will raise exception when error occurs. + Wrap every API call with this function + + Parameters + ---------- + ret : int + return value from API calls + """ + if ret != 0: + raise get_last_ffi_error() diff --git a/python/tvm/error.py b/python/tvm/error.py new file mode 100644 index 000000000000..60c3b955f206 --- /dev/null +++ b/python/tvm/error.py @@ -0,0 +1,96 @@ +"""Structured error classes in TVM. + +Each error class takes an error message as its input. +See the example sections for for suggested message conventions. +To make the code more readable, we recommended developers to +copy the examples and raise errors with the same message convention. +""" +from ._ffi.base import register_error, TVMError + +@register_error +class InternalError(TVMError): + """Internal error in the system. + + Examples + -------- + .. code :: c++ + + // Example code C++ + LOG(FATAL) << "InternalError: internal error detail."; + + .. code :: python + + # Example code in python + raise InternalError("internal error detail") + """ + def __init__(self, msg): + # Patch up additional hint message. + if "TVM hint:" not in msg: + msg += ("\nTVM hint: You hit an internal error. " + + "Please open a thread on https://discuss.tvm.ai/ to report it.") + super(InternalError, self).__init__(msg) + + +register_error("ValueError", ValueError) +register_error("TypeError", TypeError) + + +@register_error +class OpError(TVMError): + """Base class of all operator errors in fronends.""" + + +@register_error +class OpNotImplemented(OpError): + """Operator is not implemented. + + Example + ------- + .. code:: python + + raise OpNotImplemented( + "Operator {} is not supported in {} frontend".format( + missing_op, frontend_name)) + """ + + +@register_error +class OpAttributeRequired(OpError): + """Required attribute is not found. + + Example + ------- + .. code:: python + + raise OpAttributeRequired( + "Required attribute {} not found in operator {}".format( + attr_name, op_name)) + """ + + +@register_error +class OpAttributeInvalid(OpError): + """Attribute value is invalid when taking in a frontend operator. + + Example + ------- + .. code:: python + + raise OpAttributeInvalid( + "Value {} in attribute {} of operator {} is not valid".format( + value, attr_name, op_name)) + """ + + +@register_error +class OpAttributeUnimplemented(OpError): + """Attribute is not supported in a certain frontend. + + Example + ------- + .. code:: python + + raise OpAttributeUnimplemented( + "Attribute {} is not supported in operator {}".format( + attr_name, op_name)) + """ diff --git a/src/api/api_test.cc b/src/api/api_test.cc index 2c637a28f01a..4bd480b3cc7e 100644 --- a/src/api/api_test.cc +++ b/src/api/api_test.cc @@ -39,6 +39,30 @@ TVM_REGISTER_API("_nop") .set_body([](TVMArgs args, TVMRetValue *ret) { }); +TVM_REGISTER_API("_test_wrap_callback") +.set_body([](TVMArgs args, TVMRetValue *ret) { + PackedFunc pf = args[0]; + *ret = runtime::TypedPackedFunc([pf](){ + pf(); + }); + }); + +TVM_REGISTER_API("_test_raise_error_callback") +.set_body([](TVMArgs args, TVMRetValue *ret) { + std::string msg = args[0]; + *ret = runtime::TypedPackedFunc([msg](){ + LOG(FATAL) << msg; + }); + }); + +TVM_REGISTER_API("_test_check_eq_callback") +.set_body([](TVMArgs args, TVMRetValue *ret) { + std::string msg = args[0]; + *ret = runtime::TypedPackedFunc([msg](int x, int y){ + CHECK_EQ(x, y) << msg; + }); + }); + TVM_REGISTER_API("_context_test") .set_body([](TVMArgs args, TVMRetValue *ret) { DLContext ctx = args[0]; @@ -49,6 +73,20 @@ TVM_REGISTER_API("_context_test") *ret = ctx; }); + +// in src/api_test.cc +void ErrorTest(int x, int y) { + // raise ValueError + CHECK_EQ(x, y) << "ValueError: expect x and y to be equal."; + if (x == 1) { + // raise InternalError. + LOG(FATAL) << "InternalError: cannot reach here"; + } +} + +TVM_REGISTER_API("_ErrorTest") +.set_body_typed(ErrorTest); + // internal function used for debug and testing purposes TVM_REGISTER_API("_ndarray_use_count") .set_body([](TVMArgs args, TVMRetValue *ret) { diff --git a/src/runtime/c_runtime_api.cc b/src/runtime/c_runtime_api.cc index d9435d33903d..3d74391a9ea5 100644 --- a/src/runtime/c_runtime_api.cc +++ b/src/runtime/c_runtime_api.cc @@ -13,10 +13,14 @@ #ifdef _LIBCPP_SGX_CONFIG #include "sgx/trusted/runtime.h" #endif +#ifndef _LIBCPP_SGX_NO_IOSTREAMS +#include +#endif #include #include #include #include +#include #include "runtime_base.h" namespace tvm { @@ -104,6 +108,169 @@ void DeviceAPI::SyncStreamFromTo(TVMContext ctx, TVMStreamHandle event_dst) { LOG(FATAL) << "Device does not support stream api."; } + +#ifndef _LIBCPP_SGX_NO_IOSTREAMS +//-------------------------------------------------------- +// Error handling mechanism +// ------------------------------------------------------- +// Standard error message format, {} means optional +//-------------------------------------------------------- +// {error_type:} {message0} +// {message1} +// {message2} +// {Stack trace:} // stack traces follow by this line +// {trace 0} // two spaces in the begining. +// {trace 1} +// {trace 2} +//-------------------------------------------------------- +/*! + * \brief Normalize error message + * + * Parse them header generated by by LOG(FATAL) and CHECK + * and reformat the message into the standard format. + * + * This function will also merge all the stack traces into + * one trace and trim them. + * + * \param err_msg The error message. + * \return normalized message. + */ +std::string NormalizeError(std::string err_msg) { + // ------------------------------------------------------------------------ + // log with header, {} indicates optional + //------------------------------------------------------------------------- + // [timestamp] file_name:line_number: {check_msg:} {error_type:} {message0} + // {message1} + // Stack trace: + // {stack trace 0} + // {stack trace 1} + //------------------------------------------------------------------------- + // Normalzied version + //------------------------------------------------------------------------- + // error_type: check_msg message0 + // {message1} + // Stack trace: + // File file_name, line lineno + // {stack trace 0} + // {stack trace 1} + //------------------------------------------------------------------------- + int line_number = 0; + std::istringstream is(err_msg); + std::string line, file_name, error_type, check_msg; + + // Parse log header and set the fields, + // Return true if it the log is in correct format, + // return false if something is wrong. + auto parse_log_header = [&]() { + // skip timestamp + if (is.peek() != '[') { + getline(is, line); + return true; + } + if (!(is >> line)) return false; + // get filename + while (is.peek() == ' ') is.get(); + if (!getline(is, file_name, ':')) return false; + // get line number + if (!(is >> line_number)) return false; + // get rest of the message. + while (is.peek() == ' ' || is.peek() == ':') is.get(); + if (!getline(is, line)) return false; + // detect check message, rewrite to remote extra : + if (line.compare(0, 13, "Check failed:") == 0) { + size_t end_pos = line.find(':', 13); + if (end_pos == std::string::npos) return false; + check_msg = line.substr(0, end_pos + 1) + ' '; + line = line.substr(end_pos + 1); + } + return true; + }; + // if not in correct format, do not do any rewrite. + if (!parse_log_header()) return err_msg; + // Parse error type. + { + size_t start_pos = 0, end_pos; + for (; start_pos < line.length() && line[start_pos] == ' '; ++start_pos) {} + for (end_pos = start_pos; end_pos < line.length(); ++end_pos) { + char ch = line[end_pos]; + if (ch == ':') { + error_type = line.substr(start_pos, end_pos - start_pos); + break; + } + // [A-Z0-9a-z_.] + if (!std::isalpha(ch) && !std::isdigit(ch) && ch != '_' && ch != '.') break; + } + if (error_type.length() != 0) { + // if we successfully detected error_type: trim the following space. + for (start_pos = end_pos + 1; + start_pos < line.length() && line[start_pos] == ' '; ++start_pos) {} + line = line.substr(start_pos); + } else { + // did not detect error_type, use default value. + line = line.substr(start_pos); + error_type = "TVMError"; + } + } + // Seperate out stack trace. + std::ostringstream os; + os << error_type << ": " << check_msg << line << '\n'; + + bool trace_mode = true; + std::vector stack_trace; + while (getline(is, line)) { + if (trace_mode) { + if (line.compare(0, 2, " ") == 0) { + stack_trace.push_back(line); + } else { + trace_mode = false; + // remove EOL trailing stacktrace. + if (line.length() == 0) continue; + } + } + if (!trace_mode) { + if (line.compare(0, 11, "Stack trace") == 0) { + trace_mode = true; + } else { + os << line << '\n'; + } + } + } + if (stack_trace.size() != 0 || file_name.length() != 0) { + os << "Stack trace:\n"; + if (file_name.length() != 0) { + os << " File \"" << file_name << "\", line " << line_number << "\n"; + } + // Print out stack traces, optionally trim the c++ traces + // about the frontends (as they will be provided by the frontends). + bool ffi_boundary = false; + for (const auto& line : stack_trace) { + // Heuristic to detect python ffi. + if (line.find("libffi.so") != std::string::npos || + line.find("core.cpython") != std::string::npos) { + ffi_boundary = true; + } + // If the backtrace is not c++ backtrace with the prefix " [bt]", + // then we can stop trimming. + if (ffi_boundary && line.compare(0, 6, " [bt]") != 0) { + ffi_boundary = false; + } + if (!ffi_boundary) { + os << line << '\n'; + } + // The line after TVMFuncCall cound be in FFI. + if (line.find("(TVMFuncCall") != std::string::npos) { + ffi_boundary = true; + } + } + } + return os.str(); +} + +#else +std::string NormalizeError(std::string err_msg) { + return err_msg; +} +#endif } // namespace runtime } // namespace tvm @@ -121,6 +288,11 @@ const char *TVMGetLastError() { return TVMAPIRuntimeStore::Get()->last_error.c_str(); } +int TVMAPIHandleException(const std::runtime_error &e) { + TVMAPISetLastError(NormalizeError(e.what()).c_str()); + return -1; +} + void TVMAPISetLastError(const char* msg) { #ifndef _LIBCPP_SGX_CONFIG TVMAPIRuntimeStore::Get()->last_error = msg; @@ -279,9 +451,7 @@ int TVMFuncCreateFromCFunc(TVMPackedCFunc func, int ret = func((TVMValue*)args.values, (int*)args.type_codes, // NOLINT(*) args.num_args, rv, resource_handle); if (ret != 0) { - std::string err = "TVMCall CFunc Error:\n"; - err += TVMGetLastError(); - throw dmlc::Error(err); + throw dmlc::Error(TVMGetLastError() + ::dmlc::StackTrace()); } }); } else { @@ -293,9 +463,7 @@ int TVMFuncCreateFromCFunc(TVMPackedCFunc func, int ret = func((TVMValue*)args.values, (int*)args.type_codes, // NOLINT(*) args.num_args, rv, rpack.get()); if (ret != 0) { - std::string err = "TVMCall CFunc Error:\n"; - err += TVMGetLastError(); - throw dmlc::Error(err); + throw dmlc::Error(TVMGetLastError() + ::dmlc::StackTrace()); } }); } diff --git a/src/runtime/runtime_base.h b/src/runtime/runtime_base.h index b47024adca91..459538796e6a 100644 --- a/src/runtime/runtime_base.h +++ b/src/runtime/runtime_base.h @@ -26,9 +26,6 @@ * \param e the exception * \return the return value of API after exception is handled */ -inline int TVMAPIHandleException(const std::runtime_error &e) { - TVMAPISetLastError(e.what()); - return -1; -} +int TVMAPIHandleException(const std::runtime_error &e); #endif // TVM_RUNTIME_RUNTIME_BASE_H_ diff --git a/tests/python/unittest/test_runtime_error.py b/tests/python/unittest/test_runtime_error.py new file mode 100644 index 000000000000..2cf2f10ab1de --- /dev/null +++ b/tests/python/unittest/test_runtime_error.py @@ -0,0 +1,55 @@ +"""Test runtime error handling""" +import tvm + +def test_op_translation(): + ferror = tvm._api_internal._test_raise_error_callback( + "OpNotImplemented: myop") + try: + ferror() + assert False + except tvm.error.OpNotImplemented as e: + msg = str(e) + assert msg.find("api_test.cc") != -1 + + fchk_eq = tvm._api_internal._test_check_eq_callback( + "InternalError: myop") + try: + fchk_eq(0, 1) + assert False + except tvm.error.InternalError as e: + msg = str(e) + assert msg.find("api_test.cc") != -1 + + try: + tvm._api_internal._ErrorTest(0, 1) + assert False + except ValueError as e: + msg = str(e) + assert msg.find("api_test.cc") != -1 + + +def test_deep_callback(): + def error_callback(): + raise ValueError("callback error") + wrap1 = tvm._api_internal._test_wrap_callback(error_callback) + def flevel2(): + wrap1() + wrap2 = tvm._api_internal._test_wrap_callback(flevel2) + def flevel3(): + wrap2() + wrap3 = tvm._api_internal._test_wrap_callback(flevel3) + + try: + wrap3() + assert False + except ValueError as e: + msg = str(e) + idx2 = msg.find("in flevel2") + idx3 = msg.find("in flevel3") + assert idx2 != -1 and idx3 != -1 + assert idx2 > idx3 + + +if __name__ == "__main__": + test_op_translation() + test_deep_callback()