Skip to content

Commit

Permalink
Merge bf55d34 into ad517d4
Browse files Browse the repository at this point in the history
  • Loading branch information
Yshuo-Li authored Mar 25, 2021
2 parents ad517d4 + bf55d34 commit 85ec6f5
Show file tree
Hide file tree
Showing 3 changed files with 93 additions and 1 deletion.
3 changes: 2 additions & 1 deletion mmedit/models/components/refiners/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .deepfill_refiner import DeepFillRefiner
from .mlp_refiner import MLPRefiner
from .plain_refiner import PlainRefiner

__all__ = ['PlainRefiner', 'DeepFillRefiner']
__all__ = ['PlainRefiner', 'DeepFillRefiner', 'MLPRefiner']
58 changes: 58 additions & 0 deletions mmedit/models/components/refiners/mlp_refiner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import torch.nn as nn
from mmcv.runner import load_checkpoint

from mmedit.models.registry import COMPONENTS
from mmedit.utils import get_root_logger


@COMPONENTS.register_module()
class MLPRefiner(nn.Module):
"""Multilayer perceptrons (MLPs), refiner used in LIIF.
Args:
in_dim (int): Input dimension.
out_dim (int): Output dimension.
hidden_list (list[int]): List of hidden dimensions.
"""

def __init__(self, in_dim, out_dim, hidden_list):
super().__init__()
layers = []
lastv = in_dim
for hidden in hidden_list:
layers.append(nn.Linear(lastv, hidden))
layers.append(nn.ReLU())
lastv = hidden
layers.append(nn.Linear(lastv, out_dim))
self.layers = nn.Sequential(*layers)

def forward(self, x):
"""Forward function.
Args:
x (Tensor): The input of MLP.
Returns:
Tensor: The output of MLP.
"""
shape = x.shape[:-1]
x = self.layers(x.view(-1, x.shape[-1]))
return x.view(*shape, -1)

def init_weights(self, pretrained=None, strict=True):
"""Init weights for models.
Args:
pretrained (str, optional): Path for pretrained weights. If given
None, pretrained weights will not be loaded. Defaults to None.
strict (boo, optional): Whether strictly load the pretrained model.
Defaults to True.
"""
if isinstance(pretrained, str):
logger = get_root_logger()
load_checkpoint(self, pretrained, strict=strict, logger=logger)
elif pretrained is None:
pass
else:
raise TypeError(f'"pretrained" must be a str or None. '
f'But received {type(pretrained)}.')
33 changes: 33 additions & 0 deletions tests/test_mlp_refiner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import torch
import torch.nn as nn

from mmedit.models.builder import build_component


def test_mlp_refiner():
model_cfg = dict(
type='MLPRefiner', in_dim=8, out_dim=3, hidden_list=[8, 8, 8, 8])
mlp = build_component(model_cfg)

# test attributes
assert mlp.__class__.__name__ == 'MLPRefiner'

# prepare data
inputs = torch.rand(2, 8)
targets = torch.rand(2, 3)
if torch.cuda.is_available():
inputs = inputs.cuda()
targets = targets.cuda()
mlp = mlp.cuda()
data_batch = {'in': inputs, 'target': targets}
# prepare optimizer
criterion = nn.L1Loss()
optimizer = torch.optim.Adam(mlp.parameters(), lr=1e-4)

# test train_step
output = mlp.forward(data_batch['in'])
assert output.shape == data_batch['target'].shape
loss = criterion(output, data_batch['target'])
optimizer.zero_grad()
loss.backward()
optimizer.step()

0 comments on commit 85ec6f5

Please sign in to comment.