11import math
22import numbers
3- from typing import List , Optional , Sequence , Tuple , Union
3+ from typing import List , Optional , Sequence , Tuple , Union , cast
44
55import numpy as np
66import torch
77import torch .nn as nn
88import torch .nn .functional as F
99
1010from captum .optim ._utils .image .common import nchannels_to_rgb
11- from captum .optim ._utils .typing import TransformSize , TransformVal , TransformValList
11+ from captum .optim ._utils .typing import IntSeqOrIntType , NumSeqOrTensorType
1212
1313device = torch .device ("cuda:0" if torch .cuda .is_available () else "cpu" )
1414
@@ -46,14 +46,19 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
4646
4747class ToRGB (nn .Module ):
4848 """Transforms arbitrary channels to RGB. We use this to ensure our
49- image parameteriaztion itself can be decorrelated. So this goes between
50- the image parameterization and the normalization/sigmoid step.
51- We offer two transforms: Karhunen-Loève (KLT) and I1I2I3.
49+ image parametrization itself can be decorrelated. So this goes between
50+ the image parametrization and the normalization/sigmoid step.
51+ We offer two precalculated transforms: Karhunen-Loève (KLT) and I1I2I3.
5252 KLT corresponds to the empirically measured channel correlations on imagenet.
53- I1I2I3 corresponds to an aproximation for natural images from Ohta et al.[0]
53+ I1I2I3 corresponds to an approximation for natural images from Ohta et al.[0]
5454 [0] Y. Ohta, T. Kanade, and T. Sakai, "Color information for region segmentation,"
5555 Computer Graphics and Image Processing, vol. 13, no. 3, pp. 222–241, 1980
5656 https://www.sciencedirect.com/science/article/pii/0146664X80900477
57+
58+ Arguments:
59+ transform (str or tensor): Either a string for one of the precalculated
60+ transform matrices, or a 3x3 matrix for the 3 RGB channels of input
61+ tensors.
5762 """
5863
5964 @staticmethod
@@ -73,15 +78,21 @@ def i1i2i3_transform() -> torch.Tensor:
7378 ]
7479 return torch .Tensor (i1i2i3_matrix )
7580
76- def __init__ (self , transform_name : str = "klt" ) -> None :
81+ def __init__ (self , transform : Union [ str , torch . Tensor ] = "klt" ) -> None :
7782 super ().__init__ ()
78-
79- if transform_name == "klt" :
83+ assert isinstance (transform , str ) or torch .is_tensor (transform )
84+ if torch .is_tensor (transform ):
85+ transform = cast (torch .Tensor , transform )
86+ assert list (transform .shape ) == [3 , 3 ]
87+ self .register_buffer ("transform" , transform )
88+ elif transform == "klt" :
8089 self .register_buffer ("transform" , ToRGB .klt_transform ())
81- elif transform_name == "i1i2i3" :
90+ elif transform == "i1i2i3" :
8291 self .register_buffer ("transform" , ToRGB .i1i2i3_transform ())
8392 else :
84- raise ValueError ("transform_name has to be either 'klt' or 'i1i2i3'" )
93+ raise ValueError (
94+ "transform has to be either 'klt', 'i1i2i3'," + " or a matrix tensor."
95+ )
8596
8697 def forward (self , x : torch .Tensor , inverse : bool = False ) -> torch .Tensor :
8798 assert x .dim () == 3 or x .dim () == 4
@@ -118,60 +129,74 @@ def forward(self, x: torch.Tensor, inverse: bool = False) -> torch.Tensor:
118129
119130class CenterCrop (torch .nn .Module ):
120131 """
121- Center crop the specified amount of pixels from the edges .
132+ Center crop a specified amount from a tensor .
122133 Arguments:
123- size (int, sequence) or (int): Number of pixels to center crop away.
134+ size (int, sequence, int): Number of pixels to center crop away.
135+ pixels_from_edges (bool, optional): Whether to treat crop size
136+ values as the number of pixels from the tensor's edge, or an
137+ exact shape in the center.
124138 """
125139
126- def __init__ (self , size : TransformSize = 0 ) -> None :
140+ def __init__ (
141+ self , size : IntSeqOrIntType = 0 , pixels_from_edges : bool = False
142+ ) -> None :
127143 super (CenterCrop , self ).__init__ ()
128- if type (size ) is list or type (size ) is tuple :
129- assert len (size ) == 2 , (
130- "CenterCrop requires a single crop value or a tuple of (height,width)"
131- + "in pixels for cropping."
132- )
133- self .crop_val = size
134- else :
135- self .crop_val = [size ] * 2
144+ self .crop_vals = size
145+ self .pixels_from_edges = pixels_from_edges
136146
137147 def forward (self , input : torch .Tensor ) -> torch .Tensor :
138- assert (
139- input .dim () == 3 or input .dim () == 4
140- ), "Input to CenterCrop must be 3D or 4D"
141- if input .dim () == 4 :
142- h , w = input .size (2 ), input .size (3 )
143- elif input .dim () == 3 :
144- h , w = input .size (1 ), input .size (2 )
145- h_crop = h - self .crop_val [0 ]
146- w_crop = w - self .crop_val [1 ]
147- sw , sh = w // 2 - (w_crop // 2 ), h // 2 - (h_crop // 2 )
148- return input [..., sh : sh + h_crop , sw : sw + w_crop ]
148+ """
149+ Center crop an input.
150+ Arguments:
151+ input (torch.Tensor): Input to center crop.
152+ Returns:
153+ tensor (torch.Tensor): A center cropped tensor.
154+ """
155+
156+ return center_crop (input , self .crop_vals , self .pixels_from_edges )
149157
150158
151- def center_crop_shape (input : torch .Tensor , output_size : List [int ]) -> torch .Tensor :
159+ def center_crop (
160+ input : torch .Tensor , crop_vals : IntSeqOrIntType , pixels_from_edges : bool = False
161+ ) -> torch .Tensor :
152162 """
153- Crop NCHW & CHW outputs by specifying the desired output shape.
163+ Center crop a specified amount from a tensor.
164+ Arguments:
165+ input (tensor): A CHW or NCHW image tensor to center crop.
166+ size (int, sequence, int): Number of pixels to center crop away.
167+ pixels_from_edges (bool, optional): Whether to treat crop size
168+ values as the number of pixels from the tensor's edge, or an
169+ exact shape in the center.
170+ Returns:
171+ *tensor*: A center cropped tensor.
154172 """
155173
156- assert input .dim () == 4 or input .dim () == 3
157- output_size = [output_size ] if not hasattr (output_size , "__iter__" ) else output_size
158- assert len (output_size ) == 1 or len (output_size ) == 2
159- output_size = output_size * 2 if len (output_size ) == 1 else output_size
174+ assert input .dim () == 3 or input .dim () == 4
175+ crop_vals = [crop_vals ] if not hasattr (crop_vals , "__iter__" ) else crop_vals
176+ crop_vals = cast (Union [List [int ], Tuple [int ], Tuple [int , int ]], crop_vals )
177+ assert len (crop_vals ) == 1 or len (crop_vals ) == 2
178+ crop_vals = crop_vals * 2 if len (crop_vals ) == 1 else crop_vals
160179
161180 if input .dim () == 4 :
162181 h , w = input .size (2 ), input .size (3 )
163182 if input .dim () == 3 :
164183 h , w = input .size (1 ), input .size (2 )
165184
166- h_crop = h - int (round ((h - output_size [0 ]) / 2.0 ))
167- w_crop = w - int (round ((w - output_size [1 ]) / 2.0 ))
168-
169- return input [
170- ..., h_crop - output_size [0 ] : h_crop , w_crop - output_size [1 ] : w_crop
171- ]
185+ if pixels_from_edges :
186+ h_crop = h - crop_vals [0 ]
187+ w_crop = w - crop_vals [1 ]
188+ sw , sh = w // 2 - (w_crop // 2 ), h // 2 - (h_crop // 2 )
189+ x = input [..., sh : sh + h_crop , sw : sw + w_crop ]
190+ else :
191+ h_crop = h - int (round ((h - crop_vals [0 ]) / 2.0 ))
192+ w_crop = w - int (round ((w - crop_vals [1 ]) / 2.0 ))
193+ x = input [..., h_crop - crop_vals [0 ] : h_crop , w_crop - crop_vals [1 ] : w_crop ]
194+ return x
172195
173196
174- def rand_select (transform_values : TransformValList ) -> TransformVal :
197+ def rand_select (
198+ transform_values : NumSeqOrTensorType ,
199+ ) -> Union [int , float , torch .Tensor ]:
175200 """
176201 Randomly return a value from the provided tuple or list
177202 """
@@ -186,19 +211,21 @@ class RandomScale(nn.Module):
186211 scale (float, sequence): Tuple of rescaling values to randomly select from.
187212 """
188213
189- def __init__ (self , scale : TransformValList ) -> None :
214+ def __init__ (self , scale : NumSeqOrTensorType ) -> None :
190215 super (RandomScale , self ).__init__ ()
191216 self .scale = scale
192217
193218 def get_scale_mat (
194- self , m : TransformVal , device : torch .device , dtype : torch .dtype
219+ self , m : IntSeqOrIntType , device : torch .device , dtype : torch .dtype
195220 ) -> torch .Tensor :
196221 scale_mat = torch .tensor (
197222 [[m , 0.0 , 0.0 ], [0.0 , m , 0.0 ]], device = device , dtype = dtype
198223 )
199224 return scale_mat
200225
201- def scale_tensor (self , x : torch .Tensor , scale : TransformVal ) -> torch .Tensor :
226+ def scale_tensor (
227+ self , x : torch .Tensor , scale : Union [int , float , torch .Tensor ]
228+ ) -> torch .Tensor :
202229 scale_matrix = self .get_scale_mat (scale , x .device , x .dtype )[None , ...].repeat (
203230 x .shape [0 ], 1 , 1
204231 )
0 commit comments