Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[prototype] Speed improvement for normalize op #6821

Merged
merged 7 commits into from
Oct 24, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 36 additions & 1 deletion torchvision/prototype/transforms/functional/_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,42 @@
from torchvision.transforms import functional_tensor as _FT
from torchvision.transforms.functional import pil_to_tensor, to_pil_image

normalize_image_tensor = _FT.normalize

def normalize_image_tensor(
image: torch.Tensor, mean: List[float], std: List[float], inplace: bool = False
) -> torch.Tensor:
if not image.is_floating_point():
raise TypeError(f"Input tensor should be a float tensor. Got {image.dtype}.")

if image.ndim < 3:
raise ValueError(
f"Expected tensor to be a tensor image of size (..., C, H, W). Got tensor.size() = {image.size()}"
)

if isinstance(std, (tuple, list)):
divzero = not all(std)
elif isinstance(std, (int, float)):
divzero = std == 0
else:
divzero = False
if divzero:
raise ValueError("std evaluated to zero, leading to division by zero.")

dtype = image.dtype
device = image.device
mean = torch.as_tensor(mean, dtype=dtype, device=device)
std = torch.as_tensor(std, dtype=dtype, device=device)
if mean.ndim == 1:
mean = mean.view(-1, 1, 1)
if std.ndim == 1:
std = std.view(-1, 1, 1)
Comment on lines +36 to +39
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was also looking into this earlier and one thing I asked myself, is when would this branch not trigger? The tensor should always have one dimensions unless we allow scalars. See above for that.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is purely for broadcasting in case someone passes lists, not scalars. Aka [0.5, 0.5, 0.5]. This is needed else, the following div/sub fails.


if inplace:
image = image.sub_(mean)
else:
image = image.sub(mean)

return image.div_(std)


def normalize_video(video: torch.Tensor, mean: List[float], std: List[float], inplace: bool = False) -> torch.Tensor:
Expand Down