Skip to content

Commit

Permalink
Update tsrn.py
Browse files Browse the repository at this point in the history
  • Loading branch information
WenjiaWang0312 authored Sep 19, 2020
1 parent 5265ef2 commit ea5ba63
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions src/model/tsrn.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def __init__(self, scale_factor=2, width=128, height=32, STN=False, srb_nums=5,
block_ = [UpsampleBLock(2*hidden_units, 2) for _ in range(upsample_block_num)]
block_.append(nn.Conv2d(2*hidden_units, in_planes, kernel_size=9, padding=4))
setattr(self, 'block%d' % (srb_nums + 3), nn.Sequential(*block_))
self.tps_inputsize = [height//scale_factor, width//scale_factor]
self.tps_inputsize = [32, 64]
tps_outputsize = [height//scale_factor, width//scale_factor]
num_control_points = 20
tps_margins = [0.05, 0.05]
Expand All @@ -60,7 +60,7 @@ def __init__(self, scale_factor=2, width=128, height=32, STN=False, srb_nums=5,
def forward(self, x):
# embed()
if self.stn and self.training:
# x = F.interpolate(x, self.tps_inputsize, mode='bilinear', align_corners=True)
x = F.interpolate(x, self.tps_inputsize, mode='bilinear', align_corners=True)
_, ctrl_points_x = self.stn_head(x)
x, _ = self.tps(x, ctrl_points_x)
block = {'1': self.block1(x)}
Expand Down

0 comments on commit ea5ba63

Please sign in to comment.