Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Add TTSR Net #314

Merged
merged 5 commits into from
May 24, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions mmedit/models/backbones/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# yapf: enable
from .generation_backbones import ResnetGenerator, UnetGenerator
from .sr_backbones import (EDSR, RDN, SRCNN, BasicVSRNet, EDVRNet, IconVSR,
MSRResNet, RRDBNet, TOFlow)
MSRResNet, RRDBNet, TOFlow, TTSRNet)

__all__ = [
'MSRResNet', 'VGG16', 'PlainDecoder', 'SimpleEncoderDecoder',
Expand All @@ -25,5 +25,5 @@
'DeepFillEncoderDecoder', 'EDVRNet', 'IndexedUpsample', 'IndexNetEncoder',
'IndexNetDecoder', 'TOFlow', 'ResGCAEncoder', 'ResGCADecoder', 'SRCNN',
'UnetGenerator', 'ResnetGenerator', 'FBAResnetDilated', 'FBADecoder',
'BasicVSRNet', 'IconVSR'
'BasicVSRNet', 'IconVSR', 'TTSRNet'
]
3 changes: 2 additions & 1 deletion mmedit/models/backbones/sr_backbones/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,9 @@
from .sr_resnet import MSRResNet
from .srcnn import SRCNN
from .tof import TOFlow
from .ttsr_net import TTSRNet

__all__ = [
'MSRResNet', 'RRDBNet', 'EDSR', 'EDVRNet', 'TOFlow', 'SRCNN',
'BasicVSRNet', 'IconVSR', 'RDN'
'BasicVSRNet', 'IconVSR', 'RDN', 'TTSRNet'
]
226 changes: 225 additions & 1 deletion mmedit/models/backbones/sr_backbones/ttsr_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,12 @@
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import build_conv_layer
from mmcv.runner import load_checkpoint

from mmedit.models.common import ResidualBlockNoBN, make_layer
from mmedit.models.common import (PixelShufflePack, ResidualBlockNoBN,
make_layer)
from mmedit.models.registry import BACKBONES
from mmedit.utils import get_root_logger

# Use partial to specify some default arguments
_conv3x3_layer = partial(
Expand Down Expand Up @@ -175,6 +179,9 @@ class MergeFeatures(nn.Module):
Final module of Texture Transformer Network for Image Super-Resolution.
Args:
mid_channels (int): Channel number of intermediate features
out_channels (int): Number of channels in the output image
"""

def __init__(self, mid_channels, out_channels):
Expand Down Expand Up @@ -210,3 +217,220 @@ def forward(self, x1, x2, x4):
x = torch.clamp(x, -1, 1)

return x


@BACKBONES.register_module()
class TTSRNet(nn.Module):
"""TTSR network structure (main-net) for reference-based super-resolution.
Paper: Learning Texture Transformer Network for Image Super-Resolution
Adapted from 'https://github.com/researchmm/TTSR.git'
'https://github.com/researchmm/TTSR'
Copyright permission at 'https://github.com/researchmm/TTSR/issues/38'.
Args:
in_channels (int): Number of channels in the input image
out_channels (int): Number of channels in the output image
mid_channels (int): Channel number of intermediate features.
Default: 64
num_blocks (tuple[int]): Block numbers in the trunk network.
Default: (16, 16, 8, 4)
res_scale (float): Used to scale the residual in residual block.
Default: 1.
"""

def __init__(self,
in_channels,
out_channels,
mid_channels=64,
texture_channels=64,
num_blocks=(16, 16, 8, 4),
res_scale=1.0):
super().__init__()

self.texture_channels = texture_channels

self.sfe = SFE(in_channels, mid_channels, num_blocks[0], res_scale)

# stage 1
self.conv_first1 = _conv3x3_layer(4 * texture_channels + mid_channels,
mid_channels)

self.res_block1 = make_layer(
ResidualBlockNoBN,
num_blocks[1],
mid_channels=mid_channels,
res_scale=res_scale)

self.conv_last1 = _conv3x3_layer(mid_channels, mid_channels)

# up-sampling 1 -> 2
self.up1 = PixelShufflePack(
in_channels=mid_channels,
out_channels=mid_channels,
scale_factor=2,
upsample_kernel=3)

# stage 2
self.conv_first2 = _conv3x3_layer(2 * texture_channels + mid_channels,
mid_channels)

self.csfi2 = CSFI2(mid_channels)

self.res_block2_1 = make_layer(
ResidualBlockNoBN,
num_blocks[2],
mid_channels=mid_channels,
res_scale=res_scale)
self.res_block2_2 = make_layer(
ResidualBlockNoBN,
num_blocks[2],
mid_channels=mid_channels,
res_scale=res_scale)

self.conv_last2_1 = _conv3x3_layer(mid_channels, mid_channels)
self.conv_last2_2 = _conv3x3_layer(mid_channels, mid_channels)

# up-sampling 2 -> 3
self.up2 = PixelShufflePack(
in_channels=mid_channels,
out_channels=mid_channels,
scale_factor=2,
upsample_kernel=3)

# stage 3
self.conv_first3 = _conv3x3_layer(texture_channels + mid_channels,
mid_channels)

self.csfi3 = CSFI3(mid_channels)

self.res_block3_1 = make_layer(
ResidualBlockNoBN,
num_blocks[3],
mid_channels=mid_channels,
res_scale=res_scale)
self.res_block3_2 = make_layer(
ResidualBlockNoBN,
num_blocks[3],
mid_channels=mid_channels,
res_scale=res_scale)
self.res_block3_3 = make_layer(
ResidualBlockNoBN,
num_blocks[3],
mid_channels=mid_channels,
res_scale=res_scale)

self.conv_last3_1 = _conv3x3_layer(mid_channels, mid_channels)
self.conv_last3_2 = _conv3x3_layer(mid_channels, mid_channels)
self.conv_last3_3 = _conv3x3_layer(mid_channels, mid_channels)

# end, merge features
self.merge_features = MergeFeatures(mid_channels, out_channels)

def forward(self, x, s=None, t_level3=None, t_level2=None, t_level1=None):
"""Forward function.
Args:
x (Tensor): Input tensor with shape (n, c, h, w).
s (Tensor): Soft-Attention tensor with shape (n, 1, h, w).
t_level3 (Tensor): Transferred HR texture T in level3.
(n, 4c, h, w)
t_level2 (Tensor): Transferred HR texture T in level2.
(n, 2c, 2h, 2w)
t_level1 (Tensor): Transferred HR texture T in level1.
(n, c, 4h, 4w)
Returns:
Tensor: Forward results.
"""

assert t_level1.shape[1] == self.texture_channels

x1 = self.sfe(x)

# stage 1
x1_res = torch.cat((x1, t_level3), dim=1)
x1_res = self.conv_first1(x1_res)

# soft-attention
x1 = x1 + x1_res * s

x1_res = self.res_block1(x1)
x1_res = self.conv_last1(x1_res)

x1 = x1 + x1_res

# stage 2
x21 = x1
x22 = self.up1(x1)
x22 = F.relu(x22)

x22_res = torch.cat((x22, t_level2), dim=1)
x22_res = self.conv_first2(x22_res)

# soft-attention
x22_res = x22_res * F.interpolate(
s, scale_factor=2, mode='bicubic', align_corners=False)
x22 = x22 + x22_res

x21_res, x22_res = self.csfi2(x21, x22)

x21_res = self.res_block2_1(x21_res)
x22_res = self.res_block2_2(x22_res)

x21_res = self.conv_last2_1(x21_res)
x22_res = self.conv_last2_2(x22_res)

x21 = x21 + x21_res
x22 = x22 + x22_res

# stage 3
x31 = x21
x32 = x22
x33 = self.up2(x22)
x33 = F.relu(x33)

x33_res = torch.cat((x33, t_level1), dim=1)
x33_res = self.conv_first3(x33_res)

# soft-attention
x33_res = x33_res * F.interpolate(
s, scale_factor=4, mode='bicubic', align_corners=False)
x33 = x33 + x33_res

x31_res, x32_res, x33_res = self.csfi3(x31, x32, x33)

x31_res = self.res_block3_1(x31_res)
x32_res = self.res_block3_2(x32_res)
x33_res = self.res_block3_3(x33_res)

x31_res = self.conv_last3_1(x31_res)
x32_res = self.conv_last3_2(x32_res)
x33_res = self.conv_last3_3(x33_res)

x31 = x31 + x31_res
x32 = x32 + x32_res
x33 = x33 + x33_res
x = self.merge_features(x31, x32, x33)

return x

def init_weights(self, pretrained=None, strict=True):
"""Init weights for models.
Args:
pretrained (str, optional): Path for pretrained weights. If given
None, pretrained weights will not be loaded. Defaults to None.
strict (boo, optional): Whether strictly load the pretrained model.
Defaults to True.
"""
if isinstance(pretrained, str):
logger = get_root_logger()
load_checkpoint(self, pretrained, strict=strict, logger=logger)
elif pretrained is None:
pass # use default initialization
else:
raise TypeError('"pretrained" must be a str or None. '
f'But received {type(pretrained)}.')
20 changes: 20 additions & 0 deletions tests/test_ttsr.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import torch

from mmedit.models import build_backbone
from mmedit.models.backbones.sr_backbones.ttsr_net import (CSFI2, CSFI3, SFE,
MergeFeatures)

Expand Down Expand Up @@ -36,3 +37,22 @@ def test_merge_features():
merge_features = MergeFeatures(mid_channels=16, out_channels=3)
out = merge_features(inputs1, inputs2, inputs4)
assert out.shape == (2, 3, 96, 96)


def test_ttsr_net():
inputs = torch.rand(2, 3, 24, 24)
s = torch.rand(2, 1, 24, 24)
t_level3 = torch.rand(2, 64, 24, 24)
t_level2 = torch.rand(2, 32, 48, 48)
t_level1 = torch.rand(2, 16, 96, 96)

ttsr_cfg = dict(
type='TTSRNet',
in_channels=3,
out_channels=3,
mid_channels=16,
texture_channels=16)
ttsr = build_backbone(ttsr_cfg)
outputs = ttsr(inputs, s, t_level3, t_level2, t_level1)

assert outputs.shape == (2, 3, 96, 96)