55import torch
66from torch .nn .functional import conv2d , pad as torch_pad
77from torchvision .prototype import features
8- from torchvision .transforms import functional_tensor as _FT
98from torchvision .transforms .functional import pil_to_tensor , to_pil_image
109
1110
@@ -68,9 +67,9 @@ def normalize(
6867
6968
7069def _get_gaussian_kernel1d (kernel_size : int , sigma : float , dtype : torch .dtype , device : torch .device ) -> torch .Tensor :
71- lim = (kernel_size - 1 ) / (2 * math .sqrt (2 ) * sigma )
70+ lim = (kernel_size - 1 ) / (2.0 * math .sqrt (2.0 ) * sigma )
7271 x = torch .linspace (- lim , lim , steps = kernel_size , dtype = dtype , device = device )
73- kernel1d = torch .softmax (- x .pow_ (2 ), dim = 0 )
72+ kernel1d = torch .softmax (x .pow_ (2 ). neg_ ( ), dim = 0 )
7473 return kernel1d
7574
7675
@@ -89,54 +88,61 @@ def gaussian_blur_image_tensor(
8988 # TODO: consider deprecating integers from sigma on the future
9089 if isinstance (kernel_size , int ):
9190 kernel_size = [kernel_size , kernel_size ]
92- if len (kernel_size ) != 2 :
91+ elif len (kernel_size ) != 2 :
9392 raise ValueError (f"If kernel_size is a sequence its length should be 2. Got { len (kernel_size )} " )
9493 for ksize in kernel_size :
9594 if ksize % 2 == 0 or ksize < 0 :
9695 raise ValueError (f"kernel_size should have odd and positive integers. Got { kernel_size } " )
9796
9897 if sigma is None :
9998 sigma = [ksize * 0.15 + 0.35 for ksize in kernel_size ]
100-
101- if sigma is not None and not isinstance (sigma , (int , float , list , tuple )):
102- raise TypeError (f"sigma should be either float or sequence of floats. Got { type (sigma )} " )
103- if isinstance (sigma , (int , float )):
104- sigma = [float (sigma ), float (sigma )]
105- if isinstance (sigma , (list , tuple )) and len (sigma ) == 1 :
106- sigma = [sigma [0 ], sigma [0 ]]
107- if len (sigma ) != 2 :
108- raise ValueError (f"If sigma is a sequence, its length should be 2. Got { len (sigma )} " )
99+ else :
100+ if isinstance (sigma , (list , tuple )):
101+ length = len (sigma )
102+ if length == 1 :
103+ s = float (sigma [0 ])
104+ sigma = [s , s ]
105+ elif length != 2 :
106+ raise ValueError (f"If sigma is a sequence, its length should be 2. Got { length } " )
107+ elif isinstance (sigma , (int , float )):
108+ s = float (sigma )
109+ sigma = [s , s ]
110+ else :
111+ raise TypeError (f"sigma should be either float or sequence of floats. Got { type (sigma )} " )
109112 for s in sigma :
110113 if s <= 0.0 :
111114 raise ValueError (f"sigma should have positive values. Got { sigma } " )
112115
113116 if image .numel () == 0 :
114117 return image
115118
119+ dtype = image .dtype
116120 shape = image .shape
117-
118- if image .ndim > 4 :
121+ ndim = image .ndim
122+ if ndim == 3 :
123+ image = image .unsqueeze (dim = 0 )
124+ elif ndim > 4 :
119125 image = image .reshape ((- 1 ,) + shape [- 3 :])
120- needs_unsquash = True
121- else :
122- needs_unsquash = False
123126
124- dtype = image . dtype if torch .is_floating_point (image ) else torch . float32
125- kernel = _get_gaussian_kernel2d (kernel_size , sigma , dtype = dtype , device = image .device )
126- kernel = kernel .expand (image . shape [- 3 ], 1 , kernel .shape [0 ], kernel .shape [1 ])
127+ fp = torch .is_floating_point (image )
128+ kernel = _get_gaussian_kernel2d (kernel_size , sigma , dtype = dtype if fp else torch . float32 , device = image .device )
129+ kernel = kernel .expand (shape [- 3 ], 1 , kernel .shape [0 ], kernel .shape [1 ])
127130
128- image , need_cast , need_squeeze , out_dtype = _FT . _cast_squeeze_in ( image , [ kernel . dtype ] )
131+ output = image if fp else image . to ( dtype = torch . float32 )
129132
130133 # padding = (left, right, top, bottom)
131134 padding = [kernel_size [0 ] // 2 , kernel_size [0 ] // 2 , kernel_size [1 ] // 2 , kernel_size [1 ] // 2 ]
132- output = torch_pad (image , padding , mode = "reflect" )
133- output = conv2d (output , kernel , groups = output .shape [- 3 ])
134-
135- output = _FT ._cast_squeeze_out (output , need_cast , need_squeeze , out_dtype )
135+ output = torch_pad (output , padding , mode = "reflect" )
136+ output = conv2d (output , kernel , groups = shape [- 3 ])
136137
137- if needs_unsquash :
138+ if ndim == 3 :
139+ output = output .squeeze (dim = 0 )
140+ elif ndim > 4 :
138141 output = output .reshape (shape )
139142
143+ if not fp :
144+ output = output .round_ ().to (dtype = dtype )
145+
140146 return output
141147
142148
0 commit comments