Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Datumaro] Optimize mask conversions #1097

Merged
merged 1 commit into from
Jan 27, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions datumaro/datumaro/components/converters/voc.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
parse_label_map, make_voc_label_map, make_voc_categories, write_label_map
)
from datumaro.util.image import save_image
from datumaro.util.mask_tools import apply_colormap, remap_mask
from datumaro.util.mask_tools import paint_mask, remap_mask


def _write_xml_bbox(bbox, parent_elem):
Expand Down Expand Up @@ -363,7 +363,7 @@ def save_segm(self, path, annotation, colormap=None):
if colormap is None:
colormap = self._categories[AnnotationType.mask].colormap
data = self._remap_mask(data)
data = apply_colormap(data, colormap)
data = paint_mask(data, colormap)
save_image(path, data)

def save_label_map(self):
Expand Down
82 changes: 48 additions & 34 deletions datumaro/datumaro/util/mask_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,65 +30,79 @@ def invert_colormap(colormap):
tuple(a): index for index, a in colormap.items()
}

def check_is_mask(mask):
assert len(mask.shape) in {2, 3}
if len(mask.shape) == 3:
assert mask.shape[2] == 1

_default_colormap = generate_colormap()
_default_unpaint_colormap = invert_colormap(_default_colormap)

def _default_unpaint_colormap_fn(r, g, b):
return _default_unpaint_colormap[(r, g, b)]
def unpaint_mask(painted_mask, inverse_colormap=None):
# Covert color mask to index mask

def unpaint_mask(painted_mask, colormap=None):
# expect HWC BGR [0; 255] image
# expect RGB->index colormap
# mask: HWC BGR [0; 255]
# colormap: (R, G, B) -> index
assert len(painted_mask.shape) == 3
if colormap is None:
colormap = _default_unpaint_colormap_fn
if callable(colormap):
map_fn = lambda a: colormap(int(a[2]), int(a[1]), int(a[0]))
if inverse_colormap is None:
inverse_colormap = _default_unpaint_colormap

if callable(inverse_colormap):
map_fn = lambda a: inverse_colormap(
(a >> 16) & 255, (a >> 8) & 255, a & 255
)
else:
map_fn = lambda a: colormap[(int(a[2]), int(a[1]), int(a[0]))]
map_fn = lambda a: inverse_colormap[(
(a >> 16) & 255, (a >> 8) & 255, a & 255
)]

unpainted_mask = np.apply_along_axis(map_fn,
1, np.reshape(painted_mask, (-1, 3)))
unpainted_mask = np.reshape(unpainted_mask, (painted_mask.shape[:2]))
return unpainted_mask.astype(int)
painted_mask = painted_mask.astype(int)
painted_mask = painted_mask[:, :, 0] + \
(painted_mask[:, :, 1] << 8) + \
(painted_mask[:, :, 2] << 16)
uvals, unpainted_mask = np.unique(painted_mask, return_inverse=True)
palette = np.array([map_fn(v) for v in uvals], dtype=np.float32)
unpainted_mask = palette[unpainted_mask].reshape(painted_mask.shape[:2])

return unpainted_mask

def apply_colormap(mask, colormap=None):
# expect HW [0; max_index] mask
# expect index->RGB colormap
assert len(mask.shape) == 2
def paint_mask(mask, colormap=None):
# Applies colormap to index mask

# mask: HW(C) [0; max_index] mask
# colormap: index -> (R, G, B)
check_is_mask(mask)

if colormap is None:
colormap = _default_colormap
if callable(colormap):
map_fn = lambda p: colormap(int(p[0]))[::-1]
map_fn = colormap
else:
map_fn = lambda p: colormap[int(p[0])][::-1]
painted_mask = np.apply_along_axis(map_fn, 1, np.reshape(mask, (-1, 1)))
map_fn = lambda c: colormap.get(c, (-1, -1, -1))
palette = np.array([map_fn(c)[::-1] for c in range(256)], dtype=np.float32)

painted_mask = np.reshape(painted_mask, (*mask.shape, 3))
return painted_mask.astype(np.float32)
mask = mask.astype(np.uint8)
painted_mask = palette[mask].reshape((*mask.shape[:2], 3))
return painted_mask

def remap_mask(mask, map_fn):
# Changes mask elements from one colormap to another
assert len(mask.shape) == 2

shape = mask.shape
mask = np.reshape(mask, (-1, 1))
mask = np.apply_along_axis(map_fn, 1, mask)
mask = np.reshape(mask, shape)
return mask
# mask: HW(C) [0; max_index] mask
check_is_mask(mask)

return np.array([map_fn(c) for c in range(256)], dtype=np.uint8)[mask]


def load_mask(path, colormap=None):
def load_mask(path, inverse_colormap=None):
mask = load_image(path)
if colormap is not None:
if inverse_colormap is not None:
if len(mask.shape) == 3 and mask.shape[2] != 1:
mask = unpaint_mask(mask, colormap=colormap)
mask = unpaint_mask(mask, inverse_colormap)
return mask

def lazy_mask(path, colormap=None):
return lazy_image(path, lambda path: load_mask(path, colormap))
def lazy_mask(path, inverse_colormap=None):
return lazy_image(path, lambda path: load_mask(path, inverse_colormap))


def mask_to_rle(binary_mask):
Expand Down
57 changes: 56 additions & 1 deletion datumaro/tests/test_masks.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,4 +66,59 @@ def test_can_crop_covered_segments(self):
self.assertEqual(len(initial), len(computed))
for i, (e_mask, c_mask) in enumerate(zip(expected, computed)):
self.assertTrue(np.array_equal(e_mask, c_mask),
'#%s: %s\n%s\n' % (i, e_mask, c_mask))
'#%s: %s\n%s\n' % (i, e_mask, c_mask))

class ColormapOperationsTest(TestCase):
def test_can_paint_mask(self):
mask = np.zeros((1, 3), dtype=np.uint8)
mask[:, 0] = 0
mask[:, 1] = 1
mask[:, 2] = 2

colormap = mask_tools.generate_colormap(3)

expected = np.zeros((*mask.shape, 3), dtype=np.uint8)
expected[:, 0] = colormap[0][::-1]
expected[:, 1] = colormap[1][::-1]
expected[:, 2] = colormap[2][::-1]

actual = mask_tools.paint_mask(mask, colormap)

self.assertTrue(np.array_equal(expected, actual),
'%s\nvs.\n%s' % (expected, actual))

def test_can_unpaint_mask(self):
colormap = mask_tools.generate_colormap(3)
inverse_colormap = mask_tools.invert_colormap(colormap)

mask = np.zeros((1, 3, 3), dtype=np.uint8)
mask[:, 0] = colormap[0][::-1]
mask[:, 1] = colormap[1][::-1]
mask[:, 2] = colormap[2][::-1]

expected = np.zeros((1, 3), dtype=np.uint8)
expected[:, 0] = 0
expected[:, 1] = 1
expected[:, 2] = 2

actual = mask_tools.unpaint_mask(mask, inverse_colormap)

self.assertTrue(np.array_equal(expected, actual),
'%s\nvs.\n%s' % (expected, actual))

def test_can_remap_mask(self):
class_count = 10
remap_fn = lambda c: class_count - c

src = np.empty((class_count, class_count), dtype=np.uint8)
for c in range(class_count):
src[c:, c:] = c

expected = np.empty_like(src)
for c in range(class_count):
expected[c:, c:] = remap_fn(c)

actual = mask_tools.remap_mask(src, remap_fn)

self.assertTrue(np.array_equal(expected, actual),
'%s\nvs.\n%s' % (expected, actual))