Skip to content

Commit e281964

Browse files
committed
update normalize based on PR reviews
1 parent 44c27ae commit e281964

File tree

3 files changed

+32
-15
lines changed

3 files changed

+32
-15
lines changed

test/common_utils.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
IN_OSS_CI = any(os.getenv(var) == "true" for var in ["CIRCLECI", "GITHUB_ACTIONS"])
2929
IN_RE_WORKER = os.environ.get("INSIDE_RE_WORKER") is not None
3030
IN_FBCODE = os.environ.get("IN_FBCODE_TORCHVISION") == "1"
31+
CVCUDA_AVAILABLE = _is_cvcuda_available()
3132
CUDA_NOT_AVAILABLE_MSG = "CUDA device not available"
3233
MPS_NOT_AVAILABLE_MSG = "MPS device not available"
3334
OSS_CI_GPU_NO_CUDA_MSG = "We're in an OSS GPU machine, and this test doesn't need cuda."
@@ -276,6 +277,17 @@ def combinations_grid(**kwargs):
276277
return [dict(zip(kwargs.keys(), values)) for values in itertools.product(*kwargs.values())]
277278

278279

280+
def cvcuda_to_pil_compatible_tensor(tensor: "cvcuda.Tensor") -> torch.Tensor:
281+
tensor = cvcuda_to_tensor(tensor)
282+
if tensor.ndim != 4:
283+
raise ValueError(f"CV-CUDA Tensor should be 4 dimensional. Got {tensor.ndim} dimensions.")
284+
if tensor.shape[0] != 1:
285+
raise ValueError(
286+
f"CV-CUDA Tensor should have batch dimension 1 for comparison with PIL.Image.Image. Got {tensor.shape[0]}."
287+
)
288+
return tensor.squeeze(0).cpu()
289+
290+
279291
class ImagePair(TensorLikePair):
280292
def __init__(
281293
self,
@@ -304,6 +316,11 @@ def __init__(
304316
expected = expected[0]
305317
expected = expected.cpu()
306318

319+
# handle check for CV-CUDA Tensors
320+
if CVCUDA_AVAILABLE and isinstance(actual, _import_cvcuda().Tensor):
321+
# Use the PIL compatible tensor, so we can always compare with PIL.Image.Image
322+
actual = cvcuda_to_pil_compatible_tensor(actual)
323+
307324
super().__init__(actual, expected, **other_parameters)
308325
self.mae = mae
309326

test/test_transforms_v2.py

Lines changed: 12 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -5553,17 +5553,17 @@ def test_kernel_image(self, mean, std, device):
55535553

55545554
@pytest.mark.parametrize("device", cpu_and_cuda())
55555555
def test_kernel_image_inplace(self, device):
5556-
input = make_image_tensor(dtype=torch.float32, device=device)
5557-
input_version = input._version
5556+
inpt = make_image_tensor(dtype=torch.float32, device=device)
5557+
input_version = inpt._version
55585558

5559-
output_out_of_place = F.normalize_image(input, mean=self.MEAN, std=self.STD)
5560-
assert output_out_of_place.data_ptr() != input.data_ptr()
5561-
assert output_out_of_place is not input
5559+
output_out_of_place = F.normalize_image(inpt, mean=self.MEAN, std=self.STD)
5560+
assert output_out_of_place.data_ptr() != inpt.data_ptr()
5561+
assert output_out_of_place is not inpt
55625562

5563-
output_inplace = F.normalize_image(input, mean=self.MEAN, std=self.STD, inplace=True)
5564-
assert output_inplace.data_ptr() == input.data_ptr()
5563+
output_inplace = F.normalize_image(inpt, mean=self.MEAN, std=self.STD, inplace=True)
5564+
assert output_inplace.data_ptr() == inpt.data_ptr()
55655565
assert output_inplace._version > input_version
5566-
assert output_inplace is input
5566+
assert output_inplace is inpt
55675567

55685568
assert_equal(output_inplace, output_out_of_place)
55695569

@@ -5613,9 +5613,9 @@ def test_functional_error(self):
56135613
with pytest.raises(ValueError, match="std evaluated to zero, leading to division by zero"):
56145614
F.normalize_image(make_image(dtype=torch.float32), mean=self.MEAN, std=std)
56155615

5616-
def _sample_input_adapter(self, transform, input, device):
5616+
def _sample_input_adapter(self, transform, inpt, device):
56175617
adapted_input = {}
5618-
for key, value in input.items():
5618+
for key, value in inpt.items():
56195619
if isinstance(value, PIL.Image.Image):
56205620
# normalize doesn't support PIL images
56215621
continue
@@ -5669,15 +5669,12 @@ def test_correctness_image(self, mean, std, dtype, make_input, fn):
56695669
actual = fn(image, mean=mean, std=std)
56705670

56715671
if make_input == make_image_cvcuda:
5672-
image = F.cvcuda_to_tensor(image).to(device="cpu")
5673-
image = image.squeeze(0)
5674-
actual = F.cvcuda_to_tensor(actual).to(device="cpu")
5675-
actual = actual.squeeze(0)
5672+
image = cvcuda_to_pil_compatible_tensor(image)
56765673

56775674
expected = self._reference_normalize_image(image, mean=mean, std=std)
56785675

56795676
if make_input == make_image_cvcuda:
5680-
torch.testing.assert_close(actual, expected, rtol=0, atol=1e-6)
5677+
assert_close(actual, expected, rtol=0, atol=1e-6)
56815678
else:
56825679
assert_equal(actual, expected)
56835680

torchvision/transforms/v2/_misc.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
get_bounding_boxes,
1818
get_keypoints,
1919
has_any,
20+
is_cvcuda_tensor,
2021
is_pure_tensor,
2122
)
2223

@@ -160,6 +161,8 @@ class Normalize(Transform):
160161

161162
_v1_transform_cls = _transforms.Normalize
162163

164+
_transformed_types = Transform._transformed_types + (is_cvcuda_tensor,)
165+
163166
def __init__(self, mean: Sequence[float], std: Sequence[float], inplace: bool = False):
164167
super().__init__()
165168
self.mean = list(mean)

0 commit comments

Comments
 (0)