diff --git a/python/tvm/ffi/container.py b/python/tvm/ffi/container.py index 6ababe25577a..66038976f5d2 100644 --- a/python/tvm/ffi/container.py +++ b/python/tvm/ffi/container.py @@ -78,6 +78,9 @@ def __len__(self): return _ffi_api.ArraySize(self) def __repr__(self): + # exception safety handling for chandle=None + if self.__chandle__() == 0: + return type(self).__name__ + "(chandle=None)" return "[" + ", ".join([x.__repr__() for x in self]) + "]" @@ -197,4 +200,7 @@ def get(self, key, default=None): return self[key] if key in self else default def __repr__(self): + # exception safety handling for chandle=None + if self.__chandle__() == 0: + return type(self).__name__ + "(chandle=None)" return "{" + ", ".join([f"{k.__repr__()}: {v.__repr__()}" for k, v in self.items()]) + "}" diff --git a/python/tvm/ffi/cython/object.pxi b/python/tvm/ffi/cython/object.pxi index f971ca8f5a4b..4efedf35d8f4 100644 --- a/python/tvm/ffi/cython/object.pxi +++ b/python/tvm/ffi/cython/object.pxi @@ -85,9 +85,15 @@ cdef class Object: """ cdef void* chandle + def __cinit__(self): + # initialize chandle to NULL to avoid leak in + # case of error before chandle is set + self.chandle = NULL + def __dealloc__(self): if self.chandle != NULL: CHECK_CALL(TVMFFIObjectFree(self.chandle)) + self.chandle = NULL def __ctypes_handle__(self): return ctypes_handle(self.chandle) @@ -116,16 +122,23 @@ cdef class Object: self.chandle = NULL def __getattr__(self, name): + if self.chandle == NULL: + raise AttributeError(f"{type(self)} has no attribute {name}") try: return __object_getattr__(self, name) except AttributeError: raise AttributeError(f"{type(self)} has no attribute {name}") def __dir__(self): + # exception safety handling for chandle=None + if self.chandle == NULL: + return [] return __object_dir__(self) def __repr__(self): - # make sure repr is a raw string + # exception safety handling for chandle=None + if self.chandle == NULL: + return type(self).__name__ + "(chandle=None)" return str(__object_repr__(self)) def __eq__(self, other): diff --git a/python/tvm/runtime/ndarray.py b/python/tvm/runtime/ndarray.py index 9d49d9c51db2..538fa15c8a49 100644 --- a/python/tvm/runtime/ndarray.py +++ b/python/tvm/runtime/ndarray.py @@ -164,6 +164,9 @@ def copyfrom(self, source_array): return self def __repr__(self): + # exception safety handling for chandle=None + if self.__chandle__() == 0: + return type(self).__name__ + "(chandle=None)" res = f"\n" res += self.numpy().__repr__() return res diff --git a/tests/python/ffi/test_container.py b/tests/python/ffi/test_container.py index b20c221b4f40..5ac3af179956 100644 --- a/tests/python/ffi/test_container.py +++ b/tests/python/ffi/test_container.py @@ -27,6 +27,20 @@ def test_array(): assert (a_slice[0], a_slice[1]) == (1, 2) +def test_bad_constructor_init_state(): + """Test when error is raised before __init_handle_by_constructor + + This case we need the FFI binding to gracefully handle both repr + and dealloc by ensuring the chandle is initialized and there is + proper repr code + """ + with pytest.raises(TypeError): + tvm_ffi.Array(1) + + with pytest.raises(AttributeError): + tvm_ffi.Map(1) + + def test_array_of_array_map(): a = tvm_ffi.convert([[1, 2, 3], {"A": 5, "B": 6}]) assert isinstance(a, tvm_ffi.Array)