From c9719b2f284f67fe73f33f7e4be4d543371347c3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Rozet?= Date: Fri, 4 Dec 2020 19:38:37 +0100 Subject: [PATCH] =?UTF-8?q?=E2=99=BB=EF=B8=8F=20Remove=20repetitive=20redu?= =?UTF-8?q?ction=20code=20segment?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- spiq/lpips.py | 11 +++-------- spiq/psnr.py | 11 ++++------- spiq/ssim.py | 18 ++++-------------- spiq/tv.py | 11 ++++------- spiq/utils.py | 23 ++++++++++++++++++++++- 5 files changed, 37 insertions(+), 37 deletions(-) diff --git a/spiq/lpips.py b/spiq/lpips.py index 8b02899..86cf0d6 100644 --- a/spiq/lpips.py +++ b/spiq/lpips.py @@ -17,7 +17,7 @@ import torch.nn as nn import torchvision.models as models -from spiq.utils import normalize_tensor, Intermediary +from spiq.utils import build_reduce, normalize_tensor, Intermediary _SHIFT = torch.Tensor([0.485, 0.456, 0.406]) _SCALE = torch.Tensor([0.229, 0.224, 0.225]) @@ -85,7 +85,7 @@ def __init__( for y in x: y.requires_grad = False - self.reduction = reduction + self.reduce = build_reduce(reduction) def forward( self, @@ -106,9 +106,4 @@ def forward( l = torch.cat(residuals, dim=-1).sum(dim=-1) - if self.reduction == 'mean': - l = l.mean() - elif self.reduction == 'sum': - l = l.sum() - - return l + return self.reduce(l) diff --git a/spiq/psnr.py b/spiq/psnr.py index df976a1..8777185 100644 --- a/spiq/psnr.py +++ b/spiq/psnr.py @@ -9,6 +9,8 @@ import torch import torch.nn as nn +from spiq.utils import build_reduce + from typing import Tuple @@ -51,7 +53,7 @@ def __init__(self, value_range: float = 1., reduction: str = 'mean'): super().__init__() self.value_range = value_range - self.reduction = reduction + self.reduce = build_reduce(reduction) def forward( self, @@ -65,9 +67,4 @@ def forward( value_range=self.value_range, ) - if self.reduction == 'mean': - l = l.mean() - elif self.reduction == 'sum': - l = l.sum() - - return l + return self.reduce(l) diff --git a/spiq/ssim.py b/spiq/ssim.py index 657a5fb..c36b2cd 100644 --- a/spiq/ssim.py +++ b/spiq/ssim.py @@ -22,7 +22,7 @@ import torch.nn as nn import torch.nn.functional as F -from spiq.utils import gaussian_kernel +from spiq.utils import build_reduce, gaussian_kernel _SIGMA = 1.5 _K1, _K2 = 0.01, 0.03 @@ -191,7 +191,7 @@ def __init__( self.register_buffer('window', create_window(window_size, n_channels)) self.value_range = value_range - self.reduction = reduction + self.reduce = build_reduce(reduction) def forward( self, @@ -205,12 +205,7 @@ def forward( value_range=self.value_range, )[0].mean(-1) - if self.reduction == 'mean': - l = l.mean() - elif self.reduction == 'sum': - l = l.sum() - - return l + return self.reduce(l) class MSSSIM(SSIM): @@ -244,9 +239,4 @@ def forward( weights=self.weights, ).mean(-1) - if self.reduction == 'mean': - l = l.mean() - elif self.reduction == 'sum': - l = l.sum() - - return l + return self.reduce(l) diff --git a/spiq/tv.py b/spiq/tv.py index e318f94..53d3735 100644 --- a/spiq/tv.py +++ b/spiq/tv.py @@ -9,6 +9,8 @@ import torch import torch.nn as nn +from spiq.utils import build_reduce + def tv(x: torch.Tensor, norm: str = 'L2') -> torch.Tensor: r"""Returns the TV of `x`. @@ -51,14 +53,9 @@ def __init__(self, norm: str = 'L2', reduction: str = 'mean'): super().__init__() self.norm = norm - self.reduction = reduction + self.reduce = build_reduce(reduction) def forward(self, input: torch.Tensor) -> torch.Tensor: l = tv(input, norm=self.norm) - if self.reduction == 'mean': - l = l.mean() - elif self.reduction == 'sum': - l = l.sum() - - return l + return self.reduce(l) diff --git a/spiq/utils.py b/spiq/utils.py index 6eedebd..946c9dd 100644 --- a/spiq/utils.py +++ b/spiq/utils.py @@ -4,7 +4,28 @@ import torch import torch.nn as nn -from typing import List, Tuple +from typing import Callable, List, Tuple + + +def build_reduce( + reduction: str = 'mean', + dim: Tuple[int, ...] = (), + keepdim: bool = False, +) -> Callable[[torch.Tensor], torch.Tensor]: + r"""Returns a reduce function. + + Args: + reduction: A reduction type (`'mean'`, `'sum'` or `'none'`). + dim: The dimension(s) along which to reduce. + keepdim: Whether the output tensor has `dim` retained or not. + """ + + if reduction == 'mean': + return lambda x: x.mean(dim=dim, keepdim=keepdim) + elif reduction == 'sum': + return lambda x: x.sum(dim=dim, keepdim=keepdim) + + return lambda x: x def gaussian_kernel(