Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sourcery Starbot ⭐ refactored meetps/pytorch-semseg #264

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion ptsemseg/augmentations/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,5 +45,5 @@ def get_composed_augmentations(aug_dict):
augmentations = []
for aug_key, aug_param in aug_dict.items():
augmentations.append(key2aug[aug_key](aug_param))
logger.info("Using {} aug with params {}".format(aug_key, aug_param))
logger.info(f"Using {aug_key} aug with params {aug_param}")
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function get_composed_augmentations refactored with the following changes:

return Compose(augmentations)
21 changes: 8 additions & 13 deletions ptsemseg/augmentations/augmentations.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,13 +157,8 @@ def __call__(self, img, mask):
x_offset = int(2 * (random.random() - 0.5) * self.offset[0])
y_offset = int(2 * (random.random() - 0.5) * self.offset[1])

x_crop_offset = x_offset
y_crop_offset = y_offset
if x_offset < 0:
x_crop_offset = 0
if y_offset < 0:
y_crop_offset = 0

x_crop_offset = max(x_offset, 0)
y_crop_offset = max(y_offset, 0)
Comment on lines -160 to +161
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function RandomTranslate.__call__ refactored with the following changes:

cropped_img = tf.crop(
img,
y_crop_offset,
Expand All @@ -175,13 +170,13 @@ def __call__(self, img, mask):
if x_offset >= 0 and y_offset >= 0:
padding_tuple = (0, 0, x_offset, y_offset)

elif x_offset >= 0 and y_offset < 0:
elif x_offset >= 0:
padding_tuple = (0, abs(y_offset), x_offset, 0)

elif x_offset < 0 and y_offset >= 0:
elif y_offset >= 0:
padding_tuple = (abs(x_offset), 0, 0, y_offset)

elif x_offset < 0 and y_offset < 0:
else:
padding_tuple = (abs(x_offset), abs(y_offset), 0, 0)

return (
Expand Down Expand Up @@ -237,11 +232,11 @@ def __call__(self, img, mask):
if w > h:
ow = self.size
oh = int(self.size * h / w)
return (img.resize((ow, oh), Image.BILINEAR), mask.resize((ow, oh), Image.NEAREST))
else:
oh = self.size
ow = int(self.size * w / h)
return (img.resize((ow, oh), Image.BILINEAR), mask.resize((ow, oh), Image.NEAREST))

return (img.resize((ow, oh), Image.BILINEAR), mask.resize((ow, oh), Image.NEAREST))
Comment on lines -240 to +239
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function Scale.__call__ refactored with the following changes:



class RandomSizedCrop(object):
Expand All @@ -250,7 +245,7 @@ def __init__(self, size):

def __call__(self, img, mask):
assert img.size == mask.size
for attempt in range(10):
for _ in range(10):
Comment on lines -253 to +248
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function RandomSizedCrop.__call__ refactored with the following changes:

area = img.size[0] * img.size[1]
target_area = random.uniform(0.45, 1.0) * area
aspect_ratio = random.uniform(0.5, 2)
Expand Down
13 changes: 6 additions & 7 deletions ptsemseg/loader/ade20k_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def __init__(
if not self.test_mode:
for split in ["training", "validation"]:
file_list = recursive_glob(
rootdir=self.root + "images/" + self.split + "/", suffix=".jpg"
rootdir=f"{self.root}images/{self.split}/", suffix=".jpg"
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function ADE20KLoader.__init__ refactored with the following changes:

)
self.files[split] = file_list

Expand All @@ -44,7 +44,7 @@ def __len__(self):

def __getitem__(self, index):
img_path = self.files[self.split][index].rstrip()
lbl_path = img_path[:-4] + "_seg.png"
lbl_path = f"{img_path[:-4]}_seg.png"
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function ADE20KLoader.__getitem__ refactored with the following changes:


img = m.imread(img_path)
img = np.array(img, dtype=np.uint8)
Expand Down Expand Up @@ -96,7 +96,7 @@ def decode_segmap(self, temp, plot=False):
r = temp.copy()
g = temp.copy()
b = temp.copy()
for l in range(0, self.n_classes):
for l in range(self.n_classes):
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function ADE20KLoader.decode_segmap refactored with the following changes:

r[temp == l] = 10 * (l % 10)
g[temp == l] = l
b[temp == l] = 0
Expand All @@ -105,11 +105,10 @@ def decode_segmap(self, temp, plot=False):
rgb[:, :, 0] = r / 255.0
rgb[:, :, 1] = g / 255.0
rgb[:, :, 2] = b / 255.0
if plot:
plt.imshow(rgb)
plt.show()
else:
if not plot:
return rgb
plt.imshow(rgb)
plt.show()


if __name__ == "__main__":
Expand Down
8 changes: 4 additions & 4 deletions ptsemseg/loader/camvid_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,16 +33,16 @@ def __init__(

if not self.test_mode:
for split in ["train", "test", "val"]:
file_list = os.listdir(root + "/" + split)
file_list = os.listdir(f"{root}/{split}")
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function camvidLoader.__init__ refactored with the following changes:

self.files[split] = file_list

def __len__(self):
return len(self.files[self.split])

def __getitem__(self, index):
img_name = self.files[self.split][index]
img_path = self.root + "/" + self.split + "/" + img_name
lbl_path = self.root + "/" + self.split + "annot/" + img_name
img_path = f"{self.root}/{self.split}/{img_name}"
lbl_path = f"{self.root}/{self.split}annot/{img_name}"
Comment on lines -44 to +45
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function camvidLoader.__getitem__ refactored with the following changes:


img = m.imread(img_path)
img = np.array(img, dtype=np.uint8)
Expand Down Expand Up @@ -107,7 +107,7 @@ def decode_segmap(self, temp, plot=False):
r = temp.copy()
g = temp.copy()
b = temp.copy()
for l in range(0, self.n_classes):
for l in range(self.n_classes):
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function camvidLoader.decode_segmap refactored with the following changes:

r[temp == l] = label_colours[l, 0]
g[temp == l] = label_colours[l, 1]
b[temp == l] = label_colours[l, 2]
Expand Down
6 changes: 3 additions & 3 deletions ptsemseg/loader/cityscapes_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def __init__(
self.class_map = dict(zip(self.valid_classes, range(19)))

if not self.files[split]:
raise Exception("No files for split=[%s] found in %s" % (split, self.images_base))
raise Exception(f"No files for split=[{split}] found in {self.images_base}")
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function cityscapesLoader.__init__ refactored with the following changes:


print("Found %d %s images" % (len(self.files[split]), split))

Expand All @@ -150,7 +150,7 @@ def __getitem__(self, index):
lbl_path = os.path.join(
self.annotations_base,
img_path.split(os.sep)[-2],
os.path.basename(img_path)[:-15] + "gtFine_labelIds.png",
f"{os.path.basename(img_path)[:-15]}gtFine_labelIds.png",
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function cityscapesLoader.__getitem__ refactored with the following changes:

)

img = m.imread(img_path)
Expand Down Expand Up @@ -205,7 +205,7 @@ def decode_segmap(self, temp):
r = temp.copy()
g = temp.copy()
b = temp.copy()
for l in range(0, self.n_classes):
for l in range(self.n_classes):
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function cityscapesLoader.decode_segmap refactored with the following changes:

r[temp == l] = self.label_colours[l][0]
g[temp == l] = self.label_colours[l][1]
b[temp == l] = self.label_colours[l][2]
Expand Down
10 changes: 4 additions & 6 deletions ptsemseg/loader/mapillary_vistas_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def __init__(
self.ignore_id = 250

if not self.files[split]:
raise Exception("No files for split=[%s] found in %s" % (split, self.images_base))
raise Exception(f"No files for split=[{split}] found in {self.images_base}")
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function mapillaryVistasLoader.__init__ refactored with the following changes:


print("Found %d %s images" % (len(self.files[split]), split))

Expand All @@ -53,7 +53,7 @@ def parse_config(self):
class_names = []
class_ids = []
class_colors = []
print("There are {} labels in the config file".format(len(labels)))
print(f"There are {len(labels)} labels in the config file")
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function mapillaryVistasLoader.parse_config refactored with the following changes:

for label_id, label in enumerate(labels):
class_names.append(label["readable"])
class_ids.append(label_id)
Expand Down Expand Up @@ -86,9 +86,7 @@ def __getitem__(self, index):
return img, lbl

def transform(self, img, lbl):
if self.img_size == ("same", "same"):
pass
else:
if self.img_size != ("same", "same"):
img = img.resize(
(self.img_size[0], self.img_size[1]), resample=Image.LANCZOS
) # uint8 with RGB mode
Expand All @@ -103,7 +101,7 @@ def decode_segmap(self, temp):
r = temp.copy()
g = temp.copy()
b = temp.copy()
for l in range(0, self.n_classes):
for l in range(self.n_classes):
r[temp == l] = self.class_colors[l][0]
g[temp == l] = self.class_colors[l][1]
b[temp == l] = self.class_colors[l][2]
Expand Down
14 changes: 6 additions & 8 deletions ptsemseg/loader/mit_sceneparsing_benchmark_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def __init__(
self.files[split] = recursive_glob(rootdir=self.images_base, suffix=".jpg")

if not self.files[split]:
raise Exception("No files for split=[%s] found in %s" % (split, self.images_base))
raise Exception(f"No files for split=[{split}] found in {self.images_base}")

print("Found %d %s images" % (len(self.files[split]), split))

Expand All @@ -71,7 +71,9 @@ def __getitem__(self, index):
:param index:
"""
img_path = self.files[self.split][index].rstrip()
lbl_path = os.path.join(self.annotations_base, os.path.basename(img_path)[:-4] + ".png")
lbl_path = os.path.join(
self.annotations_base, f"{os.path.basename(img_path)[:-4]}.png"
)

img = m.imread(img_path, mode="RGB")
img = np.array(img, dtype=np.uint8)
Expand All @@ -93,9 +95,7 @@ def transform(self, img, lbl):
:param img:
:param lbl:
"""
if self.img_size == ("same", "same"):
pass
else:
if self.img_size != ("same", "same"):
img = m.imresize(img, (self.img_size[0], self.img_size[1])) # uint8 with RGB mode
img = img[:, :, ::-1] # RGB -> BGR
img = img.astype(np.float64)
Expand All @@ -109,9 +109,7 @@ def transform(self, img, lbl):

classes = np.unique(lbl)
lbl = lbl.astype(float)
if self.img_size == ("same", "same"):
pass
else:
if self.img_size != ("same", "same"):
lbl = m.imresize(lbl, (self.img_size[0], self.img_size[1]), "nearest", mode="F")
lbl = lbl.astype(int)

Expand Down
6 changes: 3 additions & 3 deletions ptsemseg/loader/nyuv2_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def __getitem__(self, index):
img_path = self.files[self.split][index].rstrip()
img_number = img_path.split("_")[-1][:4]
lbl_path = os.path.join(
self.root, self.split + "_annot", "new_nyu_class13_" + img_number + ".png"
self.root, f"{self.split}_annot", f"new_nyu_class13_{img_number}.png"
)

img = m.imread(img_path)
Expand All @@ -67,7 +67,7 @@ def __getitem__(self, index):
lbl = m.imread(lbl_path)
lbl = np.array(lbl, dtype=np.uint8)

if not (len(img.shape) == 3 and len(lbl.shape) == 2):
if len(img.shape) != 3 or len(lbl.shape) != 2:
return self.__getitem__(np.random.randint(0, self.__len__()))

if self.augmentations is not None:
Expand Down Expand Up @@ -128,7 +128,7 @@ def decode_segmap(self, temp):
r = temp.copy()
g = temp.copy()
b = temp.copy()
for l in range(0, self.n_classes):
for l in range(self.n_classes):
r[temp == l] = self.cmap[l, 0]
g[temp == l] = self.cmap[l, 1]
b[temp == l] = self.cmap[l, 2]
Expand Down
25 changes: 11 additions & 14 deletions ptsemseg/loader/pascal_voc_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def __init__(

if not self.test_mode:
for split in ["train", "val", "trainval"]:
path = pjoin(self.root, "ImageSets/Segmentation", split + ".txt")
path = pjoin(self.root, "ImageSets/Segmentation", f"{split}.txt")
file_list = tuple(open(path, "r"))
file_list = [id_.rstrip() for id_ in file_list]
self.files[split] = file_list
Expand All @@ -84,8 +84,8 @@ def __len__(self):

def __getitem__(self, index):
im_name = self.files[self.split][index]
im_path = pjoin(self.root, "JPEGImages", im_name + ".jpg")
lbl_path = pjoin(self.root, "SegmentationClass/pre_encoded", im_name + ".png")
im_path = pjoin(self.root, "JPEGImages", f"{im_name}.jpg")
lbl_path = pjoin(self.root, "SegmentationClass/pre_encoded", f"{im_name}.png")
im = Image.open(im_path)
lbl = Image.open(lbl_path)
if self.augmentations is not None:
Expand All @@ -95,9 +95,7 @@ def __getitem__(self, index):
return im, lbl

def transform(self, img, lbl):
if self.img_size == ("same", "same"):
pass
else:
if self.img_size != ("same", "same"):
img = img.resize((self.img_size[0], self.img_size[1])) # uint8 with RGB mode
lbl = lbl.resize((self.img_size[0], self.img_size[1]))
img = self.tf(img)
Expand Down Expand Up @@ -171,19 +169,18 @@ def decode_segmap(self, label_mask, plot=False):
r = label_mask.copy()
g = label_mask.copy()
b = label_mask.copy()
for ll in range(0, self.n_classes):
for ll in range(self.n_classes):
r[label_mask == ll] = label_colours[ll, 0]
g[label_mask == ll] = label_colours[ll, 1]
b[label_mask == ll] = label_colours[ll, 2]
rgb = np.zeros((label_mask.shape[0], label_mask.shape[1], 3))
rgb[:, :, 0] = r / 255.0
rgb[:, :, 1] = g / 255.0
rgb[:, :, 2] = b / 255.0
if plot:
plt.imshow(rgb)
plt.show()
else:
if not plot:
return rgb
plt.imshow(rgb)
plt.show()

def setup_annotations(self):
"""Sets up Berkley annotations by adding image indices to the
Expand Down Expand Up @@ -213,14 +210,14 @@ def setup_annotations(self):
if len(pre_encoded) != expected:
print("Pre-encoding segmentation masks...")
for ii in tqdm(sbd_train_list):
lbl_path = pjoin(sbd_path, "dataset/cls", ii + ".mat")
lbl_path = pjoin(sbd_path, "dataset/cls", f"{ii}.mat")
data = io.loadmat(lbl_path)
lbl = data["GTcls"][0]["Segmentation"][0].astype(np.int32)
lbl = m.toimage(lbl, high=lbl.max(), low=lbl.min())
m.imsave(pjoin(target_path, ii + ".png"), lbl)
m.imsave(pjoin(target_path, f"{ii}.png"), lbl)

for ii in tqdm(self.files["trainval"]):
fname = ii + ".png"
fname = f"{ii}.png"
lbl_path = pjoin(self.root, "SegmentationClass", fname)
lbl = self.encode_segmap(m.imread(lbl_path))
lbl = m.toimage(lbl, high=lbl.max(), low=lbl.min())
Expand Down
8 changes: 5 additions & 3 deletions ptsemseg/loader/sunrgbd_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,9 @@ def __init__(

for split in ["train", "test"]:
file_list = sorted(
recursive_glob(rootdir=self.root + "annotations/" + split + "/", suffix="png")
recursive_glob(
rootdir=f"{self.root}annotations/{split}/", suffix="png"
)
)
self.anno_files[split] = file_list

Expand All @@ -72,7 +74,7 @@ def __getitem__(self, index):
lbl = m.imread(lbl_path)
lbl = np.array(lbl, dtype=np.uint8)

if not (len(img.shape) == 3 and len(lbl.shape) == 2):
if len(img.shape) != 3 or len(lbl.shape) != 2:
return self.__getitem__(np.random.randint(0, self.__len__()))

if self.augmentations is not None:
Expand Down Expand Up @@ -133,7 +135,7 @@ def decode_segmap(self, temp):
r = temp.copy()
g = temp.copy()
b = temp.copy()
for l in range(0, self.n_classes):
for l in range(self.n_classes):
r[temp == l] = self.cmap[l, 0]
g[temp == l] = self.cmap[l, 1]
b[temp == l] = self.cmap[l, 2]
Expand Down
4 changes: 2 additions & 2 deletions ptsemseg/loss/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def get_loss_function(cfg):
loss_params = {k: v for k, v in loss_dict.items() if k != "name"}

if loss_name not in key2loss:
raise NotImplementedError("Loss {} not implemented".format(loss_name))
raise NotImplementedError(f"Loss {loss_name} not implemented")

logger.info("Using {} with {} params".format(loss_name, loss_params))
logger.info(f"Using {loss_name} with {loss_params} params")
return functools.partial(key2loss[loss_name], **loss_params)
9 changes: 6 additions & 3 deletions ptsemseg/loss/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,13 @@ def cross_entropy2d(input, target, weight=None, size_average=True):

input = input.transpose(1, 2).transpose(2, 3).contiguous().view(-1, c)
target = target.view(-1)
loss = F.cross_entropy(
input, target, weight=weight, size_average=size_average, ignore_index=250
return F.cross_entropy(
input,
target,
weight=weight,
size_average=size_average,
ignore_index=250,
)
return loss


def multi_scale_cross_entropy2d(input, target, weight=None, size_average=True, scale_weight=None):
Expand Down
Loading