Skip to content

[Feature] Add Rgb2Gray transform #227

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Nov 9, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 55 additions & 0 deletions mmseg/datasets/pipelines/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -548,6 +548,61 @@ def __repr__(self):
return repr_str


@PIPELINES.register_module()
class RGB2Gray(object):
"""Convert RGB image to grayscale image.

This transform calculate the weighted mean of input image channels with
``weights`` and then expand the channels to ``out_channels``. When
``out_channels`` is None, the number of output channels is the same as
input channels.

Args:
out_channels (int): Expected number of output channels after
transforming. Default: None.
weights (tuple[float]): The weights to calculate the weighted mean.
Default: (0.299, 0.587, 0.114).
"""

def __init__(self, out_channels=None, weights=(0.299, 0.587, 0.114)):
assert out_channels is None or out_channels > 0
self.out_channels = out_channels
assert isinstance(weights, tuple)
for item in weights:
assert isinstance(item, (float, int))
self.weights = weights

def __call__(self, results):
"""Call function to convert RGB image to grayscale image.

Args:
results (dict): Result dict from loading pipeline.

Returns:
dict: Result dict with grayscale image.
"""
img = results['img']
assert len(img.shape) == 3
assert img.shape[2] == len(self.weights)
weights = np.array(self.weights).reshape((1, 1, -1))
img = (img * weights).sum(2, keepdims=True)
if self.out_channels is None:
img = img.repeat(weights.shape[2], axis=2)
else:
img = img.repeat(self.out_channels, axis=2)

results['img'] = img
results['img_shape'] = img.shape

return results

def __repr__(self):
repr_str = self.__class__.__name__
repr_str += f'(out_channels={self.out_channels}, ' \
f'weights={self.weights})'
return repr_str


@PIPELINES.register_module()
class SegRescale(object):
"""Rescale semantic segmentation maps.
Expand Down
67 changes: 67 additions & 0 deletions tests/test_data/test_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,73 @@ def test_normalize():
assert np.allclose(results['img'], converted_img)


def test_rgb2gray():
# test assertion out_channels should be greater than 0
with pytest.raises(AssertionError):
transform = dict(type='RGB2Gray', out_channels=-1)
build_from_cfg(transform, PIPELINES)
# test assertion weights should be tuple[float]
with pytest.raises(AssertionError):
transform = dict(type='RGB2Gray', out_channels=1, weights=1.1)
build_from_cfg(transform, PIPELINES)

# test out_channels is None
transform = dict(type='RGB2Gray')
transform = build_from_cfg(transform, PIPELINES)

assert str(transform) == f'RGB2Gray(' \
f'out_channels={None}, ' \
f'weights={(0.299, 0.587, 0.114)})'

results = dict()
img = mmcv.imread(
osp.join(osp.dirname(__file__), '../data/color.jpg'), 'color')
h, w, c = img.shape
seg = np.array(
Image.open(osp.join(osp.dirname(__file__), '../data/seg.png')))
results['img'] = img
results['gt_semantic_seg'] = seg
results['seg_fields'] = ['gt_semantic_seg']
results['img_shape'] = img.shape
results['ori_shape'] = img.shape
# Set initial values for default meta_keys
results['pad_shape'] = img.shape
results['scale_factor'] = 1.0

results = transform(results)
assert results['img'].shape == (h, w, c)
assert results['img_shape'] == (h, w, c)
assert results['ori_shape'] == (h, w, c)

# test out_channels = 2
transform = dict(type='RGB2Gray', out_channels=2)
transform = build_from_cfg(transform, PIPELINES)

assert str(transform) == f'RGB2Gray(' \
f'out_channels={2}, ' \
f'weights={(0.299, 0.587, 0.114)})'

results = dict()
img = mmcv.imread(
osp.join(osp.dirname(__file__), '../data/color.jpg'), 'color')
h, w, c = img.shape
seg = np.array(
Image.open(osp.join(osp.dirname(__file__), '../data/seg.png')))
results['img'] = img
results['gt_semantic_seg'] = seg
results['seg_fields'] = ['gt_semantic_seg']
results['img_shape'] = img.shape
results['ori_shape'] = img.shape
# Set initial values for default meta_keys
results['pad_shape'] = img.shape
results['scale_factor'] = 1.0

results = transform(results)
assert results['img'].shape == (h, w, 2)
assert results['img_shape'] == (h, w, 2)
assert results['ori_shape'] == (h, w, c)


def test_seg_rescale():
results = dict()
seg = np.array(
Expand Down