Skip to content

Commit

Permalink
[Feature] Add TTSR Net (open-mmlab#314)
Browse files Browse the repository at this point in the history
* [Feature] Add TTSRNet

* Rename

* Rename

* Rename

* Add license

Co-authored-by: liyinshuo <liyinshuo@sensetime.com>
  • Loading branch information
Yshuo-Li and liyinshuo authored May 24, 2021
1 parent 5ae1522 commit aecee6f
Show file tree
Hide file tree
Showing 4 changed files with 249 additions and 4 deletions.
4 changes: 2 additions & 2 deletions mmedit/models/backbones/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# yapf: enable
from .generation_backbones import ResnetGenerator, UnetGenerator
from .sr_backbones import (EDSR, RDN, SRCNN, BasicVSRNet, EDVRNet, IconVSR,
MSRResNet, RRDBNet, TOFlow)
MSRResNet, RRDBNet, TOFlow, TTSRNet)

__all__ = [
'MSRResNet', 'VGG16', 'PlainDecoder', 'SimpleEncoderDecoder',
Expand All @@ -25,5 +25,5 @@
'DeepFillEncoderDecoder', 'EDVRNet', 'IndexedUpsample', 'IndexNetEncoder',
'IndexNetDecoder', 'TOFlow', 'ResGCAEncoder', 'ResGCADecoder', 'SRCNN',
'UnetGenerator', 'ResnetGenerator', 'FBAResnetDilated', 'FBADecoder',
'BasicVSRNet', 'IconVSR'
'BasicVSRNet', 'IconVSR', 'TTSRNet'
]
3 changes: 2 additions & 1 deletion mmedit/models/backbones/sr_backbones/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,9 @@
from .sr_resnet import MSRResNet
from .srcnn import SRCNN
from .tof import TOFlow
from .ttsr_net import TTSRNet

__all__ = [
'MSRResNet', 'RRDBNet', 'EDSR', 'EDVRNet', 'TOFlow', 'SRCNN',
'BasicVSRNet', 'IconVSR', 'RDN'
'BasicVSRNet', 'IconVSR', 'RDN', 'TTSRNet'
]
226 changes: 225 additions & 1 deletion mmedit/models/backbones/sr_backbones/ttsr_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,12 @@
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import build_conv_layer
from mmcv.runner import load_checkpoint

from mmedit.models.common import ResidualBlockNoBN, make_layer
from mmedit.models.common import (PixelShufflePack, ResidualBlockNoBN,
make_layer)
from mmedit.models.registry import BACKBONES
from mmedit.utils import get_root_logger

# Use partial to specify some default arguments
_conv3x3_layer = partial(
Expand Down Expand Up @@ -175,6 +179,9 @@ class MergeFeatures(nn.Module):
Final module of Texture Transformer Network for Image Super-Resolution.
Args:
mid_channels (int): Channel number of intermediate features
out_channels (int): Number of channels in the output image
"""

def __init__(self, mid_channels, out_channels):
Expand Down Expand Up @@ -210,3 +217,220 @@ def forward(self, x1, x2, x4):
x = torch.clamp(x, -1, 1)

return x


@BACKBONES.register_module()
class TTSRNet(nn.Module):
"""TTSR network structure (main-net) for reference-based super-resolution.
Paper: Learning Texture Transformer Network for Image Super-Resolution
Adapted from 'https://github.com/researchmm/TTSR.git'
'https://github.com/researchmm/TTSR'
Copyright permission at 'https://github.com/researchmm/TTSR/issues/38'.
Args:
in_channels (int): Number of channels in the input image
out_channels (int): Number of channels in the output image
mid_channels (int): Channel number of intermediate features.
Default: 64
num_blocks (tuple[int]): Block numbers in the trunk network.
Default: (16, 16, 8, 4)
res_scale (float): Used to scale the residual in residual block.
Default: 1.
"""

def __init__(self,
in_channels,
out_channels,
mid_channels=64,
texture_channels=64,
num_blocks=(16, 16, 8, 4),
res_scale=1.0):
super().__init__()

self.texture_channels = texture_channels

self.sfe = SFE(in_channels, mid_channels, num_blocks[0], res_scale)

# stage 1
self.conv_first1 = _conv3x3_layer(4 * texture_channels + mid_channels,
mid_channels)

self.res_block1 = make_layer(
ResidualBlockNoBN,
num_blocks[1],
mid_channels=mid_channels,
res_scale=res_scale)

self.conv_last1 = _conv3x3_layer(mid_channels, mid_channels)

# up-sampling 1 -> 2
self.up1 = PixelShufflePack(
in_channels=mid_channels,
out_channels=mid_channels,
scale_factor=2,
upsample_kernel=3)

# stage 2
self.conv_first2 = _conv3x3_layer(2 * texture_channels + mid_channels,
mid_channels)

self.csfi2 = CSFI2(mid_channels)

self.res_block2_1 = make_layer(
ResidualBlockNoBN,
num_blocks[2],
mid_channels=mid_channels,
res_scale=res_scale)
self.res_block2_2 = make_layer(
ResidualBlockNoBN,
num_blocks[2],
mid_channels=mid_channels,
res_scale=res_scale)

self.conv_last2_1 = _conv3x3_layer(mid_channels, mid_channels)
self.conv_last2_2 = _conv3x3_layer(mid_channels, mid_channels)

# up-sampling 2 -> 3
self.up2 = PixelShufflePack(
in_channels=mid_channels,
out_channels=mid_channels,
scale_factor=2,
upsample_kernel=3)

# stage 3
self.conv_first3 = _conv3x3_layer(texture_channels + mid_channels,
mid_channels)

self.csfi3 = CSFI3(mid_channels)

self.res_block3_1 = make_layer(
ResidualBlockNoBN,
num_blocks[3],
mid_channels=mid_channels,
res_scale=res_scale)
self.res_block3_2 = make_layer(
ResidualBlockNoBN,
num_blocks[3],
mid_channels=mid_channels,
res_scale=res_scale)
self.res_block3_3 = make_layer(
ResidualBlockNoBN,
num_blocks[3],
mid_channels=mid_channels,
res_scale=res_scale)

self.conv_last3_1 = _conv3x3_layer(mid_channels, mid_channels)
self.conv_last3_2 = _conv3x3_layer(mid_channels, mid_channels)
self.conv_last3_3 = _conv3x3_layer(mid_channels, mid_channels)

# end, merge features
self.merge_features = MergeFeatures(mid_channels, out_channels)

def forward(self, x, s=None, t_level3=None, t_level2=None, t_level1=None):
"""Forward function.
Args:
x (Tensor): Input tensor with shape (n, c, h, w).
s (Tensor): Soft-Attention tensor with shape (n, 1, h, w).
t_level3 (Tensor): Transferred HR texture T in level3.
(n, 4c, h, w)
t_level2 (Tensor): Transferred HR texture T in level2.
(n, 2c, 2h, 2w)
t_level1 (Tensor): Transferred HR texture T in level1.
(n, c, 4h, 4w)
Returns:
Tensor: Forward results.
"""

assert t_level1.shape[1] == self.texture_channels

x1 = self.sfe(x)

# stage 1
x1_res = torch.cat((x1, t_level3), dim=1)
x1_res = self.conv_first1(x1_res)

# soft-attention
x1 = x1 + x1_res * s

x1_res = self.res_block1(x1)
x1_res = self.conv_last1(x1_res)

x1 = x1 + x1_res

# stage 2
x21 = x1
x22 = self.up1(x1)
x22 = F.relu(x22)

x22_res = torch.cat((x22, t_level2), dim=1)
x22_res = self.conv_first2(x22_res)

# soft-attention
x22_res = x22_res * F.interpolate(
s, scale_factor=2, mode='bicubic', align_corners=False)
x22 = x22 + x22_res

x21_res, x22_res = self.csfi2(x21, x22)

x21_res = self.res_block2_1(x21_res)
x22_res = self.res_block2_2(x22_res)

x21_res = self.conv_last2_1(x21_res)
x22_res = self.conv_last2_2(x22_res)

x21 = x21 + x21_res
x22 = x22 + x22_res

# stage 3
x31 = x21
x32 = x22
x33 = self.up2(x22)
x33 = F.relu(x33)

x33_res = torch.cat((x33, t_level1), dim=1)
x33_res = self.conv_first3(x33_res)

# soft-attention
x33_res = x33_res * F.interpolate(
s, scale_factor=4, mode='bicubic', align_corners=False)
x33 = x33 + x33_res

x31_res, x32_res, x33_res = self.csfi3(x31, x32, x33)

x31_res = self.res_block3_1(x31_res)
x32_res = self.res_block3_2(x32_res)
x33_res = self.res_block3_3(x33_res)

x31_res = self.conv_last3_1(x31_res)
x32_res = self.conv_last3_2(x32_res)
x33_res = self.conv_last3_3(x33_res)

x31 = x31 + x31_res
x32 = x32 + x32_res
x33 = x33 + x33_res
x = self.merge_features(x31, x32, x33)

return x

def init_weights(self, pretrained=None, strict=True):
"""Init weights for models.
Args:
pretrained (str, optional): Path for pretrained weights. If given
None, pretrained weights will not be loaded. Defaults to None.
strict (boo, optional): Whether strictly load the pretrained model.
Defaults to True.
"""
if isinstance(pretrained, str):
logger = get_root_logger()
load_checkpoint(self, pretrained, strict=strict, logger=logger)
elif pretrained is None:
pass # use default initialization
else:
raise TypeError('"pretrained" must be a str or None. '
f'But received {type(pretrained)}.')
20 changes: 20 additions & 0 deletions tests/test_ttsr.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import torch

from mmedit.models import build_backbone
from mmedit.models.backbones.sr_backbones.ttsr_net import (CSFI2, CSFI3, SFE,
MergeFeatures)

Expand Down Expand Up @@ -36,3 +37,22 @@ def test_merge_features():
merge_features = MergeFeatures(mid_channels=16, out_channels=3)
out = merge_features(inputs1, inputs2, inputs4)
assert out.shape == (2, 3, 96, 96)


def test_ttsr_net():
inputs = torch.rand(2, 3, 24, 24)
s = torch.rand(2, 1, 24, 24)
t_level3 = torch.rand(2, 64, 24, 24)
t_level2 = torch.rand(2, 32, 48, 48)
t_level1 = torch.rand(2, 16, 96, 96)

ttsr_cfg = dict(
type='TTSRNet',
in_channels=3,
out_channels=3,
mid_channels=16,
texture_channels=16)
ttsr = build_backbone(ttsr_cfg)
outputs = ttsr(inputs, s, t_level3, t_level2, t_level1)

assert outputs.shape == (2, 3, 96, 96)

0 comments on commit aecee6f

Please sign in to comment.