Skip to content

Commit

Permalink
[UnitTests] Require cached fixtures to be copy-able, with opt-in.
Browse files Browse the repository at this point in the history
Previously, any class that doesn't raise a TypeError in copy.deepcopy
could be used as a return value in a @tvm.testing.fixture.  This has
the possibility of incorrectly copying classes inherit the default
object.__reduce__ implementation.  Therefore, only classes that
explicitly implement copy functionality (e.g. __deepcopy__ or
__getstate__/__setstate__), or that are explicitly listed in
tvm.testing._fixture_cache are allowed to be cached.
  • Loading branch information
Lunderberg committed Jul 12, 2021
1 parent 1d7a9e9 commit 6502751
Show file tree
Hide file tree
Showing 2 changed files with 88 additions and 12 deletions.
60 changes: 48 additions & 12 deletions python/tvm/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@ def test_something():
"""
import collections
import copy
import copyreg
import ctypes
import functools
import logging
import os
Expand Down Expand Up @@ -1160,6 +1162,43 @@ def wraps(func):
return wraps(func)


class _DeepCopyAllowedClasses(dict):
def __init__(self, *allowed_class_list):
self.allowed_class_list = allowed_class_list
super().__init__()

def get(self, key, *args, **kwargs):
obj = ctypes.cast(key, ctypes.py_object).value
cls = type(obj)
if (
cls in copy._deepcopy_dispatch
or issubclass(cls, type)
or getattr(obj, "__deepcopy__", None)
or copyreg.dispatch_table.get(cls)
or cls.__reduce__ is not object.__reduce__
or cls.__reduce_ex__ is not object.__reduce_ex__
or cls in self.allowed_class_list
):
return super().get(key, *args, **kwargs)

rfc_url = (
"https://github.com/apache/tvm-rfcs/blob/main/rfcs/0007-parametrized-unit-tests.md"
)
raise TypeError(
(
"Cannot copy fixture of type {}. TVM fixture caching "
"is limited to objects that explicitly provide the ability "
"to be copied (e.g. through __deepcopy__, __getstate__, or __setstate__),"
"and forbids the use of the default `object.__reduce__` and "
"`object.__reduce_ex__`. For third-party classes that are known to be "
"safe to use with copy.deepcopy, please add the class to "
"the arguments of _DeepCopyAllowedClasses in tvm.testing._fixture_cache.\n"
"\n"
"For discussion on this restriction, please see {}."
).format(cls.__name__, rfc_url)
)


def _fixture_cache(func):
cache = {}

Expand Down Expand Up @@ -1199,18 +1238,15 @@ def wrapper(*args, **kwargs):
except KeyError:
cached_value = cache[cache_key] = func(*args, **kwargs)

try:
yield copy.deepcopy(cached_value)
except TypeError as e:
rfc_url = (
"https://github.com/apache/tvm-rfcs/blob/main/rfcs/"
"0007-parametrized-unit-tests.md#unresolved-questions"
)
message = (
"TVM caching of fixtures can only be used on serializable data types, not {}.\n"
"Please see {} for details/discussion."
).format(type(cached_value), rfc_url)
raise TypeError(message) from e
yield copy.deepcopy(
cached_value,
_DeepCopyAllowedClasses(
# *args should be a list of classes that are known
# to be safe to copy using copy.deepcopy, but do
# not implement __deepcopy__, __reduce__, or
# __reduce_ex__.
),
)

finally:
# Clear the cache once all tests that use a particular fixture
Expand Down
40 changes: 40 additions & 0 deletions tests/python/unittest/test_tvm_testing_features.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,5 +180,45 @@ def test_num_uses_cached(self):
assert self.num_uses_broken_cached_fixture == 0


@pytest.mark.skipif(
bool(int(os.environ.get("TVM_TEST_DISABLE_CACHE", "0"))),
reason="Cannot test cache behavior while caching is disabled",
)
class TestCacheableTypes:
class EmptyClass:
pass

@tvm.testing.fixture(cache_return_value=True)
def uncacheable_fixture(self):
return self.EmptyClass()

@pytest.mark.xfail(reason="Requests cached fixture of uncacheable type", strict=True)
def test_uses_uncacheable(self, uncacheable_fixture):
print("asdf")
pass

class ImplementsReduce:
def __reduce__(self):
return super().__reduce__()

@tvm.testing.fixture(cache_return_value=True)
def fixture_with_reduce(self):
return self.ImplementsReduce()

def test_uses_reduce(self, fixture_with_reduce):
pass

class ImplementsDeepcopy:
def __deepcopy__(self, memo):
return type(self)()

@tvm.testing.fixture(cache_return_value=True)
def fixture_with_deepcopy(self):
return self.ImplementsDeepcopy()

def test_uses_deepcopy(self, fixture_with_deepcopy):
pass


if __name__ == "__main__":
sys.exit(pytest.main(sys.argv))

0 comments on commit 6502751

Please sign in to comment.