@@ -265,30 +265,35 @@ def _load_items(self):
265
265
recursive = True ,
266
266
)
267
267
mask_suffix = CityscapesPath .GT_INSTANCE_MASK_SUFFIX
268
+
269
+ self ._categories = self ._load_categories (
270
+ self ._path , use_train_label_map = mask_suffix is CityscapesPath .LABEL_TRAIN_IDS_SUFFIX
271
+ )
272
+
273
+ label_ids = []
274
+ for label_cat in self ._categories [AnnotationType .label ]:
275
+ label_id , _ = self ._categories [AnnotationType .label ].find (label_cat .name )
276
+ if label_id :
277
+ label_ids .append (label_id )
278
+
268
279
for mask_path in masks :
269
280
item_id = self ._get_id_from_mask_path (mask_path , mask_suffix )
270
281
271
282
anns = []
272
283
instances_mask = load_image (mask_path , dtype = np .int32 )
273
- segm_ids = np .unique (instances_mask )
274
- for segm_id in segm_ids :
275
- # either is_crowd or ann_id should be set
276
- if segm_id < 1000 :
277
- label_id = segm_id
278
- is_crowd = True
279
- ann_id = None
280
- else :
281
- label_id = segm_id // 1000
282
- is_crowd = False
283
- ann_id = segm_id % 1000
284
+ mask_id = 1
285
+ for label_id in label_ids :
286
+ if label_id not in instances_mask :
287
+ continue
288
+ binary_mask = self ._lazy_extract_mask (instances_mask , label_id )
284
289
anns .append (
285
290
Mask (
286
- image = self ._lazy_extract_mask (instances_mask , segm_id ),
291
+ id = mask_id ,
292
+ image = binary_mask ,
287
293
label = label_id ,
288
- id = ann_id ,
289
- attributes = {"is_crowd" : is_crowd },
290
294
)
291
295
)
296
+ mask_id += 1
292
297
293
298
image = image_path_by_id .pop (item_id , None )
294
299
if image :
@@ -303,9 +308,6 @@ def _load_items(self):
303
308
id = item_id , subset = self ._subset , media = Image .from_file (path = path )
304
309
)
305
310
306
- self ._categories = self ._load_categories (
307
- self ._path , use_train_label_map = mask_suffix is CityscapesPath .LABEL_TRAIN_IDS_SUFFIX
308
- )
309
311
return items
310
312
311
313
@staticmethod
@@ -429,8 +431,8 @@ def _apply_impl(self):
429
431
masks ,
430
432
instance_ids = [
431
433
self ._label_id_mapping (m .label )
432
- if m .attributes .get ("is_crowd" , False )
433
- else self ._label_id_mapping (m .label ) * 1000 + (m .id or (i + 1 ))
434
+ # if m.attributes.get("is_crowd", False)
435
+ # else self._label_id_mapping(m.label) * 1000 + (m.id or (i + 1))
434
436
for i , m in enumerate (masks )
435
437
],
436
438
instance_labels = [self ._label_id_mapping (m .label ) for m in masks ],
0 commit comments