Skip to content

Commit

Permalink
fix flaky test for rotate_bounding_box (#7362)
Browse files Browse the repository at this point in the history
Co-authored-by: Nicolas Hug <contact@nicolas-hug.com>
  • Loading branch information
pmeier and NicolasHug authored Mar 1, 2023
1 parent feda8b7 commit 924d373
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 7 deletions.
2 changes: 1 addition & 1 deletion test/common_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,7 +351,7 @@ def assert_close(

def parametrized_error_message(*args, **kwargs):
def to_str(obj):
if isinstance(obj, torch.Tensor) and obj.numel() > 10:
if isinstance(obj, torch.Tensor) and obj.numel() > 30:
return f"tensor(shape={list(obj.shape)}, dtype={obj.dtype}, device={obj.device})"
elif isinstance(obj, enum.Enum):
return f"{type(obj).__name__}.{obj.name}"
Expand Down
8 changes: 4 additions & 4 deletions test/test_transforms_v2_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ def test_scripted_vs_eager(self, test_id, info, args_kwargs, device):
actual,
expected,
**info.get_closeness_kwargs(test_id, dtype=input.dtype, device=input.device),
msg=parametrized_error_message(*([actual, expected] + other_args), **kwargs),
msg=parametrized_error_message(input, other_args, **kwargs),
)

def _unbatch(self, batch, *, data_dims):
Expand Down Expand Up @@ -204,7 +204,7 @@ def test_batched_vs_single(self, test_id, info, args_kwargs, device):
actual,
expected,
**info.get_closeness_kwargs(test_id, dtype=batched_input.dtype, device=batched_input.device),
msg=parametrized_error_message(*other_args, **kwargs),
msg=parametrized_error_message(batched_input, *other_args, **kwargs),
)

@sample_inputs
Expand Down Expand Up @@ -236,7 +236,7 @@ def test_cuda_vs_cpu(self, test_id, info, args_kwargs):
output_cpu,
check_device=False,
**info.get_closeness_kwargs(test_id, dtype=input_cuda.dtype, device=input_cuda.device),
msg=parametrized_error_message(*other_args, **kwargs),
msg=parametrized_error_message(input_cpu, *other_args, **kwargs),
)

@sample_inputs
Expand Down Expand Up @@ -294,7 +294,7 @@ def test_float32_vs_uint8(self, test_id, info, args_kwargs):
actual,
expected,
**info.get_closeness_kwargs(test_id, dtype=torch.float32, device=input.device),
msg=parametrized_error_message(*other_args, **kwargs),
msg=parametrized_error_message(input, *other_args, **kwargs),
)


Expand Down
4 changes: 2 additions & 2 deletions test/transforms_v2_kernel_infos.py
Original file line number Diff line number Diff line change
Expand Up @@ -860,8 +860,8 @@ def sample_inputs_rotate_video():
reference_fn=reference_rotate_bounding_box,
reference_inputs_fn=reference_inputs_rotate_bounding_box,
closeness_kwargs={
**scripted_vs_eager_float64_tolerances("cpu", atol=1e-6, rtol=1e-6),
**scripted_vs_eager_float64_tolerances("cuda", atol=1e-5, rtol=1e-5),
**scripted_vs_eager_float64_tolerances("cpu", atol=1e-4, rtol=1e-4),
**scripted_vs_eager_float64_tolerances("cuda", atol=1e-4, rtol=1e-4),
},
),
KernelInfo(
Expand Down

0 comments on commit 924d373

Please sign in to comment.