Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BC-breaking] ColorJitter gets its random params by calling get_params() #3001

Merged
merged 5 commits into from
Nov 16, 2020

Conversation

datumbox
Copy link
Contributor

@datumbox datumbox commented Nov 13, 2020

As per @fmassa's comment, I'm rewriting ColorJitter so that it receives its random parameters from the get_params() method. The ColorJitter remains torch-scriptable.

Affects #2672 and #2669.

@datumbox datumbox changed the title [WIP] ColorJitter gets its random params by calling get_params() [BC-breaking] ColorJitter gets its random params by calling get_params() Nov 13, 2020
@datumbox datumbox requested review from fmassa and vfdev-5 November 13, 2020 14:04
@codecov
Copy link

codecov bot commented Nov 13, 2020

Codecov Report

Merging #3001 (66cdc72) into master (80f41f8) will increase coverage by 0.12%.
The diff coverage is 100.00%.

Impacted file tree graph

@@            Coverage Diff             @@
##           master    #3001      +/-   ##
==========================================
+ Coverage   73.39%   73.51%   +0.12%     
==========================================
  Files          99       99              
  Lines        8825     8806      -19     
  Branches     1391     1383       -8     
==========================================
- Hits         6477     6474       -3     
+ Misses       1929     1913      -16     
  Partials      419      419              
Impacted Files Coverage Δ
torchvision/transforms/transforms.py 82.57% <100.00%> (+2.06%) ⬆️

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update 80f41f8...66cdc72. Read the comment docs.

Copy link
Collaborator

@vfdev-5 vfdev-5 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@datumbox looks awesome !

torchvision/transforms/transforms.py Outdated Show resolved Hide resolved
Copy link
Member

@fmassa fmassa left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot!

@fmassa fmassa merged commit 76ebe92 into pytorch:master Nov 16, 2020
@datumbox datumbox deleted the refactor/colorjitter branch November 16, 2020 12:22
bryant1410 pushed a commit to bryant1410/vision-1 that referenced this pull request Nov 22, 2020
…s() (pytorch#3001)

* ColorJitter gets its random params by calling get_params().

* Update arguments.

* Styles.

* Add description for Nones.

* Chainging Nones to optional.
vfdev-5 pushed a commit to Quansight/vision that referenced this pull request Dec 4, 2020
…s() (pytorch#3001)

* ColorJitter gets its random params by calling get_params().

* Update arguments.

* Styles.

* Add description for Nones.

* Chainging Nones to optional.
@zshn25
Copy link
Contributor

zshn25 commented Apr 7, 2021

This PR is not backward compatible. In previous versions, get_params method would return a callable transform. Not it just returns the params. As a result, this gives an error as follows

TypeError: 'tuple' object is not callable

when used like

color_jitter = transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1)
transform = transforms.ColorJitter.get_params(
    color_jitter.brightness, color_jitter.contrast, color_jitter.saturation,
    color_jitter.hue)

img_trans = transform(img)

@NicolasHug
Copy link
Member

NicolasHug commented Apr 7, 2021

@zshn25 indeed this is a backward incompatible change, but it's probably for the best: the previous use of get_params was unconventional in the sense that get_params didn't behave as what you would expect from a "getter".

[EDIT: below is wrong]
I believe your code can be changed to simply

color_jitter = transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1)
img_trans = color_jitter(img)

to get the same results as before

@zshn25
Copy link
Contributor

zshn25 commented Apr 7, 2021

@NicolasHug, Thanks for clarification. Yes, indeed the following works

color_jitter = transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1)
img_trans = color_jitter(img)

but I'm interested in having the same ColorJitter transform on many images. I used to use get_params and use that transform on all images. How would I do this now?

@fmassa
Copy link
Member

fmassa commented Apr 7, 2021

Hi @zshn25

Thanks for bringing this up. Indeed this is a BC-breaking change, sorry for breaking your code.
We wanted to uniformize how the transforms handled get_params, and in some sense the previous implementation of ColorJitter as "wrong".

I would propose to use the following helper function for what you want to achieve:

def get_random_color_jitter(
        brightness: Optional[List[float]],
        contrast: Optional[List[float]],
        saturation: Optional[List[float]],
        hue: Optional[List[float]]
    ) -> Tuple[Tensor, Optional[float], Optional[float], Optional[float], Optional[float]]:

    transforms = []

    if brightness is not None:
        brightness_factor = random.uniform(brightness[0], brightness[1])
        transforms.append(Lambda(lambda img: F.adjust_brightness(img, brightness_factor)))

    if contrast is not None:
        contrast_factor = random.uniform(contrast[0], contrast[1])
        transforms.append(Lambda(lambda img: F.adjust_contrast(img, contrast_factor)))

    if saturation is not None:
        saturation_factor = random.uniform(saturation[0], saturation[1])
        transforms.append(Lambda(lambda img: F.adjust_saturation(img, saturation_factor)))

    if hue is not None:
        hue_factor = random.uniform(hue[0], hue[1])
        transforms.append(Lambda(lambda img: F.adjust_hue(img, hue_factor)))

    random.shuffle(transforms)
    transform = Compose(transforms)

    return transform

This is basically a re-implementation of the previous method, and would also easily allow you to add more color augmentations if you wanted (like solalize, equalize, etc, which have been recently introduced).

Please let us know if this solution doesn't fit your needs.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants