diff --git a/torchvision/transforms/transforms.py b/torchvision/transforms/transforms.py index be835bdd213..af74a3188f3 100644 --- a/torchvision/transforms/transforms.py +++ b/torchvision/transforms/transforms.py @@ -1051,38 +1051,35 @@ def _check_input(self, value, name, center=1, bound=(0, float('inf')), clip_firs return value @staticmethod - @torch.jit.unused - def get_params(brightness, contrast, saturation, hue): - """Get a randomized transform to be applied on image. + def get_params(brightness: Optional[List[float]], + contrast: Optional[List[float]], + saturation: Optional[List[float]], + hue: Optional[List[float]] + ) -> Tuple[Tensor, Optional[float], Optional[float], Optional[float], Optional[float]]: + """Get the parameters for the randomized transform to be applied on image. - Arguments are same as that of __init__. + Args: + brightness (tuple of float (min, max), optional): The range from which the brightness_factor is chosen + uniformly. Pass None to turn off the transformation. + contrast (tuple of float (min, max), optional): The range from which the contrast_factor is chosen + uniformly. Pass None to turn off the transformation. + saturation (tuple of float (min, max), optional): The range from which the saturation_factor is chosen + uniformly. Pass None to turn off the transformation. + hue (tuple of float (min, max), optional): The range from which the hue_factor is chosen uniformly. + Pass None to turn off the transformation. Returns: - Transform which randomly adjusts brightness, contrast and - saturation in a random order. + tuple: The parameters used to apply the randomized transform + along with their random order. """ - transforms = [] - - if brightness is not None: - brightness_factor = random.uniform(brightness[0], brightness[1]) - transforms.append(Lambda(lambda img: F.adjust_brightness(img, brightness_factor))) - - if contrast is not None: - contrast_factor = random.uniform(contrast[0], contrast[1]) - transforms.append(Lambda(lambda img: F.adjust_contrast(img, contrast_factor))) - - if saturation is not None: - saturation_factor = random.uniform(saturation[0], saturation[1]) - transforms.append(Lambda(lambda img: F.adjust_saturation(img, saturation_factor))) - - if hue is not None: - hue_factor = random.uniform(hue[0], hue[1]) - transforms.append(Lambda(lambda img: F.adjust_hue(img, hue_factor))) + fn_idx = torch.randperm(4) - random.shuffle(transforms) - transform = Compose(transforms) + b = None if brightness is None else float(torch.empty(1).uniform_(brightness[0], brightness[1])) + c = None if contrast is None else float(torch.empty(1).uniform_(contrast[0], contrast[1])) + s = None if saturation is None else float(torch.empty(1).uniform_(saturation[0], saturation[1])) + h = None if hue is None else float(torch.empty(1).uniform_(hue[0], hue[1])) - return transform + return fn_idx, b, c, s, h def forward(self, img): """ @@ -1092,26 +1089,17 @@ def forward(self, img): Returns: PIL Image or Tensor: Color jittered image. """ - fn_idx = torch.randperm(4) + fn_idx, brightness_factor, contrast_factor, saturation_factor, hue_factor = \ + self.get_params(self.brightness, self.contrast, self.saturation, self.hue) + for fn_id in fn_idx: - if fn_id == 0 and self.brightness is not None: - brightness = self.brightness - brightness_factor = torch.tensor(1.0).uniform_(brightness[0], brightness[1]).item() + if fn_id == 0 and brightness_factor is not None: img = F.adjust_brightness(img, brightness_factor) - - if fn_id == 1 and self.contrast is not None: - contrast = self.contrast - contrast_factor = torch.tensor(1.0).uniform_(contrast[0], contrast[1]).item() + elif fn_id == 1 and contrast_factor is not None: img = F.adjust_contrast(img, contrast_factor) - - if fn_id == 2 and self.saturation is not None: - saturation = self.saturation - saturation_factor = torch.tensor(1.0).uniform_(saturation[0], saturation[1]).item() + elif fn_id == 2 and saturation_factor is not None: img = F.adjust_saturation(img, saturation_factor) - - if fn_id == 3 and self.hue is not None: - hue = self.hue - hue_factor = torch.tensor(1.0).uniform_(hue[0], hue[1]).item() + elif fn_id == 3 and hue_factor is not None: img = F.adjust_hue(img, hue_factor) return img