|
2 | 2 | import os
|
3 | 3 | from collections import namedtuple
|
4 | 4 |
|
5 |
| -from .vision import VisionDataset |
| 5 | +import torch |
6 | 6 | from PIL import Image
|
7 | 7 |
|
| 8 | +from .vision import VisionDataset |
| 9 | + |
8 | 10 |
|
9 | 11 | class Cityscapes(VisionDataset):
|
10 | 12 | """`Cityscapes <http://www.cityscapes-dataset.com/>`_ Dataset.
|
@@ -174,6 +176,58 @@ def __getitem__(self, index):
|
174 | 176 | def __len__(self):
|
175 | 177 | return len(self.images)
|
176 | 178 |
|
| 179 | + @staticmethod |
| 180 | + def convert_id_to_train_id(target): |
| 181 | + target_copy = target.clone() |
| 182 | + |
| 183 | + for cls in Cityscapes.classes: |
| 184 | + target_copy[target == cls.id] = cls.train_id |
| 185 | + |
| 186 | + return target_copy |
| 187 | + |
| 188 | + @staticmethod |
| 189 | + def convert_train_id_to_id(target): |
| 190 | + target_copy = target.clone() |
| 191 | + |
| 192 | + for cls in Cityscapes.classes: |
| 193 | + target_copy[target == cls.train_id] = cls.id |
| 194 | + |
| 195 | + return target_copy |
| 196 | + |
| 197 | + @staticmethod |
| 198 | + def get_class_from_name(name): |
| 199 | + for cls in Cityscapes.classes: |
| 200 | + if cls.name == name: |
| 201 | + return cls |
| 202 | + return None |
| 203 | + |
| 204 | + @staticmethod |
| 205 | + def get_class_from_id(id): |
| 206 | + for cls in Cityscapes.classes: |
| 207 | + if cls.id == id: |
| 208 | + return cls |
| 209 | + return None |
| 210 | + |
| 211 | + @staticmethod |
| 212 | + def get_class_from_train_id(train_id): |
| 213 | + for cls in Cityscapes.classes: |
| 214 | + if cls.train_id == train_id: |
| 215 | + return cls |
| 216 | + return None |
| 217 | + |
| 218 | + @staticmethod |
| 219 | + def get_colormap(): |
| 220 | + cmap = torch.zeros([256, 3], dtype=torch.uint8) |
| 221 | + |
| 222 | + for cls in Cityscapes.classes: |
| 223 | + cmap[cls.id, :] = torch.tensor(cls.color) |
| 224 | + |
| 225 | + return cmap |
| 226 | + |
| 227 | + @staticmethod |
| 228 | + def num_classes(): |
| 229 | + return len([cls for cls in Cityscapes.classes if not cls.ignore_in_eval]) |
| 230 | + |
177 | 231 | def extra_repr(self):
|
178 | 232 | lines = ["Split: {split}", "Mode: {mode}", "Type: {target_type}"]
|
179 | 233 | return '\n'.join(lines).format(**self.__dict__)
|
|
0 commit comments