Skip to content

Commit 39adfc7

Browse files
adam9500370meetps
authored andcommitted
Fix #125 and #154
1 parent 89f4abe commit 39adfc7

File tree

5 files changed

+21
-27
lines changed

5 files changed

+21
-27
lines changed

ptsemseg/loss.py

+3-5
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,10 @@ def cross_entropy2d(input, target, weight=None, size_average=True):
1111
# Handle inconsistent size between input and target
1212
if h > ht and w > wt: # upsample labels
1313
target = target.unsqueeze(1)
14-
target = F.upsample(target, size=(h, w), mode="nearest")
14+
target = F.interpolate(target.float(), size=(h, w), mode="nearest").long()
1515
target = target.squeeze(1)
1616
elif h < ht and w < wt: # upsample images
17-
input = F.upsample(input, size=(ht, wt), mode="bilinear")
17+
input = F.interpolate(input, size=(ht, wt), mode="bilinear", align_corners=True)
1818
elif h != ht and w != wt:
1919
raise Exception("Only support upsampling")
2020

@@ -72,9 +72,7 @@ def multi_scale_cross_entropy2d(
7272
if scale_weight == None: # scale_weight: torch tensor type
7373
n_inp = len(input)
7474
scale = 0.4
75-
scale_weight = torch.pow(scale * torch.ones(n_inp), torch.arange(n_inp))
76-
if input.is_cuda:
77-
scale_weight = scale_weight.cuda()
75+
scale_weight = torch.pow(scale * torch.ones(n_inp), torch.arange(n_inp).float()).to('cuda' if input.is_cuda else 'cpu')
7876

7977
loss = 0.0
8078
for i, inp in enumerate(input):

ptsemseg/loss/loss.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,10 @@ def cross_entropy2d(input, target, weight=None, size_average=True):
1111
# Handle inconsistent size between input and target
1212
if h > ht and w > wt: # upsample labels
1313
target = target.unsequeeze(1)
14-
target = F.upsample(target, size=(h, w), mode="nearest")
14+
target = F.interpolate(target.float(), size=(h, w), mode="nearest").long()
1515
target = target.sequeeze(1)
1616
elif h < ht and w < wt: # upsample images
17-
input = F.upsample(input, size=(ht, wt), mode="bilinear")
17+
input = F.interpolate(input, size=(ht, wt), mode="bilinear", align_corners=True)
1818
elif h != ht and w != wt:
1919
raise Exception("Only support upsampling")
2020

@@ -33,7 +33,7 @@ def multi_scale_cross_entropy2d(
3333
if scale_weight == None: # scale_weight: torch tensor type
3434
n_inp = len(input)
3535
scale = 0.4
36-
scale_weight = torch.pow(scale * torch.ones(n_inp), torch.arange(n_inp))
36+
scale_weight = torch.pow(scale * torch.ones(n_inp), torch.arange(n_inp).float()).to('cuda' if input.is_cuda else 'cpu')
3737

3838
loss = 0.0
3939
for i, inp in enumerate(input):

ptsemseg/models/icnet.py

+7-6
Original file line numberDiff line numberDiff line change
@@ -179,7 +179,7 @@ def forward(self, x):
179179
h, w = x.shape[2:]
180180

181181
# H, W -> H/2, W/2
182-
x_sub2 = interp(x, output_size=get_interp_size(x, s_factor=2))
182+
x_sub2 = F.interpolate(x, size=get_interp_size(x, s_factor=2), mode='bilinear', align_corners=True)
183183

184184
# H/2, W/2 -> H/4, W/4
185185
x_sub2 = self.convbnrelu1_1(x_sub2)
@@ -193,7 +193,7 @@ def forward(self, x):
193193
x_sub2 = self.res_block2(x_sub2)
194194
x_sub2 = self.res_block3_conv(x_sub2)
195195
# H/16, W/16 -> H/32, W/32
196-
x_sub4 = interp(x_sub2, output_size=get_interp_size(x_sub2, s_factor=2))
196+
x_sub4 = F.interpolate(x_sub2, size=get_interp_size(x_sub2, s_factor=2), mode='bilinear', align_corners=True)
197197
x_sub4 = self.res_block3_identity(x_sub4)
198198

199199
x_sub4 = self.res_block4(x_sub4)
@@ -209,18 +209,19 @@ def forward(self, x):
209209
x_sub24, sub4_cls = self.cff_sub24(x_sub4, x_sub2)
210210
x_sub12, sub24_cls = self.cff_sub12(x_sub24, x_sub1)
211211

212-
x_sub12 = F.upsample(
213-
x_sub12, size=get_interp_size(x_sub12, z_factor=2), mode="bilinear"
212+
x_sub12 = F.interpolate(
213+
x_sub12, size=get_interp_size(x_sub12, z_factor=2), mode="bilinear", align_corners=True
214214
)
215215
sub124_cls = self.classification(x_sub12)
216216

217217
if self.training:
218-
return sub4_cls, sub24_cls, sub124_cls
218+
return (sub124_cls, sub24_cls, sub4_cls)
219219
else: # eval mode
220-
sub124_cls = F.upsample(
220+
sub124_cls = F.interpolate(
221221
sub124_cls,
222222
size=get_interp_size(sub124_cls, z_factor=4),
223223
mode="bilinear",
224+
align_corners=True
224225
) # Test only
225226
return sub124_cls
226227

ptsemseg/models/utils.py

+7-12
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,6 @@
33
import numpy as np
44
import torch.nn.functional as F
55

6-
from torch.autograd import Variable
7-
86

97
class conv2DBatchNorm(nn.Module):
108
def __init__(
@@ -572,7 +570,7 @@ def forward(self, x):
572570
# out = F.adaptive_avg_pool2d(x, output_size=(pool_size, pool_size))
573571
if self.model_name != "icnet":
574572
out = module(out)
575-
out = F.upsample(out, size=(h, w), mode="bilinear")
573+
out = F.interpolate(out, size=(h, w), mode="bilinear", align_corners=True)
576574
output_slices.append(out)
577575

578576
return torch.cat(output_slices, dim=1)
@@ -586,7 +584,7 @@ def forward(self, x):
586584
# out = F.adaptive_avg_pool2d(x, output_size=(pool_size, pool_size))
587585
if self.model_name != "icnet":
588586
out = module(out)
589-
out = F.upsample(out, size=(h, w), mode="bilinear")
587+
out = F.interpolate(out, size=(h, w), mode="bilinear", align_corners=True)
590588
pp_sum = pp_sum + out
591589

592590
return pp_sum
@@ -791,8 +789,8 @@ def __init__(
791789
)
792790

793791
def forward(self, x_low, x_high):
794-
x_low_upsampled = F.upsample(
795-
x_low, size=get_interp_size(x_low, z_factor=2), mode="bilinear"
792+
x_low_upsampled = F.interpolate(
793+
x_low, size=get_interp_size(x_low, z_factor=2), mode="bilinear", align_corners=True
796794
)
797795

798796
low_cls = self.low_classifier_conv(x_low_upsampled)
@@ -824,16 +822,13 @@ def interp(input, output_size, mode="bilinear"):
824822
oh, ow = output_size
825823

826824
# normalize to [-1, 1]
827-
h = torch.arange(0, oh) / (oh - 1) * 2 - 1
828-
w = torch.arange(0, ow) / (ow - 1) * 2 - 1
825+
h = torch.arange(0, oh, dtype=torch.float, device='cuda' if input.is_cuda else 'cpu') / (oh - 1) * 2 - 1
826+
w = torch.arange(0, ow, dtype=torch.float, device='cuda' if input.is_cuda else 'cpu') / (ow - 1) * 2 - 1
829827

830-
grid = torch.zeros(oh, ow, 2)
828+
grid = torch.zeros(oh, ow, 2, dtype=torch.float, device='cuda' if input.is_cuda else 'cpu')
831829
grid[:, :, 0] = w.unsqueeze(0).repeat(oh, 1)
832830
grid[:, :, 1] = h.unsqueeze(0).repeat(ow, 1).transpose(0, 1)
833831
grid = grid.unsqueeze(0).repeat(n, 1, 1, 1) # grid.shape: [n, oh, ow, 2]
834-
grid = Variable(grid)
835-
if input.is_cuda:
836-
grid = grid.cuda()
837832

838833
return F.grid_sample(input, grid, mode=mode)
839834

validate.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def validate(cfg, args):
3636
split=cfg['data']['val_split'],
3737
is_transform=True,
3838
img_size=(cfg['data']['img_rows'],
39-
cfg['data']['img_rows']),
39+
cfg['data']['img_cols']),
4040
)
4141

4242
n_classes = loader.n_classes

0 commit comments

Comments
 (0)