From 31954e51734f20af4733735dffd9dbd672669129 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Mon, 24 Oct 2022 13:39:26 +0100 Subject: [PATCH 1/5] Avoid GPU-CPU sync on Normalize --- .../prototype/transforms/functional/_misc.py | 31 ++++++++++++++++++- 1 file changed, 30 insertions(+), 1 deletion(-) diff --git a/torchvision/prototype/transforms/functional/_misc.py b/torchvision/prototype/transforms/functional/_misc.py index fa4a6e9be73..d3594a68ed3 100644 --- a/torchvision/prototype/transforms/functional/_misc.py +++ b/torchvision/prototype/transforms/functional/_misc.py @@ -8,7 +8,36 @@ from torchvision.transforms import functional_tensor as _FT from torchvision.transforms.functional import pil_to_tensor, to_pil_image -normalize_image_tensor = _FT.normalize + +def normalize_image_tensor( + image: torch.Tensor, mean: List[float], std: List[float], inplace: bool = False +) -> torch.Tensor: + if not isinstance(image, torch.Tensor): + raise TypeError("Input img should be Tensor image") + + if not image.is_floating_point(): + raise TypeError(f"Input tensor should be a float tensor. Got {image.dtype}.") + + if image.ndim < 3: + raise ValueError( + f"Expected tensor to be a tensor image of size (..., C, H, W). Got tensor.size() = {image.size()}" + ) + + if (isinstance(std, (tuple, list)) and any(std)) or std == 0: + raise ValueError(f"std evaluated to zero after conversion to {image.dtype}, leading to division by zero.") + + if not inplace: + image = image.clone() + + dtype = image.dtype + mean = torch.as_tensor(mean, dtype=dtype, device=image.device) + std = torch.as_tensor(std, dtype=dtype, device=image.device) + if mean.ndim == 1: + mean = mean.view(-1, 1, 1) + if std.ndim == 1: + std = std.view(-1, 1, 1) + + return image.sub_(mean).div_(std) def normalize_video(video: torch.Tensor, mean: List[float], std: List[float], inplace: bool = False) -> torch.Tensor: From 208072556f5ac473bc8d265eee2dd02b9130ca7f Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Mon, 24 Oct 2022 13:54:14 +0100 Subject: [PATCH 2/5] Further optimizations. --- .../prototype/transforms/functional/_misc.py | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/torchvision/prototype/transforms/functional/_misc.py b/torchvision/prototype/transforms/functional/_misc.py index d3594a68ed3..41363ba37e4 100644 --- a/torchvision/prototype/transforms/functional/_misc.py +++ b/torchvision/prototype/transforms/functional/_misc.py @@ -23,21 +23,24 @@ def normalize_image_tensor( f"Expected tensor to be a tensor image of size (..., C, H, W). Got tensor.size() = {image.size()}" ) - if (isinstance(std, (tuple, list)) and any(std)) or std == 0: + if (isinstance(std, (tuple, list)) and not all(std)) or std == 0: raise ValueError(f"std evaluated to zero after conversion to {image.dtype}, leading to division by zero.") - if not inplace: - image = image.clone() - dtype = image.dtype - mean = torch.as_tensor(mean, dtype=dtype, device=image.device) - std = torch.as_tensor(std, dtype=dtype, device=image.device) + device = image.device + mean = torch.as_tensor(mean, dtype=dtype, device=device) + std = torch.as_tensor(std, dtype=dtype, device=device) if mean.ndim == 1: mean = mean.view(-1, 1, 1) if std.ndim == 1: std = std.view(-1, 1, 1) - return image.sub_(mean).div_(std) + if inplace: + image = image.sub_(mean) + else: + image = image.sub(mean) + + return image.div_(std) def normalize_video(video: torch.Tensor, mean: List[float], std: List[float], inplace: bool = False) -> torch.Tensor: From 8842afa499e6b47cfea9050b4ca03f3ef6398089 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Mon, 24 Oct 2022 14:27:53 +0100 Subject: [PATCH 3/5] Apply code review changes. --- torchvision/prototype/transforms/functional/_misc.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/torchvision/prototype/transforms/functional/_misc.py b/torchvision/prototype/transforms/functional/_misc.py index 41363ba37e4..c258a7e72a8 100644 --- a/torchvision/prototype/transforms/functional/_misc.py +++ b/torchvision/prototype/transforms/functional/_misc.py @@ -12,9 +12,6 @@ def normalize_image_tensor( image: torch.Tensor, mean: List[float], std: List[float], inplace: bool = False ) -> torch.Tensor: - if not isinstance(image, torch.Tensor): - raise TypeError("Input img should be Tensor image") - if not image.is_floating_point(): raise TypeError(f"Input tensor should be a float tensor. Got {image.dtype}.") @@ -24,7 +21,7 @@ def normalize_image_tensor( ) if (isinstance(std, (tuple, list)) and not all(std)) or std == 0: - raise ValueError(f"std evaluated to zero after conversion to {image.dtype}, leading to division by zero.") + raise ValueError(f"std evaluated to zero, leading to division by zero.") dtype = image.dtype device = image.device From afd5b1e2ba8f85c503f08f87dcbbe5e9ac94982a Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Mon, 24 Oct 2022 14:34:41 +0100 Subject: [PATCH 4/5] Fixing JIT. --- torchvision/prototype/transforms/functional/_misc.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/torchvision/prototype/transforms/functional/_misc.py b/torchvision/prototype/transforms/functional/_misc.py index c258a7e72a8..6f9c97f9d5f 100644 --- a/torchvision/prototype/transforms/functional/_misc.py +++ b/torchvision/prototype/transforms/functional/_misc.py @@ -20,7 +20,13 @@ def normalize_image_tensor( f"Expected tensor to be a tensor image of size (..., C, H, W). Got tensor.size() = {image.size()}" ) - if (isinstance(std, (tuple, list)) and not all(std)) or std == 0: + if isinstance(std, (tuple, list)): + divzero = not all(std) + elif isinstance(std, (int, float)): + divzero = std == 0 + else: + divzero = False + if divzero: raise ValueError(f"std evaluated to zero, leading to division by zero.") dtype = image.dtype From 9a1de929d171d8c69c367dd4af4668b67a647d9c Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Mon, 24 Oct 2022 14:35:20 +0100 Subject: [PATCH 5/5] linter fix --- torchvision/prototype/transforms/functional/_misc.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchvision/prototype/transforms/functional/_misc.py b/torchvision/prototype/transforms/functional/_misc.py index 6f9c97f9d5f..3a1d8575cd0 100644 --- a/torchvision/prototype/transforms/functional/_misc.py +++ b/torchvision/prototype/transforms/functional/_misc.py @@ -27,7 +27,7 @@ def normalize_image_tensor( else: divzero = False if divzero: - raise ValueError(f"std evaluated to zero, leading to division by zero.") + raise ValueError("std evaluated to zero, leading to division by zero.") dtype = image.dtype device = image.device