Skip to content

Commit

Permalink
re-update all
Browse files Browse the repository at this point in the history
  • Loading branch information
liyinshuo committed Jun 28, 2021
1 parent fb2bebb commit 3fd231d
Show file tree
Hide file tree
Showing 7 changed files with 86 additions and 81 deletions.
33 changes: 18 additions & 15 deletions mmedit/models/backbones/sr_backbones/ttsr_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,33 +329,30 @@ def __init__(self,
# end, merge features
self.merge_features = MergeFeatures(mid_channels, out_channels)

def forward(self, x, s=None, t_level3=None, t_level2=None, t_level1=None):
def forward(self, x, soft_attention, textures):
"""Forward function.
Args:
x (Tensor): Input tensor with shape (n, c, h, w).
s (Tensor): Soft-Attention tensor with shape (n, 1, h, w).
t_level3 (Tensor): Transferred HR texture T in level3.
(n, 4c, h, w)
t_level2 (Tensor): Transferred HR texture T in level2.
(n, 2c, 2h, 2w)
t_level1 (Tensor): Transferred HR texture T in level1.
(n, c, 4h, 4w)
soft_attention (Tensor): Soft-Attention tensor with shape
(n, 1, h, w).
textures (Tuple[Tensor]): Transferred HR texture tensors.
[(N, C, H, W), (N, C/2, 2H, 2W), ...]
Returns:
Tensor: Forward results.
"""

assert t_level1.shape[1] == self.texture_channels
assert textures[-1].shape[1] == self.texture_channels

x1 = self.sfe(x)

# stage 1
x1_res = torch.cat((x1, t_level3), dim=1)
x1_res = torch.cat((x1, textures[0]), dim=1)
x1_res = self.conv_first1(x1_res)

# soft-attention
x1 = x1 + x1_res * s
x1 = x1 + x1_res * soft_attention

x1_res = self.res_block1(x1)
x1_res = self.conv_last1(x1_res)
Expand All @@ -367,12 +364,15 @@ def forward(self, x, s=None, t_level3=None, t_level2=None, t_level1=None):
x22 = self.up1(x1)
x22 = F.relu(x22)

x22_res = torch.cat((x22, t_level2), dim=1)
x22_res = torch.cat((x22, textures[1]), dim=1)
x22_res = self.conv_first2(x22_res)

# soft-attention
x22_res = x22_res * F.interpolate(
s, scale_factor=2, mode='bicubic', align_corners=False)
soft_attention,
scale_factor=2,
mode='bicubic',
align_corners=False)
x22 = x22 + x22_res

x21_res, x22_res = self.csfi2(x21, x22)
Expand All @@ -392,12 +392,15 @@ def forward(self, x, s=None, t_level3=None, t_level2=None, t_level1=None):
x33 = self.up2(x22)
x33 = F.relu(x33)

x33_res = torch.cat((x33, t_level1), dim=1)
x33_res = torch.cat((x33, textures[2]), dim=1)
x33_res = self.conv_first3(x33_res)

# soft-attention
x33_res = x33_res * F.interpolate(
s, scale_factor=4, mode='bicubic', align_corners=False)
soft_attention,
scale_factor=4,
mode='bicubic',
align_corners=False)
x33 = x33 + x33_res

x31_res, x32_res, x33_res = self.csfi3(x31, x32, x33)
Expand Down
10 changes: 5 additions & 5 deletions mmedit/models/extractors/lte.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,10 +63,10 @@ def forward(self, x):
x (Tensor): Input tensor with shape (n, 3, h, w).
Returns:
Forward results in 3 levels.
x_level1 (Tensor): Forward results in level 1 (n, 64, h, w).
x_level2 (Tensor): Forward results in level 2 (n, 128, h/2, w/2).
x_level3 (Tensor): Forward results in level 3 (n, 256, h/4, w/4).
Tuple[Tensor]: Forward results in 3 levels.
x_level3: Forward results in level 3 (n, 256, h/4, w/4).
x_level2: Forward results in level 2 (n, 128, h/2, w/2).
x_level1: Forward results in level 1 (n, 64, h, w).
"""

x = self.img_normalize(x)
Expand All @@ -75,7 +75,7 @@ def forward(self, x):
x_level2 = x = self.slice2(x)
x_level3 = x = self.slice3(x)

return x_level1, x_level2, x_level3
return [x_level3, x_level2, x_level1]

def init_weights(self, pretrained=None, strict=True):
"""Init weights for models.
Expand Down
91 changes: 39 additions & 52 deletions mmedit/models/transformers/search_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,13 +35,12 @@ def gather(self, inputs, dim, index):

return outputs

def forward(self, lq_up_level3, ref_downup_level3, ref_level1, ref_level2,
ref_level3):
def forward(self, lq_up, ref_downup, refs):
"""Texture transformer
Q = LTE(lq_up)
K = LTE(ref_downup)
V = LTE(ref), from V_level1 to V_level3
V = LTE(ref), from V_level_n to V_level_1
Relevance embedding aims to embed the relevance between the LQ and
Ref image by estimating the similarity between Q and K.
Expand All @@ -51,41 +50,40 @@ def forward(self, lq_up_level3, ref_downup_level3, ref_level1, ref_level2,
features T and the LQ features F from the backbone.
Args:
All args are features come from extractor (sucn as LTE).
All args are features come from extractor (such as LTE).
These features contain 3 levels.
When upscale_factor=4, the size ratio of these features is
level1:level2:level3 = 4:2:1.
lq_up_level3 (Tensor): level3 feature of 4x bicubic-upsampled lq
image. (N, 4C, H, W)
ref_downup_level3 (Tensor): level3 feature of ref_downup.
ref_downup is obtained by applying bicubic down-sampling and
up-sampling with factor 4x on ref. (N, 4C, H, W)
ref_level1 (Tensor): level1 feature of ref image. (N, C, 4H, 4W)
ref_level2 (Tensor): level2 feature of ref image. (N, 2C, 2H, 2W)
ref_level3 (Tensor): level3 feature of ref image. (N, 4C, H, W)
level3:level2:level1 = 1:2:4.
lq_up (Tensor): Tensor of 4x bicubic-upsampled lq image.
(N, C, H, W)
ref_downup (Tensor): Tensor of ref_downup. ref_downup is obtained
by applying bicubic down-sampling and up-sampling with factor
4x on ref. (N, C, H, W)
refs (Tuple[Tensor]): Tuple of ref tensors.
[(N, C, H, W), (N, C/2, 2H, 2W), ...]
Returns:
s (Tensor): Soft-Attention tensor. (N, 1, H, W)
t_level3 (Tensor): Transferred GT texture T in level3.
(N, 4C, H, W)
t_level2 (Tensor): Transferred GT texture T in level2.
(N, 2C, 2H, 2W)
t_level1 (Tensor): Transferred GT texture T in level1.
(N, C, 4H, 4W)
soft_attention (Tensor): Soft-Attention tensor. (N, 1, H, W)
textures (Tuple[Tensor]): Transferred GT textures.
[(N, C, H, W), (N, C/2, 2H, 2W), ...]
"""

levels = len(refs)
# query
query = F.unfold(lq_up_level3, kernel_size=(3, 3), padding=1)
query = F.unfold(lq_up, kernel_size=(3, 3), padding=1)

# key
key = F.unfold(ref_downup_level3, kernel_size=(3, 3), padding=1)
key = F.unfold(ref_downup, kernel_size=(3, 3), padding=1)
key_t = key.permute(0, 2, 1)

# values
value_level3 = F.unfold(ref_level3, kernel_size=(3, 3), padding=1)
value_level2 = F.unfold(
ref_level2, kernel_size=(6, 6), padding=2, stride=2)
value_level1 = F.unfold(
ref_level1, kernel_size=(12, 12), padding=4, stride=4)
values = [
F.unfold(
refs[i],
kernel_size=3 * pow(2, i),
padding=pow(2, i),
stride=pow(2, i)) for i in range(levels)
]

key_t = F.normalize(key_t, dim=2) # [N, H*W, C*k*k]
query = F.normalize(query, dim=1) # [N, C*k*k, H*W]
Expand All @@ -95,30 +93,19 @@ def forward(self, lq_up_level3, ref_downup_level3, ref_level1, ref_level2,
max_val, max_index = torch.max(rel_embedding, dim=1) # [N, H*W]

# hard-attention
t_level3_unfold = self.gather(value_level3, 2, max_index)
t_level2_unfold = self.gather(value_level2, 2, max_index)
t_level1_unfold = self.gather(value_level1, 2, max_index)
textures = [self.gather(value, 2, max_index) for value in values]

# to tensor
t_level3 = F.fold(
t_level3_unfold,
output_size=lq_up_level3.size()[-2:],
kernel_size=(3, 3),
padding=1) / (3. * 3.)
t_level2 = F.fold(
t_level2_unfold,
output_size=(lq_up_level3.size(2) * 2, lq_up_level3.size(3) * 2),
kernel_size=(6, 6),
padding=2,
stride=2) / (3. * 3.)
t_level1 = F.fold(
t_level1_unfold,
output_size=(lq_up_level3.size(2) * 4, lq_up_level3.size(3) * 4),
kernel_size=(12, 12),
padding=4,
stride=4) / (3. * 3.)

s = max_val.view(
max_val.size(0), 1, lq_up_level3.size(2), lq_up_level3.size(3))

return s, t_level3, t_level2, t_level1
h, w = lq_up.size()[-2:]
textures = [
F.fold(
textures[i],
output_size=(h * pow(2, i), w * pow(2, i)),
kernel_size=3 * pow(2, i),
padding=pow(2, i),
stride=pow(2, i)) / 9. for i in range(levels)
]

soft_attention = max_val.view(max_val.size(0), 1, h, w)

return soft_attention, textures
4 changes: 0 additions & 4 deletions tests/test_models/test_common/test_img_normalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,3 @@ def test_normalize_layer():
std_y = y.std(dim=1)
assert sum(torch.div(std_x, std_y) - rgb_std) < 1e-5
assert sum(torch.div(mean_x - rgb_mean, rgb_std) - mean_y) < 1e-5


if __name__ == '__main__':
test_normalize_layer()
4 changes: 2 additions & 2 deletions tests/test_models/test_extractors/test_lte.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def test_lte():

x = torch.rand(2, 3, 64, 64)

x_level1, x_level2, x_level3 = lte(x)
x_level3, x_level2, x_level1 = lte(x)
assert x_level1.shape == (2, 64, 64, 64)
assert x_level2.shape == (2, 128, 32, 32)
assert x_level3.shape == (2, 256, 16, 16)
Expand All @@ -22,7 +22,7 @@ def test_lte():
with pytest.raises(IOError):
model_cfg['pretrained'] = ''
lte = build_component(model_cfg)
x_level1, x_level2, x_level3 = lte(x)
x_level3, x_level2, x_level1 = lte(x)
lte.init_weights('')
with pytest.raises(TypeError):
lte.init_weights(1)
18 changes: 18 additions & 0 deletions tests/test_models/test_restorers/test_ttsr.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,24 @@ def test_ttsr_net():


def test_ttsr():
model_cfg = dict(
type='TTSR',
generator=dict(
type='TTSRNet',
in_channels=3,
out_channels=3,
mid_channels=64,
num_blocks=(16, 16, 8, 4)),
extractor=dict(type='LTE'),
transformer=dict(type='SearchTransformer'),
pixel_loss=dict(type='L1Loss', loss_weight=1.0, reduction='mean'))

scale = 4
train_cfg = None
test_cfg = Config(dict(metrics=['PSNR', 'SSIM'], crop_border=scale))

# build restorer
restorer = build_model(model_cfg, train_cfg=train_cfg, test_cfg=test_cfg)

model_cfg = dict(
type='TTSR',
Expand Down
7 changes: 4 additions & 3 deletions tests/test_models/test_transformer/test_search_transformer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import torch

from mmedit.models import build_component
from mmedit.models.builder import build_component


def test_search_transformer():
Expand All @@ -13,8 +13,9 @@ def test_search_transformer():
ref_level2 = torch.randn((2, 16, 64, 64))
ref_level1 = torch.randn((2, 8, 128, 128))

s, t_level3, t_level2, t_level1 = model(lr_pad_level3, ref_pad_level3,
ref_level1, ref_level2, ref_level3)
s, textures = model(lr_pad_level3, ref_pad_level3,
(ref_level3, ref_level2, ref_level1))
t_level3, t_level2, t_level1 = textures

assert s.shape == (2, 1, 32, 32)
assert t_level3.shape == (2, 32, 32, 32)
Expand Down

0 comments on commit 3fd231d

Please sign in to comment.