@@ -5181,6 +5181,9 @@ def test_kernel_video(self):
51815181 make_segmentation_mask ,
51825182 make_video ,
51835183 make_keypoints ,
5184+ pytest .param (
5185+ make_image_cvcuda , marks = pytest .mark .skipif (not CVCUDA_AVAILABLE , reason = "CVCUDA not available" )
5186+ ),
51845187 ],
51855188 )
51865189 def test_functional (self , make_input ):
@@ -5196,9 +5199,16 @@ def test_functional(self, make_input):
51965199 (F .perspective_mask , tv_tensors .Mask ),
51975200 (F .perspective_video , tv_tensors .Video ),
51985201 (F .perspective_keypoints , tv_tensors .KeyPoints ),
5202+ pytest .param (
5203+ F ._geometry ._perspective_cvcuda ,
5204+ "cvcuda.Tensor" ,
5205+ marks = pytest .mark .skipif (not CVCUDA_AVAILABLE , reason = "CVCUDA not available" ),
5206+ ),
51995207 ],
52005208 )
52015209 def test_functional_signature (self , kernel , input_type ):
5210+ if input_type == "cvcuda.Tensor" :
5211+ input_type = _import_cvcuda ().Tensor
52025212 check_functional_kernel_signature_match (F .perspective , kernel = kernel , input_type = input_type )
52035213
52045214 @pytest .mark .parametrize ("distortion_scale" , [0.5 , 0.0 , 1.0 ])
@@ -5212,6 +5222,9 @@ def test_functional_signature(self, kernel, input_type):
52125222 make_segmentation_mask ,
52135223 make_video ,
52145224 make_keypoints ,
5225+ pytest .param (
5226+ make_image_cvcuda , marks = pytest .mark .skipif (not CVCUDA_AVAILABLE , reason = "CVCUDA not available" )
5227+ ),
52155228 ],
52165229 )
52175230 def test_transform (self , distortion_scale , make_input ):
@@ -5227,12 +5240,28 @@ def test_transform_error(self, distortion_scale):
52275240 "interpolation" , [transforms .InterpolationMode .NEAREST , transforms .InterpolationMode .BILINEAR ]
52285241 )
52295242 @pytest .mark .parametrize ("fill" , CORRECTNESS_FILLS )
5230- def test_image_functional_correctness (self , coefficients , interpolation , fill ):
5231- image = make_image (dtype = torch .uint8 , device = "cpu" )
5243+ @pytest .mark .parametrize (
5244+ "make_input" ,
5245+ [
5246+ make_image ,
5247+ pytest .param (
5248+ make_image_cvcuda , marks = pytest .mark .skipif (not CVCUDA_AVAILABLE , reason = "CVCUDA not available" )
5249+ ),
5250+ ],
5251+ )
5252+ def test_image_functional_correctness (self , coefficients , interpolation , fill , make_input ):
5253+ image = make_input (dtype = torch .uint8 , device = "cpu" )
52325254
52335255 actual = F .perspective (
52345256 image , startpoints = None , endpoints = None , coefficients = coefficients , interpolation = interpolation , fill = fill
52355257 )
5258+ if make_input is make_image_cvcuda :
5259+ actual = F .cvcuda_to_tensor (actual ).to (device = "cpu" )
5260+ actual = actual .squeeze (0 )
5261+ # drop the batch dimension
5262+ image = F .cvcuda_to_tensor (image ).to (device = "cpu" )
5263+ image = image .squeeze (0 )
5264+
52365265 expected = F .to_image (
52375266 F .perspective (
52385267 F .to_pil_image (image ),
@@ -5244,13 +5273,20 @@ def test_image_functional_correctness(self, coefficients, interpolation, fill):
52445273 )
52455274 )
52465275
5247- if interpolation is transforms .InterpolationMode .BILINEAR :
5248- abs_diff = (actual .float () - expected .float ()).abs ()
5249- assert (abs_diff > 1 ).float ().mean () < 7e-2
5250- mae = abs_diff .mean ()
5251- assert mae < 3
5252- else :
5253- assert_equal (actual , expected )
5276+ if make_input is make_image :
5277+ if interpolation is transforms .InterpolationMode .BILINEAR :
5278+ abs_diff = (actual .float () - expected .float ()).abs ()
5279+ assert (abs_diff > 1 ).float ().mean () < 7e-2
5280+ mae = abs_diff .mean ()
5281+ assert mae < 3
5282+ else :
5283+ assert_equal (actual , expected )
5284+ else : # CV-CUDA
5285+ # just check that the shapes/dtypes are the same, cvcuda warp_perspective uses different algorithm
5286+ # visually the results are the same on real images,
5287+ # realistically, the diff is not visible to the human eye
5288+ tolerance = 255 if interpolation is transforms .InterpolationMode .NEAREST else 125
5289+ torch .testing .assert_close (actual , expected , rtol = 0 , atol = tolerance )
52545290
52555291 def _reference_perspective_bounding_boxes (self , bounding_boxes , * , startpoints , endpoints ):
52565292 format = bounding_boxes .format
0 commit comments