Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Add Structural Feature Encoder #311

Merged
merged 4 commits into from
May 17, 2021
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 71 additions & 0 deletions mmedit/models/backbones/sr_backbones/ttsr_net.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
import torch.nn as nn
import torch.nn.functional as F

from mmedit.models.common import ResidualBlockNoBN, make_layer


def norm_conv_layer(in_channels, out_channels, stride=1):
Yshuo-Li marked this conversation as resolved.
Show resolved Hide resolved
"""Norm conv layer with kernal_size=3.

Args:
in_channels (int): Number of channels in the input image
out_channels (int): Number of channels produced by the convolution
stride (int or tuple, optional): Stride of the convolution. Default: 1

results:
conv_layer (Conv2d): Conv layer with kernal_size=3.
"""

conv_layer = nn.Conv2d(
in_channels,
out_channels,
kernel_size=3,
stride=stride,
padding=1,
bias=True)
return conv_layer


class SFE(nn.Module):
"""Structural Feature Encoder

Backbone of Texture Transformer Network for Image Super-Resolution.

Args:
in_channels (int): Number of channels in the input image
mid_channels (int): Channel number of intermediate features
num_blocks (int): Block number in the trunk network
res_scale (float): Used to scale the residual in residual block.
Default: 1.
"""

def __init__(self, in_channels, mid_channels, num_blocks, res_scale):
super().__init__()

self.num_blocks = num_blocks
self.conv_head = norm_conv_layer(in_channels, mid_channels)

self.body = make_layer(
ResidualBlockNoBN,
num_blocks,
mid_channels=mid_channels,
res_scale=res_scale)

self.conv_tail = norm_conv_layer(mid_channels, mid_channels)

def forward(self, x):
"""Forward function.

Args:
x (Tensor): Input tensor with shape (n, c, h, w).

Returns:
Tensor: Forward results.
"""

x = F.relu(self.conv_head(x))
Yshuo-Li marked this conversation as resolved.
Show resolved Hide resolved
x1 = x
x = self.body(x)
x = self.conv_tail(x)
x = x + x1
return x
14 changes: 14 additions & 0 deletions tests/test_ttsr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
import torch

from mmedit.models.backbones.sr_backbones.ttsr_net import SFE


def test_ttsr():
Yshuo-Li marked this conversation as resolved.
Show resolved Hide resolved
inputs = torch.rand(2, 3, 48, 48)
sfe = SFE(3, 64, 16, 1.)
outputs = sfe(inputs)
assert outputs.shape == (2, 64, 48, 48)


if __name__ == '__main__':
test_ttsr()