@@ -5553,17 +5553,17 @@ def test_kernel_image(self, mean, std, device):
55535553
55545554 @pytest .mark .parametrize ("device" , cpu_and_cuda ())
55555555 def test_kernel_image_inplace (self , device ):
5556- input = make_image_tensor (dtype = torch .float32 , device = device )
5557- input_version = input ._version
5556+ inpt = make_image_tensor (dtype = torch .float32 , device = device )
5557+ input_version = inpt ._version
55585558
5559- output_out_of_place = F .normalize_image (input , mean = self .MEAN , std = self .STD )
5560- assert output_out_of_place .data_ptr () != input .data_ptr ()
5561- assert output_out_of_place is not input
5559+ output_out_of_place = F .normalize_image (inpt , mean = self .MEAN , std = self .STD )
5560+ assert output_out_of_place .data_ptr () != inpt .data_ptr ()
5561+ assert output_out_of_place is not inpt
55625562
5563- output_inplace = F .normalize_image (input , mean = self .MEAN , std = self .STD , inplace = True )
5564- assert output_inplace .data_ptr () == input .data_ptr ()
5563+ output_inplace = F .normalize_image (inpt , mean = self .MEAN , std = self .STD , inplace = True )
5564+ assert output_inplace .data_ptr () == inpt .data_ptr ()
55655565 assert output_inplace ._version > input_version
5566- assert output_inplace is input
5566+ assert output_inplace is inpt
55675567
55685568 assert_equal (output_inplace , output_out_of_place )
55695569
@@ -5613,9 +5613,9 @@ def test_functional_error(self):
56135613 with pytest .raises (ValueError , match = "std evaluated to zero, leading to division by zero" ):
56145614 F .normalize_image (make_image (dtype = torch .float32 ), mean = self .MEAN , std = std )
56155615
5616- def _sample_input_adapter (self , transform , input , device ):
5616+ def _sample_input_adapter (self , transform , inpt , device ):
56175617 adapted_input = {}
5618- for key , value in input .items ():
5618+ for key , value in inpt .items ():
56195619 if isinstance (value , PIL .Image .Image ):
56205620 # normalize doesn't support PIL images
56215621 continue
@@ -5669,15 +5669,12 @@ def test_correctness_image(self, mean, std, dtype, make_input, fn):
56695669 actual = fn (image , mean = mean , std = std )
56705670
56715671 if make_input == make_image_cvcuda :
5672- image = F .cvcuda_to_tensor (image ).to (device = "cpu" )
5673- image = image .squeeze (0 )
5674- actual = F .cvcuda_to_tensor (actual ).to (device = "cpu" )
5675- actual = actual .squeeze (0 )
5672+ image = cvcuda_to_pil_compatible_tensor (image )
56765673
56775674 expected = self ._reference_normalize_image (image , mean = mean , std = std )
56785675
56795676 if make_input == make_image_cvcuda :
5680- torch . testing . assert_close (actual , expected , rtol = 0 , atol = 1e-6 )
5677+ assert_close (actual , expected , rtol = 0 , atol = 1e-6 )
56815678 else :
56825679 assert_equal (actual , expected )
56835680
0 commit comments