From 650275194a5676b2077f31fd1170e075dae3efd6 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Mon, 28 Jun 2021 15:10:03 -0700 Subject: [PATCH] [UnitTests] Require cached fixtures to be copy-able, with opt-in. Previously, any class that doesn't raise a TypeError in copy.deepcopy could be used as a return value in a @tvm.testing.fixture. This has the possibility of incorrectly copying classes inherit the default object.__reduce__ implementation. Therefore, only classes that explicitly implement copy functionality (e.g. __deepcopy__ or __getstate__/__setstate__), or that are explicitly listed in tvm.testing._fixture_cache are allowed to be cached. --- python/tvm/testing.py | 60 +++++++++++++++---- .../unittest/test_tvm_testing_features.py | 40 +++++++++++++ 2 files changed, 88 insertions(+), 12 deletions(-) diff --git a/python/tvm/testing.py b/python/tvm/testing.py index 4721c0050656c..90011446c9d03 100644 --- a/python/tvm/testing.py +++ b/python/tvm/testing.py @@ -56,6 +56,8 @@ def test_something(): """ import collections import copy +import copyreg +import ctypes import functools import logging import os @@ -1160,6 +1162,43 @@ def wraps(func): return wraps(func) +class _DeepCopyAllowedClasses(dict): + def __init__(self, *allowed_class_list): + self.allowed_class_list = allowed_class_list + super().__init__() + + def get(self, key, *args, **kwargs): + obj = ctypes.cast(key, ctypes.py_object).value + cls = type(obj) + if ( + cls in copy._deepcopy_dispatch + or issubclass(cls, type) + or getattr(obj, "__deepcopy__", None) + or copyreg.dispatch_table.get(cls) + or cls.__reduce__ is not object.__reduce__ + or cls.__reduce_ex__ is not object.__reduce_ex__ + or cls in self.allowed_class_list + ): + return super().get(key, *args, **kwargs) + + rfc_url = ( + "https://github.com/apache/tvm-rfcs/blob/main/rfcs/0007-parametrized-unit-tests.md" + ) + raise TypeError( + ( + "Cannot copy fixture of type {}. TVM fixture caching " + "is limited to objects that explicitly provide the ability " + "to be copied (e.g. through __deepcopy__, __getstate__, or __setstate__)," + "and forbids the use of the default `object.__reduce__` and " + "`object.__reduce_ex__`. For third-party classes that are known to be " + "safe to use with copy.deepcopy, please add the class to " + "the arguments of _DeepCopyAllowedClasses in tvm.testing._fixture_cache.\n" + "\n" + "For discussion on this restriction, please see {}." + ).format(cls.__name__, rfc_url) + ) + + def _fixture_cache(func): cache = {} @@ -1199,18 +1238,15 @@ def wrapper(*args, **kwargs): except KeyError: cached_value = cache[cache_key] = func(*args, **kwargs) - try: - yield copy.deepcopy(cached_value) - except TypeError as e: - rfc_url = ( - "https://github.com/apache/tvm-rfcs/blob/main/rfcs/" - "0007-parametrized-unit-tests.md#unresolved-questions" - ) - message = ( - "TVM caching of fixtures can only be used on serializable data types, not {}.\n" - "Please see {} for details/discussion." - ).format(type(cached_value), rfc_url) - raise TypeError(message) from e + yield copy.deepcopy( + cached_value, + _DeepCopyAllowedClasses( + # *args should be a list of classes that are known + # to be safe to copy using copy.deepcopy, but do + # not implement __deepcopy__, __reduce__, or + # __reduce_ex__. + ), + ) finally: # Clear the cache once all tests that use a particular fixture diff --git a/tests/python/unittest/test_tvm_testing_features.py b/tests/python/unittest/test_tvm_testing_features.py index 07b8c652bf1fd..d31ae1aa2cb1f 100644 --- a/tests/python/unittest/test_tvm_testing_features.py +++ b/tests/python/unittest/test_tvm_testing_features.py @@ -180,5 +180,45 @@ def test_num_uses_cached(self): assert self.num_uses_broken_cached_fixture == 0 +@pytest.mark.skipif( + bool(int(os.environ.get("TVM_TEST_DISABLE_CACHE", "0"))), + reason="Cannot test cache behavior while caching is disabled", +) +class TestCacheableTypes: + class EmptyClass: + pass + + @tvm.testing.fixture(cache_return_value=True) + def uncacheable_fixture(self): + return self.EmptyClass() + + @pytest.mark.xfail(reason="Requests cached fixture of uncacheable type", strict=True) + def test_uses_uncacheable(self, uncacheable_fixture): + print("asdf") + pass + + class ImplementsReduce: + def __reduce__(self): + return super().__reduce__() + + @tvm.testing.fixture(cache_return_value=True) + def fixture_with_reduce(self): + return self.ImplementsReduce() + + def test_uses_reduce(self, fixture_with_reduce): + pass + + class ImplementsDeepcopy: + def __deepcopy__(self, memo): + return type(self)() + + @tvm.testing.fixture(cache_return_value=True) + def fixture_with_deepcopy(self): + return self.ImplementsDeepcopy() + + def test_uses_deepcopy(self, fixture_with_deepcopy): + pass + + if __name__ == "__main__": sys.exit(pytest.main(sys.argv))