diff --git a/mmedit/models/losses/clip_loss.py b/mmedit/models/losses/clip_loss.py index eaad01f4a4..27f5f27cba 100644 --- a/mmedit/models/losses/clip_loss.py +++ b/mmedit/models/losses/clip_loss.py @@ -1,4 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. +from typing import Optional + import torch import torch.nn as nn @@ -25,10 +27,10 @@ class CLIPLossModel(torch.nn.Module): """ def __init__(self, - in_size=1024, - scale_factor=7, - pool_size=224, - clip_type='ViT-B/32'): + in_size: int = 1024, + scale_factor: int = 7, + pool_size: int = 224, + clip_type: str = 'ViT-B/32') -> None: super(CLIPLossModel, self).__init__() try: import clip @@ -43,7 +45,9 @@ def __init__(self, self.avg_pool = torch.nn.AvgPool2d( kernel_size=(scale_factor * in_size // pool_size)) - def forward(self, image=None, text=None): + def forward(self, + image: torch.Tensor = None, + text: torch.Tensor = None) -> torch.Tensor: """Forward function.""" assert image is not None assert text is not None @@ -85,10 +89,10 @@ class CLIPLoss(nn.Module): """ def __init__(self, - loss_weight=1.0, - data_info=None, - clip_model=dict(), - loss_name='loss_clip'): + loss_weight: float = 1.0, + data_info: Optional[dict] = None, + clip_model: dict = dict(), + loss_name: str = 'loss_clip') -> None: super(CLIPLoss, self).__init__() self.loss_weight = loss_weight @@ -96,7 +100,7 @@ def __init__(self, self.net = CLIPLossModel(**clip_model) self._loss_name = loss_name - def forward(self, image, text): + def forward(self, image: torch.Tensor, text: torch.Tensor) -> torch.Tensor: """Forward function. If ``self.data_info`` is not ``None``, a dictionary containing all of diff --git a/mmedit/models/losses/composition_loss.py b/mmedit/models/losses/composition_loss.py index e38e7b2751..5aea6768f8 100644 --- a/mmedit/models/losses/composition_loss.py +++ b/mmedit/models/losses/composition_loss.py @@ -1,4 +1,7 @@ # Copyright (c) OpenMMLab. All rights reserved. +from typing import Optional + +import torch import torch.nn as nn from mmedit.registry import LOSSES @@ -22,7 +25,10 @@ class L1CompositionLoss(nn.Module): Default: False. """ - def __init__(self, loss_weight=1.0, reduction='mean', sample_wise=False): + def __init__(self, + loss_weight: float = 1.0, + reduction: str = 'mean', + sample_wise: bool = False) -> None: super().__init__() if reduction not in ['none', 'mean', 'sum']: raise ValueError(f'Unsupported reduction mode: {reduction}. ' @@ -32,7 +38,13 @@ def __init__(self, loss_weight=1.0, reduction='mean', sample_wise=False): self.reduction = reduction self.sample_wise = sample_wise - def forward(self, pred_alpha, fg, bg, ori_merged, weight=None, **kwargs): + def forward(self, + pred_alpha: torch.Tensor, + fg: torch.Tensor, + bg: torch.Tensor, + ori_merged: torch.Tensor, + weight: Optional[torch.Tensor] = None, + **kwargs) -> torch.Tensor: """ Args: pred_alpha (Tensor): of shape (N, 1, H, W). Predicted alpha matte. @@ -69,7 +81,10 @@ class MSECompositionLoss(nn.Module): Default: False. """ - def __init__(self, loss_weight=1.0, reduction='mean', sample_wise=False): + def __init__(self, + loss_weight: float = 1.0, + reduction: str = 'mean', + sample_wise: bool = False) -> None: super().__init__() if reduction not in ['none', 'mean', 'sum']: raise ValueError(f'Unsupported reduction mode: {reduction}. ' @@ -79,7 +94,13 @@ def __init__(self, loss_weight=1.0, reduction='mean', sample_wise=False): self.reduction = reduction self.sample_wise = sample_wise - def forward(self, pred_alpha, fg, bg, ori_merged, weight=None, **kwargs): + def forward(self, + pred_alpha: torch.Tensor, + fg: torch.Tensor, + bg: torch.Tensor, + ori_merged: torch.Tensor, + weight: Optional[torch.Tensor] = None, + **kwargs) -> torch.Tensor: """ Args: pred_alpha (Tensor): of shape (N, 1, H, W). Predicted alpha matte. @@ -119,10 +140,10 @@ class CharbonnierCompLoss(nn.Module): """ def __init__(self, - loss_weight=1.0, - reduction='mean', - sample_wise=False, - eps=1e-12): + loss_weight: float = 1.0, + reduction: str = 'mean', + sample_wise: bool = False, + eps: bool = 1e-12) -> None: super().__init__() if reduction not in ['none', 'mean', 'sum']: raise ValueError(f'Unsupported reduction mode: {reduction}. ' @@ -133,7 +154,13 @@ def __init__(self, self.sample_wise = sample_wise self.eps = eps - def forward(self, pred_alpha, fg, bg, ori_merged, weight=None, **kwargs): + def forward(self, + pred_alpha: torch.Tensor, + fg: torch.Tensor, + bg: torch.Tensor, + ori_merged: torch.Tensor, + weight: Optional[torch.Tensor] = None, + **kwargs) -> torch.Tensor: """ Args: pred_alpha (Tensor): of shape (N, 1, H, W). Predicted alpha matte. diff --git a/mmedit/models/losses/face_id_loss.py b/mmedit/models/losses/face_id_loss.py index 7bb1a78fec..a7b82ef301 100644 --- a/mmedit/models/losses/face_id_loss.py +++ b/mmedit/models/losses/face_id_loss.py @@ -1,4 +1,7 @@ # Copyright (c) OpenMMLab. All rights reserved. +from typing import Optional + +import torch import torch.nn as nn from mmedit.registry import MODULES @@ -37,10 +40,10 @@ class FaceIdLoss(nn.Module): """ def __init__(self, - loss_weight=1.0, - data_info=None, - facenet=dict(type='ArcFace', ir_se50_weights=None), - loss_name='loss_id'): + loss_weight: float = 1.0, + data_info: Optional[dict] = None, + facenet: dict = dict(type='ArcFace', ir_se50_weights=None), + loss_name: str = 'loss_id') -> None: super(FaceIdLoss, self).__init__() self.loss_weight = loss_weight @@ -48,7 +51,9 @@ def __init__(self, self.net = MODULES.build(facenet) self._loss_name = loss_name - def forward(self, pred=None, gt=None): + def forward(self, + pred: torch.Tensor = None, + gt: torch.Tensor = None) -> torch.Tensor: """Forward function.""" # NOTE: only return the loss term diff --git a/mmedit/models/losses/feature_loss.py b/mmedit/models/losses/feature_loss.py index bb0758e291..e26d709251 100644 --- a/mmedit/models/losses/feature_loss.py +++ b/mmedit/models/losses/feature_loss.py @@ -1,5 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. import warnings +from typing import Optional import torch import torch.nn as nn @@ -23,7 +24,7 @@ def __init__(self) -> None: self.features = nn.Sequential(*list(model.features.children())) self.features.requires_grad_(False) - def forward(self, x): + def forward(self, x: torch.Tensor) -> torch.Tensor: """Forward function. Args: @@ -35,7 +36,9 @@ def forward(self, x): return self.features(x) - def init_weights(self, pretrained=None, strict=True): + def init_weights(self, + pretrained: Optional[str] = None, + strict: bool = True) -> None: """Init weights for models. Args: @@ -63,7 +66,10 @@ class LightCNNFeatureLoss(nn.Module): Default: 'l1'. """ - def __init__(self, pretrained, loss_weight=1.0, criterion='l1'): + def __init__(self, + pretrained: str, + loss_weight: float = 1.0, + criterion: str = 'l1') -> None: super().__init__() self.model = LightCNNFeature() if not isinstance(pretrained, str): @@ -80,7 +86,7 @@ def __init__(self, pretrained, loss_weight=1.0, criterion='l1'): raise ValueError("'criterion' should be 'l1' or 'mse', " f'but got {criterion}') - def forward(self, pred, gt): + def forward(self, pred: torch.Tensor, gt: torch.Tensor) -> torch.Tensor: """Forward function. Args: diff --git a/mmedit/models/losses/gan_loss.py b/mmedit/models/losses/gan_loss.py index c4fccac48a..da9c4bf3b8 100644 --- a/mmedit/models/losses/gan_loss.py +++ b/mmedit/models/losses/gan_loss.py @@ -1,4 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. +from typing import Optional, Tuple, Union + import numpy as np import torch import torch.autograd as autograd @@ -6,6 +8,7 @@ import torch.nn as nn import torch.nn.functional as F from mmengine.dist import is_distributed +from torch.cuda.amp.grad_scaler import GradScaler from torch.nn.functional import conv2d from mmedit.registry import LOSSES @@ -25,10 +28,10 @@ class GANLoss(nn.Module): """ def __init__(self, - gan_type, - real_label_val=1.0, - fake_label_val=0.0, - loss_weight=1.0): + gan_type: str, + real_label_val: float = 1.0, + fake_label_val: float = 0.0, + loss_weight: float = 1.0) -> None: super().__init__() self.gan_type = gan_type self.real_label_val = real_label_val @@ -49,7 +52,7 @@ def __init__(self, raise NotImplementedError( f'GAN type {self.gan_type} is not implemented.') - def _wgan_loss(self, input, target): + def _wgan_loss(self, input: torch.Tensor, target: bool) -> torch.Tensor: """wgan loss. Args: @@ -62,7 +65,8 @@ def _wgan_loss(self, input, target): return -input.mean() if target else input.mean() - def get_target_label(self, input, target_is_real): + def get_target_label(self, input: torch.Tensor, + target_is_real: bool) -> Union[bool, torch.Tensor]: """Get target label. Args: @@ -80,7 +84,11 @@ def get_target_label(self, input, target_is_real): self.real_label_val if target_is_real else self.fake_label_val) return input.new_ones(input.size()) * target_val - def forward(self, input, target_is_real, is_disc=False, mask=None): + def forward(self, + input: torch.Tensor, + target_is_real: bool, + is_disc: bool = False, + mask: torch.Tensor = None) -> torch.Tensor: """ Args: input (Tensor): The input for the loss module, i.e., the network @@ -157,7 +165,11 @@ class GaussianBlur(nn.Module): - output: Tensor with shape of (n, c, h, w) """ - def __init__(self, kernel_size=(71, 71), sigma=(10.0, 10.0)): + def __init__( + self, + kernel_size: Tuple[int, int] = (71, 71), + sigma: Tuple[float, float] = (10.0, 10.0) + ) -> None: super(GaussianBlur, self).__init__() self.kernel_size = kernel_size self.sigma = sigma @@ -165,7 +177,7 @@ def __init__(self, kernel_size=(71, 71), sigma=(10.0, 10.0)): self.kernel = self.get_2d_gaussian_kernel(kernel_size, sigma) @staticmethod - def compute_zero_padding(kernel_size): + def compute_zero_padding(kernel_size: Tuple[int, int]) -> tuple: """Compute zero padding tuple. Args: @@ -179,7 +191,8 @@ def compute_zero_padding(kernel_size): return padding[0], padding[1] - def get_2d_gaussian_kernel(self, kernel_size, sigma): + def get_2d_gaussian_kernel(self, kernel_size: Tuple[int, int], + sigma: Tuple[float, float]) -> torch.Tensor: """Get the two-dimensional Gaussian filter matrix coefficients. Args: @@ -213,7 +226,8 @@ def get_2d_gaussian_kernel(self, kernel_size, sigma): return kernel_2d - def get_1d_gaussian_kernel(self, kernel_size, sigma): + def get_1d_gaussian_kernel(self, kernel_size: int, + sigma: float) -> torch.Tensor: """Get the Gaussian filter coefficients in one dimension (x or y direction). @@ -236,7 +250,7 @@ def get_1d_gaussian_kernel(self, kernel_size, sigma): kernel_1d = self.gaussian(kernel_size, sigma) return kernel_1d - def gaussian(self, kernel_size, sigma): + def gaussian(self, kernel_size: int, sigma: float) -> torch.Tensor: """Gaussian function. Args: @@ -257,7 +271,7 @@ def gauss_arg(x): ]) return gauss / gauss.sum() - def forward(self, x): + def forward(self, x: torch.Tensor) -> torch.Tensor: """Forward function. Args: @@ -280,11 +294,11 @@ def forward(self, x): return conv2d(x, kernel, padding=self.padding, stride=1, groups=c) -def gradient_penalty_loss(discriminator, - real_data, - fake_data, - mask=None, - norm_mode='pixel'): +def gradient_penalty_loss(discriminator: nn.Module, + real_data: torch.Tensor, + fake_data: torch.Tensor, + mask: Optional[torch.Tensor] = None, + norm_mode: str = 'pixel') -> torch.Tensor: """Calculate gradient penalty for wgan-gp. Args: @@ -339,11 +353,15 @@ class GradientPenaltyLoss(nn.Module): loss_weight (float): Loss weight. Default: 1.0. """ - def __init__(self, loss_weight=1.): + def __init__(self, loss_weight: float = 1.) -> None: super().__init__() self.loss_weight = loss_weight - def forward(self, discriminator, real_data, fake_data, mask=None): + def forward(self, + discriminator: nn.Module, + real_data: torch.Tensor, + fake_data: torch.Tensor, + mask: Optional[torch.Tensor] = None) -> torch.Tensor: """Forward function. Args: @@ -361,7 +379,7 @@ def forward(self, discriminator, real_data, fake_data, mask=None): return loss * self.loss_weight -def disc_shift_loss(pred): +def disc_shift_loss(pred: torch.Tensor) -> torch.Tensor: """Disc Shift loss. This loss is proposed in PGGAN as an auxiliary loss for discriminator. @@ -383,11 +401,11 @@ class DiscShiftLoss(nn.Module): loss_weight (float, optional): Loss weight. Defaults to 1.0. """ - def __init__(self, loss_weight=0.1): + def __init__(self, loss_weight: float = 0.1) -> None: super().__init__() self.loss_weight = loss_weight - def forward(self, x): + def forward(self, x: torch.Tensor) -> torch.Tensor: """Forward function. Args: @@ -401,12 +419,12 @@ def forward(self, x): return loss * self.loss_weight -def r1_gradient_penalty_loss(discriminator, - real_data, - mask=None, - norm_mode='pixel', - loss_scaler=None, - use_apex_amp=False): +def r1_gradient_penalty_loss(discriminator: nn.Module, + real_data: torch.Tensor, + mask: Optional[torch.Tensor] = None, + norm_mode: str = 'pixel', + loss_scaler: Optional[GradScaler] = None, + use_apex_amp: bool = False) -> torch.Tensor: """Calculate R1 gradient penalty for WGAN-GP. R1 regularizer comes from: @@ -472,16 +490,16 @@ def r1_gradient_penalty_loss(discriminator, return gradients_penalty -def gen_path_regularizer(generator, - num_batches, - mean_path_length, - pl_batch_shrink=1, - decay=0.01, - weight=1., - pl_batch_size=None, - sync_mean_buffer=False, - loss_scaler=None, - use_apex_amp=False): +def gen_path_regularizer(generator: nn.Module, + num_batches: int, + mean_path_length: torch.Tensor, + pl_batch_shrink: int = 1, + decay: float = 0.01, + weight: float = 1., + pl_batch_size: Optional[int] = None, + sync_mean_buffer: bool = False, + loss_scaler: Optional[GradScaler] = None, + use_apex_amp: bool = False) -> Tuple[torch.Tensor]: """Generator Path Regularization. Path regularization is proposed in StyelGAN2, which can help the improve diff --git a/mmedit/models/losses/gradient_loss.py b/mmedit/models/losses/gradient_loss.py index 12114b4482..f2b1ebe1f1 100644 --- a/mmedit/models/losses/gradient_loss.py +++ b/mmedit/models/losses/gradient_loss.py @@ -1,4 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. +from typing import Optional + import torch import torch.nn as nn import torch.nn.functional as F @@ -19,7 +21,9 @@ class GradientLoss(nn.Module): Supported choices are 'none' | 'mean' | 'sum'. Default: 'mean'. """ - def __init__(self, loss_weight=1.0, reduction='mean'): + def __init__(self, + loss_weight: float = 1.0, + reduction: str = 'mean') -> None: super().__init__() self.loss_weight = loss_weight self.reduction = reduction @@ -27,7 +31,10 @@ def __init__(self, loss_weight=1.0, reduction='mean'): raise ValueError(f'Unsupported reduction mode: {self.reduction}. ' f'Supported ones are: {_reduction_modes}') - def forward(self, pred, target, weight=None): + def forward(self, + pred: torch.Tensor, + target: torch.Tensor, + weight: Optional[torch.Tensor] = None) -> torch.Tensor: """ Args: pred (Tensor): of shape (N, C, H, W). Predicted tensor. diff --git a/mmedit/models/losses/loss_comps/clip_loss_comps.py b/mmedit/models/losses/loss_comps/clip_loss_comps.py index 3e652afb90..6695dedf08 100644 --- a/mmedit/models/losses/loss_comps/clip_loss_comps.py +++ b/mmedit/models/losses/loss_comps/clip_loss_comps.py @@ -1,4 +1,7 @@ # Copyright (c) OpenMMLab. All rights reserved. +from typing import Optional + +import torch import torch.nn as nn from mmedit.registry import MODULES @@ -38,10 +41,10 @@ class CLIPLossComps(nn.Module): """ def __init__(self, - loss_weight=1.0, - data_info=None, - clip_model=dict(), - loss_name='loss_clip'): + loss_weight: float = 1.0, + data_info: Optional[dict] = None, + clip_model: dict = dict(), + loss_name: str = 'loss_clip') -> None: super().__init__() self.loss_weight = loss_weight @@ -49,7 +52,7 @@ def __init__(self, self.net = CLIPLossModel(**clip_model) self._loss_name = loss_name - def forward(self, *args, **kwargs): + def forward(self, *args, **kwargs) -> torch.Tensor: """Forward function. If ``self.data_info`` is not ``None``, a dictionary containing all of @@ -88,7 +91,7 @@ def forward(self, *args, **kwargs): return self.net(*args, **kwargs) * self.loss_weight @staticmethod - def loss_name(): + def loss_name() -> str: """Loss Name. This function must be implemented and will return the name of this diff --git a/mmedit/models/losses/loss_comps/disc_auxiliary_loss_comps.py b/mmedit/models/losses/loss_comps/disc_auxiliary_loss_comps.py index 5080ee5941..f24ac75a6d 100644 --- a/mmedit/models/losses/loss_comps/disc_auxiliary_loss_comps.py +++ b/mmedit/models/losses/loss_comps/disc_auxiliary_loss_comps.py @@ -1,4 +1,7 @@ # Copyright (c) OpenMMLab. All rights reserved. +from typing import Optional + +import torch import torch.nn as nn from mmedit.registry import MODULES @@ -63,15 +66,15 @@ class DiscShiftLossComps(nn.Module): """ def __init__(self, - loss_weight=1.0, - data_info=None, - loss_name='loss_disc_shift'): + loss_weight: float = 1.0, + data_info: Optional[dict] = None, + loss_name: str = 'loss_disc_shift') -> None: super().__init__() self.loss_weight = loss_weight self.data_info = data_info self._loss_name = loss_name - def forward(self, *args, **kwargs): + def forward(self, *args, **kwargs) -> torch.Tensor: """Forward function. If ``self.data_info`` is not ``None``, a dictionary containing all of @@ -112,7 +115,7 @@ def forward(self, *args, **kwargs): # module will just directly return the loss as usual. return disc_shift_loss(*args, **kwargs) * self.loss_weight - def loss_name(self): + def loss_name(self) -> str: """Loss Name. This function must be implemented and will return the name of this @@ -187,17 +190,17 @@ class GradientPenaltyLossComps(nn.Module): """ def __init__(self, - loss_weight=1.0, - norm_mode='pixel', - data_info=None, - loss_name='loss_gp'): + loss_weight: float = 1.0, + norm_mode: str = 'pixel', + data_info: Optional[dict] = None, + loss_name: str = 'loss_gp') -> None: super().__init__() self.loss_weight = loss_weight self.norm_mode = norm_mode self.data_info = data_info self._loss_name = loss_name - def forward(self, *args, **kwargs): + def forward(self, *args, **kwargs) -> torch.Tensor: """Forward function. If ``self.data_info`` is not ``None``, a dictionary containing all of @@ -244,7 +247,7 @@ def forward(self, *args, **kwargs): # *args, weight=self.loss_weight, **kwargs) return gradient_penalty_loss(*args, **kwargs) * self.loss_weight - def loss_name(self): + def loss_name(self) -> str: """Loss Name. This function must be implemented and will return the name of this @@ -321,12 +324,12 @@ class R1GradientPenaltyComps(nn.Module): """ def __init__(self, - loss_weight=1.0, - norm_mode='pixel', - interval=1, - data_info=None, - use_apex_amp=False, - loss_name='loss_r1_gp'): + loss_weight: float = 1.0, + norm_mode: str = 'pixel', + interval: int = 1, + data_info: Optional[dict] = None, + use_apex_amp: bool = False, + loss_name: str = 'loss_r1_gp') -> None: super().__init__() self.loss_weight = loss_weight self.norm_mode = norm_mode @@ -335,7 +338,7 @@ def __init__(self, self.use_apex_amp = use_apex_amp self._loss_name = loss_name - def forward(self, *args, **kwargs): + def forward(self, *args, **kwargs) -> torch.Tensor: """Forward function. If ``self.data_info`` is not ``None``, a dictionary containing all of @@ -385,7 +388,7 @@ def forward(self, *args, **kwargs): return r1_gradient_penalty_loss( *args, norm_mode=self.norm_mode, **kwargs) * self.loss_weight - def loss_name(self): + def loss_name(self) -> str: """Loss Name. This function must be implemented and will return the name of this diff --git a/mmedit/models/losses/loss_comps/face_id_loss_comps.py b/mmedit/models/losses/loss_comps/face_id_loss_comps.py index d22762b948..a38e84b8c4 100644 --- a/mmedit/models/losses/loss_comps/face_id_loss_comps.py +++ b/mmedit/models/losses/loss_comps/face_id_loss_comps.py @@ -1,4 +1,7 @@ # Copyright (c) OpenMMLab. All rights reserved. +from typing import Optional + +import torch import torch.nn as nn from mmedit.registry import MODULES @@ -37,10 +40,10 @@ class FaceIdLossComps(nn.Module): """ def __init__(self, - loss_weight=1.0, - data_info=None, - facenet=dict(type='ArcFace', ir_se50_weights=None), - loss_name='loss_id'): + loss_weight: float = 1.0, + data_info: Optional[dict] = None, + facenet: dict = dict(type='ArcFace', ir_se50_weights=None), + loss_name: str = 'loss_id') -> None: super().__init__() self.loss_weight = loss_weight @@ -48,7 +51,7 @@ def __init__(self, self.net = MODULES.build(facenet) self._loss_name = loss_name - def forward(self, *args, **kwargs): + def forward(self, *args, **kwargs) -> torch.Tensor: """Forward function. If ``self.data_info`` is not ``None``, a dictionary containing all of @@ -88,7 +91,7 @@ def forward(self, *args, **kwargs): # NOTE: only return the loss term return self.net(*args, **kwargs)[0] * self.loss_weight - def loss_name(self): + def loss_name(self) -> str: """Loss Name. This function must be implemented and will return the name of this diff --git a/mmedit/models/losses/loss_comps/gan_loss_comps.py b/mmedit/models/losses/loss_comps/gan_loss_comps.py index 83de0c9e94..cd0d06a0f1 100644 --- a/mmedit/models/losses/loss_comps/gan_loss_comps.py +++ b/mmedit/models/losses/loss_comps/gan_loss_comps.py @@ -1,4 +1,7 @@ # Copyright (c) OpenMMLab. All rights reserved. +from typing import Union + +import torch import torch.nn as nn import torch.nn.functional as F @@ -20,10 +23,10 @@ class GANLossComps(nn.Module): """ def __init__(self, - gan_type, - real_label_val=1.0, - fake_label_val=0.0, - loss_weight=1.0): + gan_type: str, + real_label_val: float = 1.0, + fake_label_val: float = 0.0, + loss_weight: float = 1.0) -> None: super().__init__() self.gan_type = gan_type self.loss_weight = loss_weight @@ -44,7 +47,7 @@ def __init__(self, raise NotImplementedError( f'GAN type {self.gan_type} is not implemented.') - def _wgan_loss(self, input, target): + def _wgan_loss(self, input: torch.Tensor, target: bool) -> torch.Tensor: """wgan loss. Args: @@ -56,7 +59,8 @@ def _wgan_loss(self, input, target): """ return -input.mean() if target else input.mean() - def _wgan_logistic_ns_loss(self, input, target): + def _wgan_logistic_ns_loss(self, input: torch.Tensor, + target: bool) -> torch.Tensor: """WGAN loss in logistically non-saturating mode. This loss is widely used in StyleGANv2. @@ -72,7 +76,8 @@ def _wgan_logistic_ns_loss(self, input, target): return F.softplus(-input).mean() if target else F.softplus( input).mean() - def get_target_label(self, input, target_is_real): + def get_target_label(self, input: torch.Tensor, + target_is_real: bool) -> Union[bool, torch.Tensor]: """Get target label. Args: @@ -90,7 +95,10 @@ def get_target_label(self, input, target_is_real): self.real_label_val if target_is_real else self.fake_label_val) return input.new_ones(input.size()) * target_val - def forward(self, input, target_is_real, is_disc=False): + def forward(self, + input: torch.Tensor, + target_is_real: bool, + is_disc: bool = False) -> torch.Tensor: """ Args: input (Tensor): The input for the loss module, i.e., the network diff --git a/mmedit/models/losses/loss_comps/gen_auxiliary_loss_comps.py b/mmedit/models/losses/loss_comps/gen_auxiliary_loss_comps.py index 227cf7fec9..d9b566ce71 100644 --- a/mmedit/models/losses/loss_comps/gen_auxiliary_loss_comps.py +++ b/mmedit/models/losses/loss_comps/gen_auxiliary_loss_comps.py @@ -1,4 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. +from typing import Optional + import torch import torch.nn as nn @@ -74,15 +76,15 @@ class GeneratorPathRegularizerComps(nn.Module): """ def __init__(self, - loss_weight=1., - pl_batch_shrink=1, - decay=0.01, - pl_batch_size=None, - sync_mean_buffer=False, - interval=1, - data_info=None, - use_apex_amp=False, - loss_name='loss_path_regular'): + loss_weight: float = 1., + pl_batch_shrink: int = 1, + decay: float = 0.01, + pl_batch_size: Optional[int] = None, + sync_mean_buffer: bool = False, + interval: int = 1, + data_info: Optional[dict] = None, + use_apex_amp: bool = False, + loss_name: str = 'loss_path_regular') -> None: super().__init__() self.loss_weight = loss_weight self.pl_batch_shrink = pl_batch_shrink @@ -96,7 +98,7 @@ def __init__(self, self.register_buffer('mean_path_length', torch.tensor(0.)) - def forward(self, *args, **kwargs): + def forward(self, *args, **kwargs) -> torch.Tensor: """Forward function. If ``self.data_info`` is not ``None``, a dictionary containing all of @@ -155,7 +157,7 @@ def forward(self, *args, **kwargs): *args, **kwargs) return path_penalty * self.loss_weight - def loss_name(self): + def loss_name(self) -> str: """Loss Name. This function must be implemented and will return the name of this diff --git a/mmedit/models/losses/loss_wrapper.py b/mmedit/models/losses/loss_wrapper.py index 859bc36c0c..de5649b509 100644 --- a/mmedit/models/losses/loss_wrapper.py +++ b/mmedit/models/losses/loss_wrapper.py @@ -1,10 +1,12 @@ # Copyright (c) OpenMMLab. All rights reserved. import functools +from typing import Optional +import torch import torch.nn.functional as F -def reduce_loss(loss, reduction): +def reduce_loss(loss: torch.Tensor, reduction: str) -> torch.Tensor: """Reduce loss as specified. Args: @@ -25,7 +27,10 @@ def reduce_loss(loss, reduction): raise ValueError(f'reduction type {reduction} not supported') -def mask_reduce_loss(loss, weight=None, reduction='mean', sample_wise=False): +def mask_reduce_loss(loss: torch.Tensor, + weight: Optional[torch.Tensor] = None, + reduction: str = 'mean', + sample_wise: bool = False) -> torch.Tensor: """Apply element-wise weight and reduce loss. Args: @@ -102,12 +107,12 @@ def masked_loss(loss_func): """ @functools.wraps(loss_func) - def wrapper(pred, - target, - weight=None, - reduction='mean', - sample_wise=False, - **kwargs): + def wrapper(pred: torch.Tensor, + target: torch.Tensor, + weight: Optional[torch.Tensor] = None, + reduction: str = 'mean', + sample_wise: bool = False, + **kwargs) -> torch.Tensor: # get element-wise loss loss = loss_func(pred, target, **kwargs) loss = mask_reduce_loss(loss, weight, reduction, sample_wise) diff --git a/mmedit/models/losses/perceptual_loss.py b/mmedit/models/losses/perceptual_loss.py index 7f0e9de4c7..e05e50e51a 100644 --- a/mmedit/models/losses/perceptual_loss.py +++ b/mmedit/models/losses/perceptual_loss.py @@ -1,4 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. +from typing import List, Optional, Tuple + import torch import torch.nn as nn import torchvision.models.vgg as vgg @@ -30,10 +32,10 @@ class PerceptualVGG(nn.Module): """ def __init__(self, - layer_name_list, - vgg_type='vgg19', - use_input_norm=True, - pretrained='torchvision://vgg19'): + layer_name_list: List[str], + vgg_type: str = 'vgg19', + use_input_norm: bool = True, + pretrained: str = 'torchvision://vgg19') -> None: super().__init__() if pretrained.startswith('torchvision://'): assert vgg_type in pretrained @@ -62,7 +64,7 @@ def __init__(self, for v in self.vgg_layers.parameters(): v.requires_grad = False - def forward(self, x): + def forward(self, x: torch.Tensor) -> torch.Tensor: """Forward function. Args: @@ -82,7 +84,7 @@ def forward(self, x): output[name] = x.clone() return output - def init_weights(self, model, pretrained): + def init_weights(self, model: nn.Module, pretrained: str) -> None: """Init weights. Args: @@ -126,15 +128,15 @@ class PerceptualLoss(nn.Module): """ def __init__(self, - layer_weights, - layer_weights_style=None, - vgg_type='vgg19', - use_input_norm=True, - perceptual_weight=1.0, - style_weight=1.0, - norm_img=True, - pretrained='torchvision://vgg19', - criterion='l1'): + layer_weights: dict, + layer_weights_style: Optional[dict] = None, + vgg_type: str = 'vgg19', + use_input_norm: bool = True, + perceptual_weight: float = 1.0, + style_weight: float = 1.0, + norm_img: bool = True, + pretrained: str = 'torchvision://vgg19', + criterion: str = 'l1') -> None: super().__init__() self.norm_img = norm_img self.perceptual_weight = perceptual_weight @@ -169,7 +171,8 @@ def __init__(self, f'{criterion} criterion has not been supported in' ' this version.') - def forward(self, x, gt): + def forward(self, x: torch.Tensor, + gt: torch.Tensor) -> Tuple[torch.Tensor]: """Forward function. Args: @@ -215,7 +218,7 @@ def forward(self, x, gt): return percep_loss, style_loss - def _gram_mat(self, x): + def _gram_mat(self, x: torch.Tensor) -> torch.Tensor: """Calculate Gram matrix. Args: @@ -242,7 +245,10 @@ class TransferalPerceptualLoss(nn.Module): Default: 'mse'. """ - def __init__(self, loss_weight=1.0, use_attention=True, criterion='mse'): + def __init__(self, + loss_weight: float = 1.0, + use_attention: bool = True, + criterion: str = 'mse') -> None: super().__init__() self.use_attention = use_attention self.loss_weight = loss_weight @@ -255,7 +261,8 @@ def __init__(self, loss_weight=1.0, use_attention=True, criterion='mse'): raise ValueError( f"criterion should be 'l1' or 'mse', but got {criterion}") - def forward(self, maps, soft_attention, textures): + def forward(self, maps: Tuple[torch.Tensor], soft_attention: torch.Tensor, + textures: Tuple[torch.Tensor]) -> torch.Tensor: """Forward function. Args: diff --git a/mmedit/models/losses/pixelwise_loss.py b/mmedit/models/losses/pixelwise_loss.py index 3ba36d5fba..3f9548c242 100644 --- a/mmedit/models/losses/pixelwise_loss.py +++ b/mmedit/models/losses/pixelwise_loss.py @@ -1,4 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. +from typing import Optional + import torch import torch.nn as nn import torch.nn.functional as F @@ -10,7 +12,7 @@ @masked_loss -def l1_loss(pred, target): +def l1_loss(pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor: """L1 loss. Args: @@ -24,7 +26,7 @@ def l1_loss(pred, target): @masked_loss -def mse_loss(pred, target): +def mse_loss(pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor: """MSE loss. Args: @@ -38,7 +40,9 @@ def mse_loss(pred, target): @masked_loss -def charbonnier_loss(pred, target, eps=1e-12): +def charbonnier_loss(pred: torch.Tensor, + target: torch.Tensor, + eps: float = 1e-12) -> torch.Tensor: """Charbonnier loss. Args: @@ -53,7 +57,7 @@ def charbonnier_loss(pred, target, eps=1e-12): return torch.sqrt((pred - target)**2 + eps) -def tv_loss(input): +def tv_loss(input: torch.Tensor) -> torch.Tensor: """L2 total variation loss, as in Mahendran et al.""" input = F.pad(input, (0, 1, 0, 1), 'replicate') x_diff = input[..., :-1, 1:] - input[..., :-1, :-1] @@ -76,7 +80,10 @@ class L1Loss(nn.Module): Default: False. """ - def __init__(self, loss_weight=1.0, reduction='mean', sample_wise=False): + def __init__(self, + loss_weight: float = 1.0, + reduction: str = 'mean', + sample_wise: bool = False) -> None: super().__init__() if reduction not in ['none', 'mean', 'sum']: raise ValueError(f'Unsupported reduction mode: {reduction}. ' @@ -86,7 +93,11 @@ def __init__(self, loss_weight=1.0, reduction='mean', sample_wise=False): self.reduction = reduction self.sample_wise = sample_wise - def forward(self, pred, target, weight=None, **kwargs): + def forward(self, + pred: torch.Tensor, + target: torch.Tensor, + weight: Optional[torch.Tensor] = None, + **kwargs) -> torch.Tensor: """Forward Function. Args: @@ -118,7 +129,10 @@ class MSELoss(nn.Module): Default: False. """ - def __init__(self, loss_weight=1.0, reduction='mean', sample_wise=False): + def __init__(self, + loss_weight: float = 1.0, + reduction: str = 'mean', + sample_wise: bool = False) -> None: super().__init__() if reduction not in ['none', 'mean', 'sum']: raise ValueError(f'Unsupported reduction mode: {reduction}. ' @@ -128,7 +142,11 @@ def __init__(self, loss_weight=1.0, reduction='mean', sample_wise=False): self.reduction = reduction self.sample_wise = sample_wise - def forward(self, pred, target, weight=None, **kwargs): + def forward(self, + pred: torch.Tensor, + target: torch.Tensor, + weight: Optional[torch.Tensor] = None, + **kwargs) -> torch.Tensor: """Forward Function. Args: @@ -167,10 +185,10 @@ class CharbonnierLoss(nn.Module): """ def __init__(self, - loss_weight=1.0, - reduction='mean', - sample_wise=False, - eps=1e-12): + loss_weight: float = 1.0, + reduction: str = 'mean', + sample_wise: bool = False, + eps: float = 1e-12) -> None: super().__init__() if reduction not in ['none', 'mean', 'sum']: raise ValueError(f'Unsupported reduction mode: {reduction}. ' @@ -181,7 +199,11 @@ def __init__(self, self.sample_wise = sample_wise self.eps = eps - def forward(self, pred, target, weight=None, **kwargs): + def forward(self, + pred: torch.Tensor, + target: torch.Tensor, + weight: Optional[torch.Tensor] = None, + **kwargs) -> torch.Tensor: """Forward Function. Args: @@ -207,10 +229,12 @@ class MaskedTVLoss(L1Loss): loss_weight (float, optional): Loss weight. Defaults to 1.0. """ - def __init__(self, loss_weight=1.0): + def __init__(self, loss_weight: float = 1.0) -> None: super().__init__(loss_weight=loss_weight) - def forward(self, pred, mask=None): + def forward(self, + pred: torch.Tensor, + mask: Optional[torch.Tensor] = None) -> torch.Tensor: """Forward function. Args: @@ -242,7 +266,7 @@ class PSNRLoss(nn.Module): toY: change to calculate the PSNR of Y channel in YCbCr format """ - def __init__(self, loss_weight=1.0, toY=False): + def __init__(self, loss_weight: float = 1.0, toY: bool = False) -> None: super(PSNRLoss, self).__init__() self.loss_weight = loss_weight import numpy as np @@ -251,7 +275,8 @@ def __init__(self, loss_weight=1.0, toY=False): self.coef = torch.tensor([65.481, 128.553, 24.966]).reshape(1, 3, 1, 1) self.first = True - def forward(self, pred, target): + def forward(self, pred: torch.Tensor, + target: torch.Tensor) -> torch.Tensor: assert len(pred.size()) == 4 return self.loss_weight * self.scale * torch.log((