Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor test_models to use pytest #3697

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .circleci/unittest/linux/scripts/install.sh
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ else
fi

printf "Installing PyTorch with %s\n" "${cudatoolkit}"
conda install -y -c "pytorch-${UPLOAD_CHANNEL}" -c conda-forge "pytorch-${UPLOAD_CHANNEL}"::pytorch "${cudatoolkit}"
conda install -y -c "pytorch-${UPLOAD_CHANNEL}" -c conda-forge "pytorch-${UPLOAD_CHANNEL}"::pytorch "${cudatoolkit}" pytest

if [ $PYTHON_VERSION == "3.6" ]; then
printf "Installing minimal PILLOW version\n"
Expand Down
2 changes: 1 addition & 1 deletion .circleci/unittest/windows/scripts/install.sh
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ else
fi

printf "Installing PyTorch with %s\n" "${cudatoolkit}"
conda install -y -c "pytorch-${UPLOAD_CHANNEL}" -c conda-forge "pytorch-${UPLOAD_CHANNEL}"::pytorch "${cudatoolkit}"
conda install -y -c "pytorch-${UPLOAD_CHANNEL}" -c conda-forge "pytorch-${UPLOAD_CHANNEL}"::pytorch "${cudatoolkit}" pytest

if [ $PYTHON_VERSION == "3.6" ]; then
printf "Installing minimal PILLOW version\n"
Expand Down
31 changes: 7 additions & 24 deletions test/common_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,28 +100,20 @@ def is_iterable(obj):
class TestCase(unittest.TestCase):
precision = 1e-5

def _get_expected_file(self, subname=None, strip_suffix=None):
def remove_prefix_suffix(text, prefix, suffix):
if text.startswith(prefix):
text = text[len(prefix):]
if suffix is not None and text.endswith(suffix):
text = text[:len(text) - len(suffix)]
return text
def _get_expected_file(self, name=None):
# NB: we take __file__ from the module that defined the test
# class, so we place the expect directory where the test script
# lives, NOT where test/common_utils.py lives.
module_id = self.__class__.__module__
munged_id = remove_prefix_suffix(self.id(), module_id + ".", strip_suffix)

# Determine expected file based on environment
expected_file_base = get_relative_path(
os.path.realpath(sys.modules[module_id].__file__),
"expect")

# Set expected_file based on subname.
expected_file = os.path.join(expected_file_base, munged_id)
if subname:
expected_file += "_" + subname
# Note: for legacy reasons, the reference file names all had "ModelTest.test_" in their names
# We hardcode it here to avoid having to re-generate the reference files
expected_file = expected_file = os.path.join(expected_file_base, 'ModelTester.test_' + name)
expected_file += "_expect.pkl"

if not ACCEPT and not os.path.exists(expected_file):
Expand All @@ -132,25 +124,16 @@ def remove_prefix_suffix(text, prefix, suffix):

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The munged_id variable in the exact above line should have also been removed.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll submit a fix

return expected_file

def assertExpected(self, output, subname=None, prec=None, strip_suffix=None):
def assertExpected(self, output, name, prec=None):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had to tweak this a little because in the previous version, the check name was based on self.id(), which changes now that the tests are parametrized with pytest.

In fact I simplified it a bit because all what assertExpected needs to know is the name of the model, so I removed strip_suffix (which we don't need anymore) and also subname which was actually never used.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As a general note I would like us to re-use as much as possible the testing utils from PyTorch, if it makes sense. This was originally taken from PyTorch tests, and now it is exposed via torch.testing I believe

r"""
Test that a python value matches the recorded contents of a file
derived from the name of this test and subname. The value must be
based on a "check" name. The value must be
pickable with `torch.save`. This file
is placed in the 'expect' directory in the same directory
as the test script. You can automatically update the recorded test
output using --accept.

If you call this multiple times in a single function, you must
give a unique subname each time.

strip_suffix allows different tests that expect similar numerics, e.g.
"test_xyz_cuda" and "test_xyz_cpu", to use the same pickled data.
test_xyz_cuda would pass strip_suffix="_cuda", test_xyz_cpu would pass
strip_suffix="_cpu", and they would both use a data file name based on
"test_xyz".
"""
expected_file = self._get_expected_file(subname, strip_suffix)
expected_file = self._get_expected_file(name)

if ACCEPT:
filename = {os.path.basename(expected_file)}
Expand Down
74 changes: 30 additions & 44 deletions test/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
import unittest
import warnings

import pytest


def get_available_classification_models():
# TODO add a registration mechanism to torchvision.models
Expand Down Expand Up @@ -79,7 +81,7 @@ def _test_classification_model(self, name, input_shape, dev):
# RNG always on CPU, to ensure x in cuda tests is bitwise identical to x in cpu tests
x = torch.rand(input_shape).to(device=dev)
out = model(x)
self.assertExpected(out.cpu(), prec=0.1, strip_suffix=f"_{dev}")
self.assertExpected(out.cpu(), name, prec=0.1)
self.assertEqual(out.shape[-1], 50)
self.check_jit_scriptable(model, (x,), unwrapper=script_model_unwrapper.get(name, None))

Expand All @@ -88,7 +90,7 @@ def _test_classification_model(self, name, input_shape, dev):
out = model(x)
# See autocast_flaky_numerics comment at top of file.
if name not in autocast_flaky_numerics:
self.assertExpected(out.cpu(), prec=0.1, strip_suffix=f"_{dev}")
self.assertExpected(out.cpu(), name, prec=0.1)
self.assertEqual(out.shape[-1], 50)

def _test_segmentation_model(self, name, dev):
Expand All @@ -104,17 +106,16 @@ def _test_segmentation_model(self, name, dev):

def check_out(out):
prec = 0.01
strip_suffix = f"_{dev}"
try:
# We first try to assert the entire output if possible. This is not
# only the best way to assert results but also handles the cases
# where we need to create a new expected result.
self.assertExpected(out.cpu(), prec=prec, strip_suffix=strip_suffix)
self.assertExpected(out.cpu(), name, prec=prec)
except AssertionError:
# Unfortunately some segmentation models are flaky with autocast
# so instead of validating the probability scores, check that the class
# predictions match.
expected_file = self._get_expected_file(strip_suffix=strip_suffix)
expected_file = self._get_expected_file(name)
expected = torch.load(expected_file)
self.assertEqual(out.argmax(dim=1), expected.argmax(dim=1), prec=prec)
return False # Partial validation performed
Expand Down Expand Up @@ -189,18 +190,18 @@ def compute_mean_std(tensor):

output = map_nested_tensor_object(out, tensor_map_fn=compact)
prec = 0.01
strip_suffix = f"_{dev}"
try:
# We first try to assert the entire output if possible. This is not
# only the best way to assert results but also handles the cases
# where we need to create a new expected result.
self.assertExpected(output, prec=prec, strip_suffix=strip_suffix)
self.assertExpected(output, name, prec=prec)
raise AssertionError
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@NicolasHug I think this was left here by accident. This will cause the tests to be considered partially validated and marked as skipped.

except AssertionError:
# Unfortunately detection models are flaky due to the unstable sort
# in NMS. If matching across all outputs fails, use the same approach
# as in NMSTester.test_nms_cuda to see if this is caused by duplicate
# scores.
expected_file = self._get_expected_file(strip_suffix=strip_suffix)
expected_file = self._get_expected_file(name)
expected = torch.load(expected_file)
self.assertEqual(output[0]["scores"], expected[0]["scores"], prec=prec)

Expand Down Expand Up @@ -430,50 +431,35 @@ def test_generalizedrcnn_transform_repr(self):
_devs = [torch.device("cpu"), torch.device("cuda")] if torch.cuda.is_available() else [torch.device("cpu")]


for model_name in get_available_classification_models():
for dev in _devs:
# for-loop bodies don't define scopes, so we have to save the variables
# we want to close over in some way
def do_test(self, model_name=model_name, dev=dev):
input_shape = (1, 3, 224, 224)
if model_name in ['inception_v3']:
input_shape = (1, 3, 299, 299)
self._test_classification_model(model_name, input_shape, dev)

setattr(ModelTester, f"test_{model_name}_{dev}", do_test)


for model_name in get_available_segmentation_models():
for dev in _devs:
# for-loop bodies don't define scopes, so we have to save the variables
# we want to close over in some way
def do_test(self, model_name=model_name, dev=dev):
self._test_segmentation_model(model_name, dev)
@pytest.mark.parametrize('model_name', get_available_classification_models())
@pytest.mark.parametrize('dev', _devs)
def test_classification_model(model_name, dev):
input_shape = (1, 3, 299, 299) if model_name == 'inception_v3' else (1, 3, 224, 224)
ModelTester()._test_classification_model(model_name, input_shape, dev)

setattr(ModelTester, f"test_{model_name}_{dev}", do_test)

@pytest.mark.parametrize('model_name', get_available_segmentation_models())
@pytest.mark.parametrize('dev', _devs)
def test_segmentation_model(model_name, dev):
ModelTester()._test_segmentation_model(model_name, dev)

for model_name in get_available_detection_models():
for dev in _devs:
# for-loop bodies don't define scopes, so we have to save the variables
# we want to close over in some way
def do_test(self, model_name=model_name, dev=dev):
self._test_detection_model(model_name, dev)

setattr(ModelTester, f"test_{model_name}_{dev}", do_test)
@pytest.mark.parametrize('model_name', get_available_detection_models())
@pytest.mark.parametrize('dev', _devs)
def test_detection_model(model_name, dev):
ModelTester()._test_detection_model(model_name, dev)

def do_validation_test(self, model_name=model_name):
self._test_detection_model_validation(model_name)

setattr(ModelTester, "test_" + model_name + "_validation", do_validation_test)
@pytest.mark.parametrize('model_name', get_available_detection_models())
def test_detection_model_validation(model_name):
ModelTester()._test_detection_model_validation(model_name)


for model_name in get_available_video_models():
for dev in _devs:
def do_test(self, model_name=model_name, dev=dev):
self._test_video_model(model_name, dev)
@pytest.mark.parametrize('model_name', get_available_video_models())
@pytest.mark.parametrize('dev', _devs)
def test_video_model(model_name, dev):
ModelTester()._test_video_model(model_name, dev)

setattr(ModelTester, f"test_{model_name}_{dev}", do_test)

if __name__ == '__main__':
unittest.main()
pytest.main([__file__])