Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[UnitTests] Require cached fixtures to be copy-able, with opt-in. #8451

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 72 additions & 14 deletions python/tvm/testing/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@ def test_something():
"""
import collections
import copy
import copyreg
import ctypes
import functools
import logging
import os
Expand Down Expand Up @@ -386,8 +388,14 @@ def _get_targets(target_str=None):
targets = []
for target in target_names:
target_kind = target.split()[0]
is_enabled = tvm.runtime.enabled(target_kind)
is_runnable = is_enabled and tvm.device(target_kind).exist

if target_kind == "cuda" and "cudnn" in tvm.target.Target(target).attrs.get("libs", []):
is_enabled = tvm.support.libinfo()["USE_CUDNN"].lower() in ["on", "true", "1"]
is_runnable = is_enabled and cudnn.exists()
else:
is_enabled = tvm.runtime.enabled(target_kind)
is_runnable = is_enabled and tvm.device(target_kind).exist

targets.append(
{
"target": target,
Expand Down Expand Up @@ -1251,6 +1259,60 @@ def wraps(func):
return wraps(func)


class _DeepCopyAllowedClasses(dict):
def __init__(self, allowed_class_list):
self.allowed_class_list = allowed_class_list
super().__init__()

def get(self, key, *args, **kwargs):
"""Overrides behavior of copy.deepcopy to avoid implicit copy.

By default, copy.deepcopy uses a dict of id->object to track
all objects that it has seen, which is passed as the second
argument to all recursive calls. This class is intended to be
passed in instead, and inspects the type of all objects being
copied.

Where copy.deepcopy does a best-effort attempt at copying an
object, for unit tests we would rather have all objects either
be copied correctly, or to throw an error. Classes that
define an explicit method to perform a copy are allowed, as
are any explicitly listed classes. Classes that would fall
back to using object.__reduce__, and are not explicitly listed
as safe, will throw an exception.

"""
obj = ctypes.cast(key, ctypes.py_object).value
cls = type(obj)
if (
cls in copy._deepcopy_dispatch
or issubclass(cls, type)
or getattr(obj, "__deepcopy__", None)
or copyreg.dispatch_table.get(cls)
or cls.__reduce__ is not object.__reduce__
or cls.__reduce_ex__ is not object.__reduce_ex__
or cls in self.allowed_class_list
):
return super().get(key, *args, **kwargs)

rfc_url = (
"https://github.com/apache/tvm-rfcs/blob/main/rfcs/0007-parametrized-unit-tests.md"
)
raise TypeError(
(
f"Cannot copy fixture of type {cls.__name__}. TVM fixture caching "
"is limited to objects that explicitly provide the ability "
"to be copied (e.g. through __deepcopy__, __getstate__, or __setstate__),"
"and forbids the use of the default `object.__reduce__` and "
"`object.__reduce_ex__`. For third-party classes that are "
"safe to use with copy.deepcopy, please add the class to "
"the arguments of _DeepCopyAllowedClasses in tvm.testing._fixture_cache.\n"
"\n"
f"For discussion on this restriction, please see {rfc_url}."
)
)


def _fixture_cache(func):
cache = {}

Expand Down Expand Up @@ -1290,18 +1352,14 @@ def wrapper(*args, **kwargs):
except KeyError:
cached_value = cache[cache_key] = func(*args, **kwargs)

try:
yield copy.deepcopy(cached_value)
except TypeError as e:
rfc_url = (
"https://github.com/apache/tvm-rfcs/blob/main/rfcs/"
"0007-parametrized-unit-tests.md#unresolved-questions"
)
message = (
"TVM caching of fixtures can only be used on serializable data types, not {}.\n"
"Please see {} for details/discussion."
).format(type(cached_value), rfc_url)
raise TypeError(message) from e
yield copy.deepcopy(
cached_value,
# allowed_class_list should be a list of classes that
# are safe to copy using copy.deepcopy, but do not
# implement __deepcopy__, __reduce__, or
# __reduce_ex__.
_DeepCopyAllowedClasses(allowed_class_list=[]),
)

finally:
# Clear the cache once all tests that use a particular fixture
Expand Down
55 changes: 55 additions & 0 deletions tests/python/unittest/test_tvm_testing_features.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,22 @@ def test_cached_count(self):
assert self.cached_calls == len(self.param1_vals)


class TestCachedFixtureIsCopy:
param = tvm.testing.parameter(1, 2, 3, 4)

@tvm.testing.fixture(cache_return_value=True)
def cached_mutable_fixture(self):
return {"val": 0}

def test_modifies_fixture(self, param, cached_mutable_fixture):
assert cached_mutable_fixture["val"] == 0

# The tests should receive a copy of the fixture value. If
# the test receives the original and not a copy, then this
# will cause the next parametrization to fail.
cached_mutable_fixture["val"] = param


class TestBrokenFixture:
# Tests that use a fixture that throws an exception fail, and are
# marked as setup failures. The tests themselves are never run.
Expand Down Expand Up @@ -210,5 +226,44 @@ def test_pytest_mark_covariant(self, request, target, other_param):
self.check_marks(request, target)


@pytest.mark.skipif(
bool(int(os.environ.get("TVM_TEST_DISABLE_CACHE", "0"))),
reason="Cannot test cache behavior while caching is disabled",
)
class TestCacheableTypes:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is there any way to assert you're getting a copy in each test case?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good call, and I've added a TestCachedFixtureIsCopy to verify that the cached fixture received is a copy.

class EmptyClass:
pass

@tvm.testing.fixture(cache_return_value=True)
def uncacheable_fixture(self):
return self.EmptyClass()

def test_uses_uncacheable(self, request):
with pytest.raises(TypeError):
request.getfixturevalue("uncacheable_fixture")

class ImplementsReduce:
def __reduce__(self):
return super().__reduce__()

@tvm.testing.fixture(cache_return_value=True)
def fixture_with_reduce(self):
return self.ImplementsReduce()

def test_uses_reduce(self, fixture_with_reduce):
pass

class ImplementsDeepcopy:
def __deepcopy__(self, memo):
return type(self)()

@tvm.testing.fixture(cache_return_value=True)
def fixture_with_deepcopy(self):
return self.ImplementsDeepcopy()

def test_uses_deepcopy(self, fixture_with_deepcopy):
pass


if __name__ == "__main__":
sys.exit(pytest.main(sys.argv))