Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Refactor] use mmengine.basemodule instead of nn.module #1491

Merged
merged 11 commits into from
Jan 18, 2023
26 changes: 8 additions & 18 deletions mmedit/models/base_archs/smpatch_disc.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,15 @@

import torch.nn as nn
from mmcv.cnn import ConvModule
from mmengine import MMLogger
from mmengine.runner import load_checkpoint
from mmengine.model import BaseModule
from torch import Tensor

from mmedit.models.utils import generation_init_weights
from mmedit.registry import COMPONENTS


@COMPONENTS.register_module()
class SoftMaskPatchDiscriminator(nn.Module):
class SoftMaskPatchDiscriminator(BaseModule):
"""A Soft Mask-Guided PatchGAN discriminator.

Args:
Expand Down Expand Up @@ -118,19 +117,10 @@ def forward(self, x: Tensor) -> Tensor:
"""
return self.model(x)

def init_weights(self, pretrained: Optional[str] = None) -> None:
"""Initialize weights for the model.
def init_weights(self) -> None:
"""Initialize weights for the model."""

Args:
pretrained (str, optional): Path for pretrained weights. If given
None, pretrained weights will not be loaded. Default: None.
"""
if isinstance(pretrained, str):
logger = MMLogger.get_current_instance()
load_checkpoint(self, pretrained, strict=False, logger=logger)
elif pretrained is None:
generation_init_weights(
self, init_type=self.init_type, init_gain=self.init_gain)
else:
raise TypeError("'pretrained' must be a str or None. "
f'But received {type(pretrained)}.')
generation_init_weights(
self, init_type=self.init_type, init_gain=self.init_gain)

self._is_init = True
3 changes: 2 additions & 1 deletion mmedit/models/editors/aotgan/aot_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,13 @@
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import ConvModule
from mmengine.model import BaseModule

from mmedit.registry import COMPONENTS


@COMPONENTS.register_module()
class AOTDecoder(nn.Module):
class AOTDecoder(BaseModule):
"""Decoder used in AOT-GAN model.

This implementation follows:
Expand Down
3 changes: 2 additions & 1 deletion mmedit/models/editors/aotgan/aot_encoder.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch.nn as nn
from mmcv.cnn import ConvModule
from mmengine.model import BaseModule

from mmedit.registry import COMPONENTS


@COMPONENTS.register_module()
class AOTEncoder(nn.Module):
class AOTEncoder(BaseModule):
"""Encoder used in AOT-GAN model.

This implementation follows:
Expand Down
1 change: 1 addition & 0 deletions mmedit/models/editors/aotgan/aot_encoder_decoder.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.

from mmedit.registry import BACKBONES, COMPONENTS
from ..global_local import GLEncoderDecoder

Expand Down
5 changes: 3 additions & 2 deletions mmedit/models/editors/aotgan/aot_neck.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,13 @@
import torch
import torch.nn as nn
from mmcv.cnn import ConvModule
from mmengine.model import BaseModule

from mmedit.registry import COMPONENTS


@COMPONENTS.register_module()
class AOTBlockNeck(nn.Module):
class AOTBlockNeck(BaseModule):
"""Dilation backbone used in AOT-GAN model.

This implementation follows:
Expand Down Expand Up @@ -45,7 +46,7 @@ def forward(self, x):
return x


class AOTBlock(nn.Module):
class AOTBlock(BaseModule):
"""AOT Block which constitutes the dilation backbone.

This implementation follows:
Expand Down
6 changes: 3 additions & 3 deletions mmedit/models/editors/basicvsr/basicvsr_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ def forward(self, lrs):
return torch.stack(outputs, dim=1)


class ResidualBlocksWithInputConv(nn.Module):
class ResidualBlocksWithInputConv(BaseModule):
"""Residual blocks with a convolution in front.

Args:
Expand Down Expand Up @@ -206,7 +206,7 @@ def forward(self, feat):
return self.main(feat)


class SPyNet(nn.Module):
class SPyNet(BaseModule):
"""SPyNet network structure.

The difference to the SPyNet in [tof.py] is that
Expand Down Expand Up @@ -339,7 +339,7 @@ def forward(self, ref, supp):
return flow


class SPyNetBasicModule(nn.Module):
class SPyNetBasicModule(BaseModule):
"""Basic Module for SPyNet.

Paper:
Expand Down
8 changes: 4 additions & 4 deletions mmedit/models/editors/cain/cain_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ def get_padding_functions(x, padding=7):
return padding_function, depadding_function


class ConvNormWithReflectionPad(nn.Module):
class ConvNormWithReflectionPad(BaseModule):
"""Apply reflection padding, followed by a convolution, which can be
followed by an optional normalization.

Expand Down Expand Up @@ -193,7 +193,7 @@ def forward(self, x):
return out


class ChannelAttentionLayer(nn.Module):
class ChannelAttentionLayer(BaseModule):
"""Channel Attention (CA) Layer.

Args:
Expand Down Expand Up @@ -236,7 +236,7 @@ def forward(self, x):
return x * y


class ResidualChannelAttention(nn.Module):
class ResidualChannelAttention(BaseModule):
"""Residual Channel Attention Module.

Args:
Expand Down Expand Up @@ -277,7 +277,7 @@ def forward(self, x):
return out + x


class ResidualGroup(nn.Module):
class ResidualGroup(BaseModule):
"""Residual Group, consisting of a stack of residual channel attention,
followed by a convolution.

Expand Down
3 changes: 2 additions & 1 deletion mmedit/models/editors/deepfillv1/contextual_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,10 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmengine.model import BaseModule


class ContextualAttentionModule(nn.Module):
class ContextualAttentionModule(BaseModule):
"""Contexture attention module.

The details of this module can be found in:
Expand Down
4 changes: 2 additions & 2 deletions mmedit/models/editors/deepfillv1/contextual_attention_neck.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch.nn as nn
from mmcv.cnn import ConvModule
from mmengine.model import BaseModule

from mmedit.models.base_archs import SimpleGatedConvModule
from mmedit.models.editors.deepfillv1.contextual_attention import \
Expand All @@ -9,7 +9,7 @@


@COMPONENTS.register_module()
class ContextualAttentionNeck(nn.Module):
class ContextualAttentionNeck(BaseModule):
"""Neck with contextual attention module.

Args:
Expand Down
4 changes: 2 additions & 2 deletions mmedit/models/editors/deepfillv1/deepfill_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,17 @@
from functools import partial

import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import ConvModule, build_activation_layer
from mmengine.model import BaseModule

# from ...modules import SimpleGatedConvModule
from mmedit.models.base_archs import SimpleGatedConvModule
from mmedit.registry import COMPONENTS


@COMPONENTS.register_module()
class DeepFillDecoder(nn.Module):
class DeepFillDecoder(BaseModule):
"""Decoder used in DeepFill model.

This implementation follows:
Expand Down
32 changes: 11 additions & 21 deletions mmedit/models/editors/deepfillv1/deepfill_disc.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch.nn as nn
from mmengine import MMLogger
from mmengine.model import BaseModule
from mmengine.model.weight_init import normal_init
from mmengine.runner import load_checkpoint

from mmedit.registry import COMPONENTS


@COMPONENTS.register_module()
class DeepFillv1Discriminators(nn.Module):
class DeepFillv1Discriminators(BaseModule):
"""Discriminators used in DeepFillv1 model.

In DeepFillv1 model, the discriminators are independent without any
Expand Down Expand Up @@ -47,22 +46,13 @@ def forward(self, x):

return global_pred, local_pred

def init_weights(self, pretrained=None):
"""Init weights for models.
def init_weights(self):
"""Init weights for models."""
Z-Fran marked this conversation as resolved.
Show resolved Hide resolved

Args:
pretrained (str, optional): Path for pretrained weights. If given
None, pretrained weights will not be loaded. Defaults to None.
"""
if isinstance(pretrained, str):
logger = MMLogger.get_current_instance()
load_checkpoint(self, pretrained, strict=False, logger=logger)
elif pretrained is None:
for m in self.modules():
if isinstance(m, nn.Linear):
normal_init(m, 0, std=0.02)
elif isinstance(m, nn.Conv2d):
normal_init(m, 0.0, std=0.02)
else:
raise TypeError('pretrained must be a str or None but got'
f'{type(pretrained)} instead.')
for m in self.modules():
if isinstance(m, nn.Linear):
normal_init(m, 0, std=0.02)
elif isinstance(m, nn.Conv2d):
normal_init(m, 0.0, std=0.02)

self._is_init = True
4 changes: 2 additions & 2 deletions mmedit/models/editors/deepfillv1/deepfill_encoder.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch.nn as nn
from mmcv.cnn import ConvModule
from mmengine.model import BaseModule

from mmedit.models.base_archs import SimpleGatedConvModule
from mmedit.registry import COMPONENTS


@COMPONENTS.register_module()
class DeepFillEncoder(nn.Module):
class DeepFillEncoder(BaseModule):
"""Encoder used in DeepFill model.

This implementation follows:
Expand Down
4 changes: 2 additions & 2 deletions mmedit/models/editors/deepfillv1/deepfill_refiner.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmengine.model import BaseModule

from mmedit.registry import COMPONENTS


@COMPONENTS.register_module()
class DeepFillRefiner(nn.Module):
class DeepFillRefiner(BaseModule):
"""Refiner used in DeepFill model.

This implementation follows:
Expand Down
32 changes: 11 additions & 21 deletions mmedit/models/editors/deepfillv2/two_stage_encoder_decoder.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,15 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn as nn
from mmengine import MMLogger
from mmengine.model import BaseModule
from mmengine.model.weight_init import constant_init, normal_init
from mmengine.runner import load_checkpoint
from mmengine.utils.dl_utils.parrots_wrapper import _BatchNorm

from mmedit.registry import BACKBONES


@BACKBONES.register_module()
class DeepFillEncoderDecoder(nn.Module):
class DeepFillEncoderDecoder(BaseModule):
"""Two-stage encoder-decoder structure used in DeepFill model.

The details are in:
Expand Down Expand Up @@ -73,22 +72,13 @@ def forward(self, x):
return stage1_res, stage2_res

# TODO: study the effects of init functions
def init_weights(self, pretrained=None):
"""Init weights for models.
def init_weights(self):
"""Init weights for models."""
Z-Fran marked this conversation as resolved.
Show resolved Hide resolved

Args:
pretrained (str, optional): Path for pretrained weights. If given
None, pretrained weights will not be loaded. Defaults to None.
"""
if isinstance(pretrained, str):
logger = MMLogger.get_current_instance()
load_checkpoint(self, pretrained, strict=False, logger=logger)
elif pretrained is None:
for m in self.modules():
if isinstance(m, nn.Conv2d):
normal_init(m, 0, 0.02)
elif isinstance(m, (_BatchNorm, nn.InstanceNorm2d)):
constant_init(m, 1)
else:
raise TypeError('pretrained must be a str or None but'
f' got {type(pretrained)} instead.')
for m in self.modules():
if isinstance(m, nn.Conv2d):
normal_init(m, 0, 0.02)
elif isinstance(m, (_BatchNorm, nn.InstanceNorm2d)):
constant_init(m, 1)

self._is_init = True
4 changes: 2 additions & 2 deletions mmedit/models/editors/edvr/edvr_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,7 @@ def forward(self, x, extra_feat):
self.deform_groups)


class PCDAlignment(nn.Module):
class PCDAlignment(BaseModule):
"""Alignment module using Pyramid, Cascading and Deformable convolution
(PCD). It is used in EDVRNet.

Expand Down Expand Up @@ -361,7 +361,7 @@ def forward(self, neighbor_feats, ref_feats):
return feat


class TSAFusion(nn.Module):
class TSAFusion(BaseModule):
"""Temporal Spatial Attention (TSA) fusion module. It is used in EDVRNet.

Args:
Expand Down
3 changes: 2 additions & 1 deletion mmedit/models/editors/global_local/gl_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,13 @@
import torch
import torch.nn as nn
from mmcv.cnn import ConvModule
from mmengine.model import BaseModule

from mmedit.registry import COMPONENTS


@COMPONENTS.register_module()
class GLDecoder(nn.Module):
class GLDecoder(BaseModule):
"""Decoder used in Global&Local model.

This implementation follows:
Expand Down
3 changes: 2 additions & 1 deletion mmedit/models/editors/global_local/gl_dilation.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch.nn as nn
from mmcv.cnn import ConvModule
from mmengine.model import BaseModule

from mmedit.models.base_archs import SimpleGatedConvModule
from mmedit.registry import COMPONENTS


@COMPONENTS.register_module()
class GLDilationNeck(nn.Module):
class GLDilationNeck(BaseModule):
"""Dilation Backbone used in Global&Local model.

This implementation follows:
Expand Down
Loading