Skip to content

Commit 39a7a2f

Browse files
committed
Add some utility functions to Cityscapes
1 parent f516753 commit 39a7a2f

File tree

1 file changed

+55
-1
lines changed

1 file changed

+55
-1
lines changed

torchvision/datasets/cityscapes.py

Lines changed: 55 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,11 @@
22
import os
33
from collections import namedtuple
44

5-
from .vision import VisionDataset
5+
import torch
66
from PIL import Image
77

8+
from .vision import VisionDataset
9+
810

911
class Cityscapes(VisionDataset):
1012
"""`Cityscapes <http://www.cityscapes-dataset.com/>`_ Dataset.
@@ -174,6 +176,58 @@ def __getitem__(self, index):
174176
def __len__(self):
175177
return len(self.images)
176178

179+
@staticmethod
180+
def convert_id_to_train_id(target):
181+
target_copy = target.clone()
182+
183+
for cls in Cityscapes.classes:
184+
target_copy[target == cls.id] = cls.train_id
185+
186+
return target_copy
187+
188+
@staticmethod
189+
def convert_train_id_to_id(target):
190+
target_copy = target.clone()
191+
192+
for cls in Cityscapes.classes:
193+
target_copy[target == cls.train_id] = cls.id
194+
195+
return target_copy
196+
197+
@staticmethod
198+
def get_class_from_name(name):
199+
for cls in Cityscapes.classes:
200+
if cls.name == name:
201+
return cls
202+
return None
203+
204+
@staticmethod
205+
def get_class_from_id(id):
206+
for cls in Cityscapes.classes:
207+
if cls.id == id:
208+
return cls
209+
return None
210+
211+
@staticmethod
212+
def get_class_from_train_id(train_id):
213+
for cls in Cityscapes.classes:
214+
if cls.train_id == train_id:
215+
return cls
216+
return None
217+
218+
@staticmethod
219+
def get_colormap():
220+
cmap = torch.zeros([256, 3], dtype=torch.uint8)
221+
222+
for cls in Cityscapes.classes:
223+
cmap[cls.id, :] = torch.tensor(cls.color)
224+
225+
return cmap
226+
227+
@staticmethod
228+
def num_classes():
229+
return len([cls for cls in Cityscapes.classes if not cls.ignore_in_eval])
230+
177231
def extra_repr(self):
178232
lines = ["Split: {split}", "Mode: {mode}", "Type: {target_type}"]
179233
return '\n'.join(lines).format(**self.__dict__)

0 commit comments

Comments
 (0)