Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BC-breaking] ColorJitter gets its random params by calling get_params() #3001

Merged
merged 5 commits into from
Nov 16, 2020
Merged
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 30 additions & 42 deletions torchvision/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -1051,38 +1051,35 @@ def _check_input(self, value, name, center=1, bound=(0, float('inf')), clip_firs
return value

@staticmethod
@torch.jit.unused
def get_params(brightness, contrast, saturation, hue):
"""Get a randomized transform to be applied on image.
def get_params(brightness: Optional[List[float]],
contrast: Optional[List[float]],
saturation: Optional[List[float]],
hue: Optional[List[float]]
) -> Tuple[Tensor, Optional[float], Optional[float], Optional[float], Optional[float]]:
"""Get the parameters for the randomized transform to be applied on image.

Arguments are same as that of __init__.
Args:
brightness (tuple of float (min, max) or None): The range from which the brightness_factor is chosen
datumbox marked this conversation as resolved.
Show resolved Hide resolved
uniformly. Pass None to turn off the transformation.
contrast (tuple of float (min, max) or None): The range from which the contrast_factor is chosen
uniformly. Pass None to turn off the transformation.
saturation (tuple of float (min, max) or None): The range from which the saturation_factor is chosen
uniformly. Pass None to turn off the transformation.
hue (tuple of float (min, max) or None): The range from which the hue_factor is chosen uniformly.
Pass None to turn off the transformation.

Returns:
Transform which randomly adjusts brightness, contrast and
saturation in a random order.
tuple: The parameters used to apply the randomized transform
along with their random order.
"""
transforms = []

if brightness is not None:
brightness_factor = random.uniform(brightness[0], brightness[1])
transforms.append(Lambda(lambda img: F.adjust_brightness(img, brightness_factor)))

if contrast is not None:
contrast_factor = random.uniform(contrast[0], contrast[1])
transforms.append(Lambda(lambda img: F.adjust_contrast(img, contrast_factor)))

if saturation is not None:
saturation_factor = random.uniform(saturation[0], saturation[1])
transforms.append(Lambda(lambda img: F.adjust_saturation(img, saturation_factor)))

if hue is not None:
hue_factor = random.uniform(hue[0], hue[1])
transforms.append(Lambda(lambda img: F.adjust_hue(img, hue_factor)))
fn_idx = torch.randperm(4)

random.shuffle(transforms)
transform = Compose(transforms)
b = None if brightness is None else float(torch.empty(1).uniform_(brightness[0], brightness[1]))
c = None if contrast is None else float(torch.empty(1).uniform_(contrast[0], contrast[1]))
s = None if saturation is None else float(torch.empty(1).uniform_(saturation[0], saturation[1]))
h = None if hue is None else float(torch.empty(1).uniform_(hue[0], hue[1]))

return transform
return fn_idx, b, c, s, h

def forward(self, img):
"""
Expand All @@ -1092,26 +1089,17 @@ def forward(self, img):
Returns:
PIL Image or Tensor: Color jittered image.
"""
fn_idx = torch.randperm(4)
fn_idx, brightness_factor, contrast_factor, saturation_factor, hue_factor = \
self.get_params(self.brightness, self.contrast, self.saturation, self.hue)

for fn_id in fn_idx:
if fn_id == 0 and self.brightness is not None:
brightness = self.brightness
brightness_factor = torch.tensor(1.0).uniform_(brightness[0], brightness[1]).item()
if fn_id == 0 and brightness_factor is not None:
img = F.adjust_brightness(img, brightness_factor)

if fn_id == 1 and self.contrast is not None:
contrast = self.contrast
contrast_factor = torch.tensor(1.0).uniform_(contrast[0], contrast[1]).item()
elif fn_id == 1 and contrast_factor is not None:
img = F.adjust_contrast(img, contrast_factor)

if fn_id == 2 and self.saturation is not None:
saturation = self.saturation
saturation_factor = torch.tensor(1.0).uniform_(saturation[0], saturation[1]).item()
elif fn_id == 2 and saturation_factor is not None:
img = F.adjust_saturation(img, saturation_factor)

if fn_id == 3 and self.hue is not None:
hue = self.hue
hue_factor = torch.tensor(1.0).uniform_(hue[0], hue[1]).item()
elif fn_id == 3 and hue_factor is not None:
img = F.adjust_hue(img, hue_factor)

return img
Expand Down