Skip to content

Commit

Permalink
Merge 7880bfc into ad24a5a
Browse files Browse the repository at this point in the history
  • Loading branch information
liangzelong authored Jan 16, 2023
2 parents ad24a5a + 7880bfc commit 598b7d8
Show file tree
Hide file tree
Showing 14 changed files with 286 additions and 163 deletions.
24 changes: 14 additions & 10 deletions mmedit/models/losses/clip_loss.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Optional

import torch
import torch.nn as nn

Expand All @@ -25,10 +27,10 @@ class CLIPLossModel(torch.nn.Module):
"""

def __init__(self,
in_size=1024,
scale_factor=7,
pool_size=224,
clip_type='ViT-B/32'):
in_size: int = 1024,
scale_factor: int = 7,
pool_size: int = 224,
clip_type: str = 'ViT-B/32') -> None:
super(CLIPLossModel, self).__init__()
try:
import clip
Expand All @@ -43,7 +45,9 @@ def __init__(self,
self.avg_pool = torch.nn.AvgPool2d(
kernel_size=(scale_factor * in_size // pool_size))

def forward(self, image=None, text=None):
def forward(self,
image: torch.Tensor = None,
text: torch.Tensor = None) -> torch.Tensor:
"""Forward function."""
assert image is not None
assert text is not None
Expand Down Expand Up @@ -85,18 +89,18 @@ class CLIPLoss(nn.Module):
"""

def __init__(self,
loss_weight=1.0,
data_info=None,
clip_model=dict(),
loss_name='loss_clip'):
loss_weight: float = 1.0,
data_info: Optional[dict] = None,
clip_model: dict = dict(),
loss_name: str = 'loss_clip') -> None:

super(CLIPLoss, self).__init__()
self.loss_weight = loss_weight
self.data_info = data_info
self.net = CLIPLossModel(**clip_model)
self._loss_name = loss_name

def forward(self, image, text):
def forward(self, image: torch.Tensor, text: torch.Tensor) -> torch.Tensor:
"""Forward function.
If ``self.data_info`` is not ``None``, a dictionary containing all of
Expand Down
45 changes: 36 additions & 9 deletions mmedit/models/losses/composition_loss.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Optional

import torch
import torch.nn as nn

from mmedit.registry import LOSSES
Expand All @@ -22,7 +25,10 @@ class L1CompositionLoss(nn.Module):
Default: False.
"""

def __init__(self, loss_weight=1.0, reduction='mean', sample_wise=False):
def __init__(self,
loss_weight: float = 1.0,
reduction: str = 'mean',
sample_wise: bool = False) -> None:
super().__init__()
if reduction not in ['none', 'mean', 'sum']:
raise ValueError(f'Unsupported reduction mode: {reduction}. '
Expand All @@ -32,7 +38,13 @@ def __init__(self, loss_weight=1.0, reduction='mean', sample_wise=False):
self.reduction = reduction
self.sample_wise = sample_wise

def forward(self, pred_alpha, fg, bg, ori_merged, weight=None, **kwargs):
def forward(self,
pred_alpha: torch.Tensor,
fg: torch.Tensor,
bg: torch.Tensor,
ori_merged: torch.Tensor,
weight: Optional[torch.Tensor] = None,
**kwargs) -> torch.Tensor:
"""
Args:
pred_alpha (Tensor): of shape (N, 1, H, W). Predicted alpha matte.
Expand Down Expand Up @@ -69,7 +81,10 @@ class MSECompositionLoss(nn.Module):
Default: False.
"""

def __init__(self, loss_weight=1.0, reduction='mean', sample_wise=False):
def __init__(self,
loss_weight: float = 1.0,
reduction: str = 'mean',
sample_wise: bool = False) -> None:
super().__init__()
if reduction not in ['none', 'mean', 'sum']:
raise ValueError(f'Unsupported reduction mode: {reduction}. '
Expand All @@ -79,7 +94,13 @@ def __init__(self, loss_weight=1.0, reduction='mean', sample_wise=False):
self.reduction = reduction
self.sample_wise = sample_wise

def forward(self, pred_alpha, fg, bg, ori_merged, weight=None, **kwargs):
def forward(self,
pred_alpha: torch.Tensor,
fg: torch.Tensor,
bg: torch.Tensor,
ori_merged: torch.Tensor,
weight: Optional[torch.Tensor] = None,
**kwargs) -> torch.Tensor:
"""
Args:
pred_alpha (Tensor): of shape (N, 1, H, W). Predicted alpha matte.
Expand Down Expand Up @@ -119,10 +140,10 @@ class CharbonnierCompLoss(nn.Module):
"""

def __init__(self,
loss_weight=1.0,
reduction='mean',
sample_wise=False,
eps=1e-12):
loss_weight: float = 1.0,
reduction: str = 'mean',
sample_wise: bool = False,
eps: bool = 1e-12) -> None:
super().__init__()
if reduction not in ['none', 'mean', 'sum']:
raise ValueError(f'Unsupported reduction mode: {reduction}. '
Expand All @@ -133,7 +154,13 @@ def __init__(self,
self.sample_wise = sample_wise
self.eps = eps

def forward(self, pred_alpha, fg, bg, ori_merged, weight=None, **kwargs):
def forward(self,
pred_alpha: torch.Tensor,
fg: torch.Tensor,
bg: torch.Tensor,
ori_merged: torch.Tensor,
weight: Optional[torch.Tensor] = None,
**kwargs) -> torch.Tensor:
"""
Args:
pred_alpha (Tensor): of shape (N, 1, H, W). Predicted alpha matte.
Expand Down
15 changes: 10 additions & 5 deletions mmedit/models/losses/face_id_loss.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Optional

import torch
import torch.nn as nn

from mmedit.registry import MODULES
Expand Down Expand Up @@ -37,18 +40,20 @@ class FaceIdLoss(nn.Module):
"""

def __init__(self,
loss_weight=1.0,
data_info=None,
facenet=dict(type='ArcFace', ir_se50_weights=None),
loss_name='loss_id'):
loss_weight: float = 1.0,
data_info: Optional[dict] = None,
facenet: dict = dict(type='ArcFace', ir_se50_weights=None),
loss_name: str = 'loss_id') -> None:

super(FaceIdLoss, self).__init__()
self.loss_weight = loss_weight
self.data_info = data_info
self.net = MODULES.build(facenet)
self._loss_name = loss_name

def forward(self, pred=None, gt=None):
def forward(self,
pred: torch.Tensor = None,
gt: torch.Tensor = None) -> torch.Tensor:
"""Forward function."""

# NOTE: only return the loss term
Expand Down
14 changes: 10 additions & 4 deletions mmedit/models/losses/feature_loss.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
import warnings
from typing import Optional

import torch
import torch.nn as nn
Expand All @@ -23,7 +24,7 @@ def __init__(self) -> None:
self.features = nn.Sequential(*list(model.features.children()))
self.features.requires_grad_(False)

def forward(self, x):
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Forward function.
Args:
Expand All @@ -35,7 +36,9 @@ def forward(self, x):

return self.features(x)

def init_weights(self, pretrained=None, strict=True):
def init_weights(self,
pretrained: Optional[str] = None,
strict: bool = True) -> None:
"""Init weights for models.
Args:
Expand Down Expand Up @@ -63,7 +66,10 @@ class LightCNNFeatureLoss(nn.Module):
Default: 'l1'.
"""

def __init__(self, pretrained, loss_weight=1.0, criterion='l1'):
def __init__(self,
pretrained: str,
loss_weight: float = 1.0,
criterion: str = 'l1') -> None:
super().__init__()
self.model = LightCNNFeature()
if not isinstance(pretrained, str):
Expand All @@ -80,7 +86,7 @@ def __init__(self, pretrained, loss_weight=1.0, criterion='l1'):
raise ValueError("'criterion' should be 'l1' or 'mse', "
f'but got {criterion}')

def forward(self, pred, gt):
def forward(self, pred: torch.Tensor, gt: torch.Tensor) -> torch.Tensor:
"""Forward function.
Args:
Expand Down
Loading

0 comments on commit 598b7d8

Please sign in to comment.