Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Add Structural Feature Encoder #311

Merged
merged 4 commits into from
May 17, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 55 additions & 0 deletions mmedit/models/backbones/sr_backbones/ttsr_net.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
from functools import partial

import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import build_conv_layer

from mmedit.models.common import ResidualBlockNoBN, make_layer

# Use partial to specify some default arguments
_norm_conv_layer = partial(
build_conv_layer, dict(type='Conv2d'), kernel_size=3, padding=1)


class SFE(nn.Module):
"""Structural Feature Encoder

Backbone of Texture Transformer Network for Image Super-Resolution.

Args:
in_channels (int): Number of channels in the input image
mid_channels (int): Channel number of intermediate features
num_blocks (int): Block number in the trunk network
res_scale (float): Used to scale the residual in residual block.
Default: 1.
"""

def __init__(self, in_channels, mid_channels, num_blocks, res_scale):
super().__init__()

self.num_blocks = num_blocks
self.conv_first = _norm_conv_layer(in_channels, mid_channels)

self.body = make_layer(
ResidualBlockNoBN,
num_blocks,
mid_channels=mid_channels,
res_scale=res_scale)

self.conv_last = _norm_conv_layer(mid_channels, mid_channels)

def forward(self, x):
"""Forward function.

Args:
x (Tensor): Input tensor with shape (n, c, h, w).

Returns:
Tensor: Forward results.
"""

x1 = x = F.relu(self.conv_first(x))
x = self.body(x)
x = self.conv_last(x)
x = x + x1
return x
14 changes: 14 additions & 0 deletions tests/test_ttsr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
import torch

from mmedit.models.backbones.sr_backbones.ttsr_net import SFE


def test_sfe():
inputs = torch.rand(2, 3, 48, 48)
sfe = SFE(3, 64, 16, 1.)
outputs = sfe(inputs)
assert outputs.shape == (2, 64, 48, 48)


if __name__ == '__main__':
test_sfe()