-
Notifications
You must be signed in to change notification settings - Fork 1.1k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
3 changed files
with
93 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,5 @@ | ||
from .deepfill_refiner import DeepFillRefiner | ||
from .mlp_refiner import MLPRefiner | ||
from .plain_refiner import PlainRefiner | ||
|
||
__all__ = ['PlainRefiner', 'DeepFillRefiner'] | ||
__all__ = ['PlainRefiner', 'DeepFillRefiner', 'MLPRefiner'] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
import torch.nn as nn | ||
from mmcv.runner import load_checkpoint | ||
|
||
from mmedit.models.registry import COMPONENTS | ||
from mmedit.utils import get_root_logger | ||
|
||
|
||
@COMPONENTS.register_module() | ||
class MLPRefiner(nn.Module): | ||
"""Multilayer perceptrons (MLPs), refiner used in LIIF. | ||
Args: | ||
in_dim (int): Input dimension. | ||
out_dim (int): Output dimension. | ||
hidden_list (list[int]): List of hidden dimensions. | ||
""" | ||
|
||
def __init__(self, in_dim, out_dim, hidden_list): | ||
super().__init__() | ||
layers = [] | ||
lastv = in_dim | ||
for hidden in hidden_list: | ||
layers.append(nn.Linear(lastv, hidden)) | ||
layers.append(nn.ReLU()) | ||
lastv = hidden | ||
layers.append(nn.Linear(lastv, out_dim)) | ||
self.layers = nn.Sequential(*layers) | ||
|
||
def forward(self, x): | ||
"""Forward function. | ||
Args: | ||
x (Tensor): The input of MLP. | ||
Returns: | ||
Tensor: The output of MLP. | ||
""" | ||
shape = x.shape[:-1] | ||
x = self.layers(x.view(-1, x.shape[-1])) | ||
return x.view(*shape, -1) | ||
|
||
def init_weights(self, pretrained=None, strict=True): | ||
"""Init weights for models. | ||
Args: | ||
pretrained (str, optional): Path for pretrained weights. If given | ||
None, pretrained weights will not be loaded. Defaults to None. | ||
strict (boo, optional): Whether strictly load the pretrained model. | ||
Defaults to True. | ||
""" | ||
if isinstance(pretrained, str): | ||
logger = get_root_logger() | ||
load_checkpoint(self, pretrained, strict=strict, logger=logger) | ||
elif pretrained is None: | ||
pass | ||
else: | ||
raise TypeError(f'"pretrained" must be a str or None. ' | ||
f'But received {type(pretrained)}.') |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
import torch | ||
import torch.nn as nn | ||
|
||
from mmedit.models.builder import build_component | ||
|
||
|
||
def test_mlp_refiner(): | ||
model_cfg = dict( | ||
type='MLPRefiner', in_dim=8, out_dim=3, hidden_list=[8, 8, 8, 8]) | ||
mlp = build_component(model_cfg) | ||
|
||
# test attributes | ||
assert mlp.__class__.__name__ == 'MLPRefiner' | ||
|
||
# prepare data | ||
inputs = torch.rand(2, 8) | ||
targets = torch.rand(2, 3) | ||
if torch.cuda.is_available(): | ||
inputs = inputs.cuda() | ||
targets = targets.cuda() | ||
mlp = mlp.cuda() | ||
data_batch = {'in': inputs, 'target': targets} | ||
# prepare optimizer | ||
criterion = nn.L1Loss() | ||
optimizer = torch.optim.Adam(mlp.parameters(), lr=1e-4) | ||
|
||
# test train_step | ||
output = mlp.forward(data_batch['in']) | ||
assert output.shape == data_batch['target'].shape | ||
loss = criterion(output, data_batch['target']) | ||
optimizer.zero_grad() | ||
loss.backward() | ||
optimizer.step() |