99
1010import unittest
1111
12- from common_utils import TransformsTester , get_tmp_dir
12+ from common_utils import TransformsTester , get_tmp_dir , int_dtypes , float_dtypes
1313
1414
1515class Tester (TransformsTester ):
@@ -27,26 +27,26 @@ def _test_functional_op(self, func, fn_kwargs):
2727 transformed_pil_img = f (pil_img , ** fn_kwargs )
2828 self .compareTensorToPIL (transformed_tensor , transformed_pil_img )
2929
30- def _test_transform_vs_scripted (self , transform , s_transform , tensor ):
30+ def _test_transform_vs_scripted (self , transform , s_transform , tensor , msg = None ):
3131 torch .manual_seed (12 )
3232 out1 = transform (tensor )
3333 torch .manual_seed (12 )
3434 out2 = s_transform (tensor )
35- self .assertTrue (out1 .equal (out2 ))
35+ self .assertTrue (out1 .equal (out2 ), msg = msg )
3636
37- def _test_transform_vs_scripted_on_batch (self , transform , s_transform , batch_tensors ):
37+ def _test_transform_vs_scripted_on_batch (self , transform , s_transform , batch_tensors , msg = None ):
3838 torch .manual_seed (12 )
3939 transformed_batch = transform (batch_tensors )
4040
4141 for i in range (len (batch_tensors )):
4242 img_tensor = batch_tensors [i , ...]
4343 torch .manual_seed (12 )
4444 transformed_img = transform (img_tensor )
45- self .assertTrue (transformed_img .equal (transformed_batch [i , ...]))
45+ self .assertTrue (transformed_img .equal (transformed_batch [i , ...]), msg = msg )
4646
4747 torch .manual_seed (12 )
4848 s_transformed_batch = s_transform (batch_tensors )
49- self .assertTrue (transformed_batch .equal (s_transformed_batch ))
49+ self .assertTrue (transformed_batch .equal (s_transformed_batch ), msg = msg )
5050
5151 def _test_class_op (self , method , meth_kwargs = None , test_exact_match = True , ** match_kwargs ):
5252 if meth_kwargs is None :
@@ -492,6 +492,32 @@ def test_random_erasing(self):
492492 self ._test_transform_vs_scripted (fn , scripted_fn , tensor )
493493 self ._test_transform_vs_scripted_on_batch (fn , scripted_fn , batch_tensors )
494494
495+ def test_convert_image_dtype (self ):
496+ tensor , _ = self ._create_data (26 , 34 , device = self .device )
497+ batch_tensors = torch .rand (4 , 3 , 44 , 56 , device = self .device )
498+
499+ for in_dtype in int_dtypes () + float_dtypes ():
500+ in_tensor = tensor .to (in_dtype )
501+ in_batch_tensors = batch_tensors .to (in_dtype )
502+ for out_dtype in int_dtypes () + float_dtypes ():
503+
504+ fn = T .ConvertImageDtype (dtype = out_dtype )
505+ scripted_fn = torch .jit .script (fn )
506+
507+ if (in_dtype == torch .float32 and out_dtype in (torch .int32 , torch .int64 )) or \
508+ (in_dtype == torch .float64 and out_dtype == torch .int64 ):
509+ with self .assertRaisesRegex (RuntimeError , r"cannot be performed safely" ):
510+ self ._test_transform_vs_scripted (fn , scripted_fn , in_tensor )
511+ with self .assertRaisesRegex (RuntimeError , r"cannot be performed safely" ):
512+ self ._test_transform_vs_scripted_on_batch (fn , scripted_fn , in_batch_tensors )
513+ continue
514+
515+ self ._test_transform_vs_scripted (fn , scripted_fn , in_tensor )
516+ self ._test_transform_vs_scripted_on_batch (fn , scripted_fn , in_batch_tensors )
517+
518+ with get_tmp_dir () as tmp_dir :
519+ scripted_fn .save (os .path .join (tmp_dir , "t_convert_dtype.pt" ))
520+
495521
496522@unittest .skipIf (not torch .cuda .is_available (), reason = "Skip if no CUDA device" )
497523class CUDATester (Tester ):
0 commit comments