diff --git a/test/test_models.py b/test/test_models.py index 5826cc77164..abffe91aaeb 100644 --- a/test/test_models.py +++ b/test/test_models.py @@ -29,7 +29,7 @@ def list_model_fns(module): return [get_model_builder(name) for name in list_models(module)] -def _get_image(input_shape, real_image, device): +def _get_image(input_shape, real_image, device, dtype=None): """This routine loads a real or random image based on `real_image` argument. Currently, the real image is utilized for the following list of models: - `retinanet_resnet50_fpn`, @@ -60,10 +60,10 @@ def _get_image(input_shape, real_image, device): convert_tensor = transforms.ToTensor() image = convert_tensor(img) assert tuple(image.size()) == input_shape - return image.to(device=device) + return image.to(device=device, dtype=dtype) # RNG always on CPU, to ensure x in cuda tests is bitwise identical to x in cpu tests - return torch.rand(input_shape).to(device=device) + return torch.rand(input_shape).to(device=device, dtype=dtype) @pytest.fixture @@ -278,6 +278,11 @@ def _check_input_backprop(model, inputs): # tests under test_quantized_classification_model will be skipped for the following models. quantized_flaky_models = ("inception_v3", "resnet50") +# The tests for the following detection models are flaky. +# We run those tests on float64 to avoid floating point errors. +# FIXME: we shouldn't have to do that :'/ +detection_flaky_models = ("keypointrcnn_resnet50_fpn", "maskrcnn_resnet50_fpn", "maskrcnn_resnet50_fpn_v2") + # The following contains configuration parameters for all models which are used by # the _test_*_model methods. @@ -777,13 +782,17 @@ def test_detection_model(model_fn, dev): "input_shape": (3, 300, 300), } model_name = model_fn.__name__ + if model_name in detection_flaky_models: + dtype = torch.float64 + else: + dtype = torch.get_default_dtype() kwargs = {**defaults, **_model_params.get(model_name, {})} input_shape = kwargs.pop("input_shape") real_image = kwargs.pop("real_image", False) model = model_fn(**kwargs) - model.eval().to(device=dev) - x = _get_image(input_shape=input_shape, real_image=real_image, device=dev) + model.eval().to(device=dev, dtype=dtype) + x = _get_image(input_shape=input_shape, real_image=real_image, device=dev, dtype=dtype) model_input = [x] with torch.no_grad(), freeze_rng_state(): out = model(model_input)