From ea5ba63d47afc20e24fae8cb204fddf5ed5c0b9b Mon Sep 17 00:00:00 2001 From: Wenjia Wang <33180378+JasonBoy1@users.noreply.github.com> Date: Sat, 19 Sep 2020 11:49:15 +0800 Subject: [PATCH] Update tsrn.py --- src/model/tsrn.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/model/tsrn.py b/src/model/tsrn.py index 71404ec..e7a559a 100644 --- a/src/model/tsrn.py +++ b/src/model/tsrn.py @@ -41,7 +41,7 @@ def __init__(self, scale_factor=2, width=128, height=32, STN=False, srb_nums=5, block_ = [UpsampleBLock(2*hidden_units, 2) for _ in range(upsample_block_num)] block_.append(nn.Conv2d(2*hidden_units, in_planes, kernel_size=9, padding=4)) setattr(self, 'block%d' % (srb_nums + 3), nn.Sequential(*block_)) - self.tps_inputsize = [height//scale_factor, width//scale_factor] + self.tps_inputsize = [32, 64] tps_outputsize = [height//scale_factor, width//scale_factor] num_control_points = 20 tps_margins = [0.05, 0.05] @@ -60,7 +60,7 @@ def __init__(self, scale_factor=2, width=128, height=32, STN=False, srb_nums=5, def forward(self, x): # embed() if self.stn and self.training: - # x = F.interpolate(x, self.tps_inputsize, mode='bilinear', align_corners=True) + x = F.interpolate(x, self.tps_inputsize, mode='bilinear', align_corners=True) _, ctrl_points_x = self.stn_head(x) x, _ = self.tps(x, ctrl_points_x) block = {'1': self.block1(x)}