@@ -270,7 +270,7 @@ def from_instance_masks(instance_masks,
270
270
if instance_ids is not None :
271
271
assert len (instance_ids ) == len (instance_masks )
272
272
else :
273
- instance_ids = [ 1 + i for i in range (len (instance_masks ))]
273
+ instance_ids = range (1 , len (instance_masks ) + 1 )
274
274
275
275
if instance_labels is not None :
276
276
assert len (instance_labels ) == len (instance_masks )
@@ -310,15 +310,13 @@ def instance_mask(self):
310
310
def instance_count (self ):
311
311
return int (self .instance_mask .max ())
312
312
313
- def get_instance_labels (self , class_count = None ):
314
- if class_count is None :
315
- class_count = np .max (self .class_mask ) + 1
316
-
317
- m = self .class_mask * class_count + self .instance_mask
318
- m = m .astype (int )
313
+ def get_instance_labels (self ):
314
+ class_shift = 16
315
+ m = (self .class_mask .astype (np .uint32 ) << class_shift ) \
316
+ + self .instance_mask .astype (np .uint32 )
319
317
keys = np .unique (m )
320
- instance_labels = {k % class_count : k // class_count
321
- for k in keys if k % class_count != 0
318
+ instance_labels = {k & (( 1 << class_shift ) - 1 ) : k >> class_shift
319
+ for k in keys if k & (( 1 << class_shift ) - 1 ) != 0
322
320
}
323
321
return instance_labels
324
322
@@ -783,4 +781,4 @@ def categories(self):
783
781
return self ._extractor .categories ()
784
782
785
783
def transform_item (self , item ):
786
- raise NotImplementedError ()
784
+ raise NotImplementedError ()
0 commit comments