diff --git a/mmseg/datasets/pipelines/__init__.py b/mmseg/datasets/pipelines/__init__.py index b4b7148089..8b9046b07b 100644 --- a/mmseg/datasets/pipelines/__init__.py +++ b/mmseg/datasets/pipelines/__init__.py @@ -3,14 +3,14 @@ Transpose, to_tensor) from .loading import LoadAnnotations, LoadImageFromFile from .test_time_aug import MultiScaleFlipAug -from .transforms import (Normalize, Pad, PhotoMetricDistortion, RandomCrop, - RandomFlip, RandomRotate, Rerange, Resize, RGB2Gray, - SegRescale) +from .transforms import (CLAHE, AdjustGamma, Normalize, Pad, + PhotoMetricDistortion, RandomCrop, RandomFlip, + RandomRotate, Rerange, Resize, RGB2Gray, SegRescale) __all__ = [ 'Compose', 'to_tensor', 'ToTensor', 'ImageToTensor', 'ToDataContainer', 'Transpose', 'Collect', 'LoadAnnotations', 'LoadImageFromFile', 'MultiScaleFlipAug', 'Resize', 'RandomFlip', 'Pad', 'RandomCrop', 'Normalize', 'SegRescale', 'PhotoMetricDistortion', 'RandomRotate', - 'Rerange', 'RGB2Gray' + 'AdjustGamma', 'CLAHE', 'Rerange', 'RGB2Gray' ] diff --git a/mmseg/datasets/pipelines/transforms.py b/mmseg/datasets/pipelines/transforms.py index c138b21c20..4f586cfe44 100644 --- a/mmseg/datasets/pipelines/transforms.py +++ b/mmseg/datasets/pipelines/transforms.py @@ -1,6 +1,6 @@ import mmcv import numpy as np -from mmcv.utils import deprecated_api_warning +from mmcv.utils import deprecated_api_warning, is_tuple_of from numpy import random from ..builder import PIPELINES @@ -415,7 +415,6 @@ def __call__(self, results): Args: results (dict): Result dict from loading pipeline. - Returns: dict: Reranged results. """ @@ -439,6 +438,51 @@ def __repr__(self): return repr_str +@PIPELINES.register_module() +class CLAHE(object): + """Use CLAHE method to process the image. + + See `ZUIDERVELD,K. Contrast Limited Adaptive Histogram Equalization[J]. + Graphics Gems, 1994:474-485.` for more information. + + Args: + clip_limit (float): Threshold for contrast limiting. Default: 40.0. + tile_grid_size (tuple[int]): Size of grid for histogram equalization. + Input image will be divided into equally sized rectangular tiles. + It defines the number of tiles in row and column. Default: (8, 8). + """ + + def __init__(self, clip_limit=40.0, tile_grid_size=(8, 8)): + assert isinstance(clip_limit, (float, int)) + self.clip_limit = clip_limit + assert is_tuple_of(tile_grid_size, int) + assert len(tile_grid_size) == 2 + self.tile_grid_size = tile_grid_size + + def __call__(self, results): + """Call function to Use CLAHE method process images. + + Args: + results (dict): Result dict from loading pipeline. + + Returns: + dict: Processed results. + """ + + for i in range(results['img'].shape[2]): + results['img'][:, :, i] = mmcv.clahe( + np.array(results['img'][:, :, i], dtype=np.uint8), + self.clip_limit, self.tile_grid_size) + + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += f'(clip_limit={self.clip_limit}, '\ + f'tile_grid_size={self.tile_grid_size})' + return repr_str + + @PIPELINES.register_module() class RandomCrop(object): """Random crop the image & seg. diff --git a/tests/test_data/test_transform.py b/tests/test_data/test_transform.py index 19cf6d5337..1833d791e8 100644 --- a/tests/test_data/test_transform.py +++ b/tests/test_data/test_transform.py @@ -409,6 +409,46 @@ def test_rerange(): assert str(transform) == f'Rerange(min_value={0}, max_value={255})' +def test_CLAHE(): + # test assertion if clip_limit is None + with pytest.raises(AssertionError): + transform = dict(type='CLAHE', clip_limit=None) + build_from_cfg(transform, PIPELINES) + + # test assertion if tile_grid_size is illegal + with pytest.raises(AssertionError): + transform = dict(type='CLAHE', tile_grid_size=(8.0, 8.0)) + build_from_cfg(transform, PIPELINES) + + # test assertion if tile_grid_size is illegal + with pytest.raises(AssertionError): + transform = dict(type='CLAHE', tile_grid_size=(9, 9, 9)) + build_from_cfg(transform, PIPELINES) + + transform = dict(type='CLAHE', clip_limit=2) + transform = build_from_cfg(transform, PIPELINES) + results = dict() + img = mmcv.imread( + osp.join(osp.dirname(__file__), '../data/color.jpg'), 'color') + original_img = copy.deepcopy(img) + results['img'] = img + results['img_shape'] = img.shape + results['ori_shape'] = img.shape + # Set initial values for default meta_keys + results['pad_shape'] = img.shape + results['scale_factor'] = 1.0 + + results = transform(results) + + converted_img = np.empty(original_img.shape) + for i in range(original_img.shape[2]): + converted_img[:, :, i] = mmcv.clahe( + np.array(original_img[:, :, i], dtype=np.uint8), 2, (8, 8)) + + assert np.allclose(results['img'], converted_img) + assert str(transform) == f'CLAHE(clip_limit={2}, tile_grid_size={(8, 8)})' + + def test_seg_rescale(): results = dict() seg = np.array(