-
Notifications
You must be signed in to change notification settings - Fork 7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[BC-breaking] ColorJitter gets its random params by calling get_params() #3001
Conversation
Codecov Report
@@ Coverage Diff @@
## master #3001 +/- ##
==========================================
+ Coverage 73.39% 73.51% +0.12%
==========================================
Files 99 99
Lines 8825 8806 -19
Branches 1391 1383 -8
==========================================
- Hits 6477 6474 -3
+ Misses 1929 1913 -16
Partials 419 419
Continue to review full report at Codecov.
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@datumbox looks awesome !
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot!
…s() (pytorch#3001) * ColorJitter gets its random params by calling get_params(). * Update arguments. * Styles. * Add description for Nones. * Chainging Nones to optional.
…s() (pytorch#3001) * ColorJitter gets its random params by calling get_params(). * Update arguments. * Styles. * Add description for Nones. * Chainging Nones to optional.
This PR is not backward compatible. In previous versions,
when used like color_jitter = transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1)
transform = transforms.ColorJitter.get_params(
color_jitter.brightness, color_jitter.contrast, color_jitter.saturation,
color_jitter.hue)
img_trans = transform(img) |
@zshn25 indeed this is a backward incompatible change, but it's probably for the best: the previous use of [EDIT: below is wrong] color_jitter = transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1)
img_trans = color_jitter(img) to get the same results as before |
@NicolasHug, Thanks for clarification. Yes, indeed the following works
but I'm interested in having the same |
Hi @zshn25 Thanks for bringing this up. Indeed this is a BC-breaking change, sorry for breaking your code. I would propose to use the following helper function for what you want to achieve: def get_random_color_jitter(
brightness: Optional[List[float]],
contrast: Optional[List[float]],
saturation: Optional[List[float]],
hue: Optional[List[float]]
) -> Tuple[Tensor, Optional[float], Optional[float], Optional[float], Optional[float]]:
transforms = []
if brightness is not None:
brightness_factor = random.uniform(brightness[0], brightness[1])
transforms.append(Lambda(lambda img: F.adjust_brightness(img, brightness_factor)))
if contrast is not None:
contrast_factor = random.uniform(contrast[0], contrast[1])
transforms.append(Lambda(lambda img: F.adjust_contrast(img, contrast_factor)))
if saturation is not None:
saturation_factor = random.uniform(saturation[0], saturation[1])
transforms.append(Lambda(lambda img: F.adjust_saturation(img, saturation_factor)))
if hue is not None:
hue_factor = random.uniform(hue[0], hue[1])
transforms.append(Lambda(lambda img: F.adjust_hue(img, hue_factor)))
random.shuffle(transforms)
transform = Compose(transforms)
return transform This is basically a re-implementation of the previous method, and would also easily allow you to add more color augmentations if you wanted (like solalize, equalize, etc, which have been recently introduced). Please let us know if this solution doesn't fit your needs. |
As per @fmassa's comment, I'm rewriting ColorJitter so that it receives its random parameters from the
get_params()
method. The ColorJitter remains torch-scriptable.Affects #2672 and #2669.