Skip to content

Commit

Permalink
[fbsync] Test some flaky detection models on float64 instead of float…
Browse files Browse the repository at this point in the history
…32 (#7204)

Reviewed By: vmoens

Differential Revision: D44416251

fbshipit-source-id: e80c2c1ebde2bad25ab5b6f29e2ac2278e0f4b90
  • Loading branch information
NicolasHug authored and facebook-github-bot committed Mar 28, 2023
1 parent 0afc675 commit f3cfc49
Showing 1 changed file with 14 additions and 5 deletions.
19 changes: 14 additions & 5 deletions test/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def list_model_fns(module):
return [get_model_builder(name) for name in list_models(module)]


def _get_image(input_shape, real_image, device):
def _get_image(input_shape, real_image, device, dtype=None):
"""This routine loads a real or random image based on `real_image` argument.
Currently, the real image is utilized for the following list of models:
- `retinanet_resnet50_fpn`,
Expand Down Expand Up @@ -60,10 +60,10 @@ def _get_image(input_shape, real_image, device):
convert_tensor = transforms.ToTensor()
image = convert_tensor(img)
assert tuple(image.size()) == input_shape
return image.to(device=device)
return image.to(device=device, dtype=dtype)

# RNG always on CPU, to ensure x in cuda tests is bitwise identical to x in cpu tests
return torch.rand(input_shape).to(device=device)
return torch.rand(input_shape).to(device=device, dtype=dtype)


@pytest.fixture
Expand Down Expand Up @@ -278,6 +278,11 @@ def _check_input_backprop(model, inputs):
# tests under test_quantized_classification_model will be skipped for the following models.
quantized_flaky_models = ("inception_v3", "resnet50")

# The tests for the following detection models are flaky.
# We run those tests on float64 to avoid floating point errors.
# FIXME: we shouldn't have to do that :'/
detection_flaky_models = ("keypointrcnn_resnet50_fpn", "maskrcnn_resnet50_fpn", "maskrcnn_resnet50_fpn_v2")


# The following contains configuration parameters for all models which are used by
# the _test_*_model methods.
Expand Down Expand Up @@ -777,13 +782,17 @@ def test_detection_model(model_fn, dev):
"input_shape": (3, 300, 300),
}
model_name = model_fn.__name__
if model_name in detection_flaky_models:
dtype = torch.float64
else:
dtype = torch.get_default_dtype()
kwargs = {**defaults, **_model_params.get(model_name, {})}
input_shape = kwargs.pop("input_shape")
real_image = kwargs.pop("real_image", False)

model = model_fn(**kwargs)
model.eval().to(device=dev)
x = _get_image(input_shape=input_shape, real_image=real_image, device=dev)
model.eval().to(device=dev, dtype=dtype)
x = _get_image(input_shape=input_shape, real_image=real_image, device=dev, dtype=dtype)
model_input = [x]
with torch.no_grad(), freeze_rng_state():
out = model(model_input)
Expand Down

0 comments on commit f3cfc49

Please sign in to comment.