-
Notifications
You must be signed in to change notification settings - Fork 7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[prototype] Speed improvement for normalize op #6821
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I looked at this earlier and saw no possible optimizations. It seems you have better eyes 😛
LGTM, if CI is green.
f"Expected tensor to be a tensor image of size (..., C, H, W). Got tensor.size() = {image.size()}" | ||
) | ||
|
||
if (isinstance(std, (tuple, list)) and not all(std)) or std == 0: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need the first part of the check? What input would fail isinstance(std, (tuple, list))
? Do we actually allow scalars here? Otherwise, this should be sufficient
if (isinstance(std, (tuple, list)) and not all(std)) or std == 0: | |
if not all(std): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We actually allow scalars. It's not visible due to the JIT-script types but if you pass mean=0.5, std=0.5
it works. So I'm keeping this for BC and provide separate benchmarks.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ugh 🙄 We need to update the tests since they currently don't check scalars:
vision/test/prototype_transforms_kernel_infos.py
Lines 1945 to 1956 in 6979888
_NORMALIZE_MEANS_STDS = [ | |
((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), | |
([0.0, 0.0, 0.0], [1.0, 1.0, 1.0]), | |
] | |
def sample_inputs_normalize_image_tensor(): | |
for image_loader, (mean, std) in itertools.product( | |
make_image_loaders(sizes=["random"], color_spaces=[features.ColorSpace.RGB], dtypes=[torch.float32]), | |
_NORMALIZE_MEANS_STDS, | |
): | |
yield ArgsKwargs(image_loader, mean=mean, std=std) |
Will send a PR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good catch. I also had to rewrite the check because JIT couldn't understand the assertions were correct in one line... This version seems to pass. I've updated the benchmarks and we are still good.
if mean.ndim == 1: | ||
mean = mean.view(-1, 1, 1) | ||
if std.ndim == 1: | ||
std = std.view(-1, 1, 1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was also looking into this earlier and one thing I asked myself, is when would this branch not trigger? The tensor should always have one dimensions unless we allow scalars. See above for that.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this is purely for broadcasting in case someone passes lists, not scalars. Aka [0.5, 0.5, 0.5]. This is needed else, the following div/sub fails.
Summary: * Avoid GPU-CPU sync on Normalize * Further optimizations. * Apply code review changes. * Fixing JIT. * linter fix Reviewed By: YosuaMichael Differential Revision: D40722904 fbshipit-source-id: e452d89a42b34be852e3125d25756b3f598e50f4
This PR:
(std == 0).any()
idiom which caused synchronization between CPU and GPU (50% improvement in CUDA)The combined result is:
Modified benchmark script from here
cc @vfdev-5 @bjuncek @pmeier