@@ -609,21 +609,6 @@ def test_resize_antialias(device, dt, size, interpolation):
609609 assert_equal (resized_tensor , resize_result )
610610
611611
612- @needs_cuda
613- @pytest .mark .parametrize ("interpolation" , [BILINEAR , BICUBIC ])
614- def test_assert_resize_antialias (interpolation ):
615-
616- # Checks implementation on very large scales
617- # and catch TORCH_CHECK inside PyTorch implementation
618- torch .manual_seed (12 )
619- tensor , _ = _create_data (1000 , 1000 , device = "cuda" )
620-
621- # Error message is not yet updated in pytorch nightly
622- # with pytest.raises(RuntimeError, match=r"Provided interpolation parameters can not be handled"):
623- with pytest .raises (RuntimeError , match = r"Too much shared memory required" ):
624- F .resize (tensor , size = (5 , 5 ), interpolation = interpolation , antialias = True )
625-
626-
627612def test_resize_antialias_default_warning ():
628613
629614 img = torch .randint (0 , 256 , size = (3 , 44 , 56 ), dtype = torch .uint8 )
@@ -641,25 +626,6 @@ def test_resize_antialias_default_warning():
641626 F .resized_crop (img , 0 , 0 , 10 , 10 , size = (20 , 20 ), interpolation = NEAREST )
642627
643628
644- @pytest .mark .parametrize ("device" , cpu_and_gpu ())
645- @pytest .mark .parametrize ("dt" , [torch .float32 , torch .float64 , torch .float16 ])
646- @pytest .mark .parametrize ("size" , [[10 , 7 ], [10 , 42 ], [42 , 7 ]])
647- @pytest .mark .parametrize ("interpolation" , [BILINEAR , BICUBIC ])
648- def test_interpolate_antialias_backward (device , dt , size , interpolation ):
649-
650- if dt == torch .float16 and device == "cpu" :
651- # skip float16 on CPU case
652- return
653-
654- torch .manual_seed (12 )
655- x = (torch .rand (1 , 32 , 29 , 3 , dtype = torch .double , device = device ).permute (0 , 3 , 1 , 2 ).requires_grad_ (True ),)
656- resize = partial (F .resize , size = size , interpolation = interpolation , antialias = True )
657- assert torch .autograd .gradcheck (resize , x , eps = 1e-8 , atol = 1e-6 , rtol = 1e-6 , fast_mode = False )
658-
659- x = (torch .rand (1 , 3 , 32 , 29 , dtype = torch .double , device = device , requires_grad = True ),)
660- assert torch .autograd .gradcheck (resize , x , eps = 1e-8 , atol = 1e-6 , rtol = 1e-6 , fast_mode = False )
661-
662-
663629def check_functional_vs_PIL_vs_scripted (
664630 fn , fn_pil , fn_t , config , device , dtype , channels = 3 , tol = 2.0 + 1e-10 , agg_method = "max"
665631):
0 commit comments