Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature]add CLAHE transform #229

Merged
merged 17 commits into from
Dec 2, 2020
8 changes: 4 additions & 4 deletions mmseg/datasets/pipelines/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,14 @@
Transpose, to_tensor)
from .loading import LoadAnnotations, LoadImageFromFile
from .test_time_aug import MultiScaleFlipAug
from .transforms import (Normalize, Pad, PhotoMetricDistortion, RandomCrop,
RandomFlip, RandomRotate, Rerange, Resize, RGB2Gray,
SegRescale)
from .transforms import (CLAHE, AdjustGamma, Normalize, Pad,
PhotoMetricDistortion, RandomCrop, RandomFlip,
RandomRotate, Rerange, Resize, RGB2Gray, SegRescale)

__all__ = [
'Compose', 'to_tensor', 'ToTensor', 'ImageToTensor', 'ToDataContainer',
'Transpose', 'Collect', 'LoadAnnotations', 'LoadImageFromFile',
'MultiScaleFlipAug', 'Resize', 'RandomFlip', 'Pad', 'RandomCrop',
'Normalize', 'SegRescale', 'PhotoMetricDistortion', 'RandomRotate',
'Rerange', 'RGB2Gray'
'AdjustGamma', 'CLAHE', 'Rerange', 'RGB2Gray'
]
48 changes: 46 additions & 2 deletions mmseg/datasets/pipelines/transforms.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import mmcv
import numpy as np
from mmcv.utils import deprecated_api_warning
from mmcv.utils import deprecated_api_warning, is_tuple_of
from numpy import random

from ..builder import PIPELINES
Expand Down Expand Up @@ -415,7 +415,6 @@ def __call__(self, results):

Args:
results (dict): Result dict from loading pipeline.

Returns:
dict: Reranged results.
"""
Expand All @@ -439,6 +438,51 @@ def __repr__(self):
return repr_str


@PIPELINES.register_module()
class CLAHE(object):
"""Use CLAHE method to process the image.
xvjiarui marked this conversation as resolved.
Show resolved Hide resolved

See `ZUIDERVELD,K. Contrast Limited Adaptive Histogram Equalization[J].
Graphics Gems, 1994:474-485.` for more information.

Args:
clip_limit (float): Threshold for contrast limiting. Default: 40.0.
tile_grid_size (tuple[int]): Size of grid for histogram equalization.
Input image will be divided into equally sized rectangular tiles.
It defines the number of tiles in row and column. Default: (8, 8).
"""

def __init__(self, clip_limit=40.0, tile_grid_size=(8, 8)):
assert isinstance(clip_limit, (float, int))
self.clip_limit = clip_limit
assert is_tuple_of(tile_grid_size, int)
assert len(tile_grid_size) == 2
self.tile_grid_size = tile_grid_size

def __call__(self, results):
"""Call function to Use CLAHE method process images.

Args:
results (dict): Result dict from loading pipeline.

Returns:
dict: Processed results.
"""

for i in range(results['img'].shape[2]):
results['img'][:, :, i] = mmcv.clahe(
np.array(results['img'][:, :, i], dtype=np.uint8),
self.clip_limit, self.tile_grid_size)

return results

def __repr__(self):
repr_str = self.__class__.__name__
repr_str += f'(clip_limit={self.clip_limit}, '\
f'tile_grid_size={self.tile_grid_size})'
return repr_str


@PIPELINES.register_module()
class RandomCrop(object):
"""Random crop the image & seg.
Expand Down
40 changes: 40 additions & 0 deletions tests/test_data/test_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,6 +409,46 @@ def test_rerange():
assert str(transform) == f'Rerange(min_value={0}, max_value={255})'


def test_CLAHE():
# test assertion if clip_limit is None
with pytest.raises(AssertionError):
transform = dict(type='CLAHE', clip_limit=None)
build_from_cfg(transform, PIPELINES)

# test assertion if tile_grid_size is illegal
with pytest.raises(AssertionError):
transform = dict(type='CLAHE', tile_grid_size=(8.0, 8.0))
build_from_cfg(transform, PIPELINES)

# test assertion if tile_grid_size is illegal
with pytest.raises(AssertionError):
transform = dict(type='CLAHE', tile_grid_size=(9, 9, 9))
build_from_cfg(transform, PIPELINES)

transform = dict(type='CLAHE', clip_limit=2)
transform = build_from_cfg(transform, PIPELINES)
results = dict()
img = mmcv.imread(
osp.join(osp.dirname(__file__), '../data/color.jpg'), 'color')
original_img = copy.deepcopy(img)
results['img'] = img
results['img_shape'] = img.shape
results['ori_shape'] = img.shape
# Set initial values for default meta_keys
results['pad_shape'] = img.shape
results['scale_factor'] = 1.0

results = transform(results)

converted_img = np.empty(original_img.shape)
for i in range(original_img.shape[2]):
converted_img[:, :, i] = mmcv.clahe(
np.array(original_img[:, :, i], dtype=np.uint8), 2, (8, 8))

assert np.allclose(results['img'], converted_img)
assert str(transform) == f'CLAHE(clip_limit={2}, tile_grid_size={(8, 8)})'


def test_seg_rescale():
results = dict()
seg = np.array(
Expand Down