diff --git a/ptsemseg/loader/ade20k_loader.py b/ptsemseg/loader/ade20k_loader.py index 8a37bed3..92894660 100644 --- a/ptsemseg/loader/ade20k_loader.py +++ b/ptsemseg/loader/ade20k_loader.py @@ -2,6 +2,8 @@ import torch import torchvision import numpy as np +import cv2 +from pathlib import Path import scipy.misc as m import matplotlib.pyplot as plt @@ -22,34 +24,42 @@ def __init__( test_mode=False, ): self.root = root - self.split = split + assert split in ['training', 'validation'] + self.split = 'training' if split == 'training' else 'validation' self.is_transform = is_transform self.augmentations = augmentations self.img_norm = img_norm self.test_mode = test_mode - self.n_classes = 150 + self.n_classes = img_size[0] if isinstance(img_size, tuple) else img_size self.img_size = img_size if isinstance(img_size, tuple) else (img_size, img_size) self.mean = np.array([104.00699, 116.66877, 122.67892]) + """ self.files = collections.defaultdict(list) - if not self.test_mode: for split in ["training", "validation"]: file_list = recursive_glob( rootdir=self.root + "images/" + self.split + "/", suffix=".jpg" ) self.files[split] = file_list + """ + img_path = Path(self.root + 'images/' + self.split) + self.files = list(img_path.glob('*.jpg')) + + if not self.files: + raise Exception(f"No images found in {img_path}") + print(f"Found {len(self.files)} {split} images.") def __len__(self): - return len(self.files[self.split]) + return len(self.files) def __getitem__(self, index): img_path = self.files[self.split][index].rstrip() - lbl_path = img_path[:-4] + "_seg.png" + lbl_path = img_path.replace('images','annotations')[:-4] + ".png" - img = m.imread(img_path) + img = cv2.imread(img_path) img = np.array(img, dtype=np.uint8) - lbl = m.imread(lbl_path) + lbl = cv2.imread(lbl_path) lbl = np.array(lbl, dtype=np.int32) if self.augmentations is not None: @@ -61,7 +71,7 @@ def __getitem__(self, index): return img, lbl def transform(self, img, lbl): - img = m.imresize(img, (self.img_size[0], self.img_size[1])) # uint8 with RGB mode + img = cv2.resize(img, (self.img_size[0], self.img_size[1])) # uint8 with RGB mode img = img[:, :, ::-1] # RGB -> BGR img = img.astype(np.float64) img -= self.mean @@ -75,9 +85,10 @@ def transform(self, img, lbl): lbl = self.encode_segmap(lbl) classes = np.unique(lbl) lbl = lbl.astype(float) - lbl = m.imresize(lbl, (self.img_size[0], self.img_size[1]), "nearest", mode="F") + #lbl = cv2.resize(lbl, (self.img_size[0], self.img_size[1]), "nearest", mode="F") + lbl = cv2.resize(lbl, (self.img_size[0], self.img_size[1])) lbl = lbl.astype(int) - assert np.all(classes == np.unique(lbl)) + #assert np.all(classes == np.unique(lbl)) img = torch.from_numpy(img).float() lbl = torch.from_numpy(lbl).long() diff --git a/ptsemseg/models/frrn.py b/ptsemseg/models/frrn.py index e2ec72b0..6053af99 100644 --- a/ptsemseg/models/frrn.py +++ b/ptsemseg/models/frrn.py @@ -26,7 +26,7 @@ class frrn(nn.Module): 2) TF implementation by @kiwonjoon: https://github.com/hiwonjoon/tf-frrn """ - def __init__(self, n_classes=21, model_type="B", group_norm=False, n_groups=16): + def __init__(self, n_classes=150, model_type="B", group_norm=False, n_groups=16): super(frrn, self).__init__() self.n_classes = n_classes self.model_type = model_type @@ -144,7 +144,7 @@ def forward(self, x): for n_blocks, channels, scale in self.decoder_frru_specs: # bilinear upsample smaller feature map upsample_size = torch.Size([_s * 2 for _s in y.size()[-2:]]) - y_upsampled = F.upsample(y, size=upsample_size, mode="bilinear", align_corners=True) + y_upsampled = F.interpolate(y, size=upsample_size, mode="bilinear", align_corners=True) # pass through decoding FRRUs for block in range(n_blocks): key = "_".join(map(str, ["decoding_frru", n_blocks, channels, scale, block])) @@ -155,7 +155,7 @@ def forward(self, x): # merge streams x = torch.cat( - [F.upsample(y, scale_factor=2, mode="bilinear", align_corners=True), z], dim=1 + [F.interpolate(y, scale_factor=2, mode="bilinear", align_corners=True), z], dim=1 ) x = self.merge_conv(x) diff --git a/ptsemseg/models/utils.py b/ptsemseg/models/utils.py index ec24bba0..679d3fb0 100644 --- a/ptsemseg/models/utils.py +++ b/ptsemseg/models/utils.py @@ -401,7 +401,7 @@ def forward(self, y, z): x = self.conv_res(y_prime) upsample_size = torch.Size([_s * self.scale for _s in y_prime.shape[-2:]]) - x = F.upsample(x, size=upsample_size, mode="nearest") + x = F.interpolate(x, size=upsample_size, mode="nearest") z_prime = z + x return y_prime, z_prime @@ -482,14 +482,14 @@ def __init__(self, channels, up_scale_high, up_scale_low, high_shape, low_shape) self.conv_low = nn.Conv2d(low_shape[1], channels, kernel_size=3) def forward(self, x_high, x_low): - high_upsampled = F.upsample( + high_upsampled = F.interpolate( self.conv_high(x_high), scale_factor=self.up_scale_high, mode="bilinear" ) if x_low is None: return high_upsampled - low_upsampled = F.upsample( + low_upsampled = F.interpolate( self.conv_low(x_low), scale_factor=self.up_scale_low, mode="bilinear" )