Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Fix] Add swap_labe_pairs in RandomFlip #2332

Merged
merged 9 commits into from
Oct 17, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 58 additions & 21 deletions mmcv/transforms/processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -1025,21 +1025,25 @@ class RandomFlip(BaseTransform):

- flip
- flip_direction
- swap_seg_labels (optional)

Args:
prob (float | list[float], optional): The flipping probability.
Defaults to None.
direction(str | list[str]): The flipping direction. Options
If input is a list, the length must equal ``prob``. Each
element in ``prob`` indicates the flip probability of
corresponding direction. Defaults to 'horizontal'.
prob (float | list[float], optional): The flipping probability.
Defaults to None.
direction(str | list[str]): The flipping direction. Options
If input is a list, the length must equal ``prob``. Each
element in ``prob`` indicates the flip probability of
corresponding direction. Defaults to 'horizontal'.
swap_seg_labels (list, optional): The label pair need to be swapped
for ground truth, like 'left arm' and 'right arm' need to be
swapped after horizontal flipping. For example, ``[(1, 5)]``,
where 1/5 is the label of the left/right arm. Defaults to None.
"""

def __init__(
self,
prob: Optional[Union[float, Iterable[float]]] = None,
direction: Union[str,
Sequence[Optional[str]]] = 'horizontal') -> None:
def __init__(self,
prob: Optional[Union[float, Iterable[float]]] = None,
direction: Union[str, Sequence[Optional[str]]] = 'horizontal',
swap_seg_labels: Optional[Sequence] = None) -> None:
if isinstance(prob, list):
assert mmengine.is_list_of(prob, float)
assert 0 <= sum(prob) <= 1
Expand All @@ -1049,6 +1053,7 @@ def __init__(
raise ValueError(f'probs must be float or list of float, but \
got `{type(prob)}`.')
self.prob = prob
self.swap_seg_labels = swap_seg_labels

valid_directions = ['horizontal', 'vertical', 'diagonal']
if isinstance(direction, str):
Expand All @@ -1064,8 +1069,8 @@ def __init__(
if isinstance(prob, list):
assert len(prob) == len(self.direction)

def flip_bbox(self, bboxes: np.ndarray, img_shape: Tuple[int, int],
direction: str) -> np.ndarray:
def _flip_bbox(self, bboxes: np.ndarray, img_shape: Tuple[int, int],
direction: str) -> np.ndarray:
"""Flip bboxes horizontally.

Args:
Expand Down Expand Up @@ -1096,8 +1101,12 @@ def flip_bbox(self, bboxes: np.ndarray, img_shape: Tuple[int, int],
or 'diagonal', but got '{direction}'")
return flipped

def flip_keypoints(self, keypoints: np.ndarray, img_shape: Tuple[int, int],
direction: str) -> np.ndarray:
def _flip_keypoints(
self,
keypoints: np.ndarray,
img_shape: Tuple[int, int],
direction: str,
) -> np.ndarray:
"""Flip keypoints horizontally, vertically or diagonally.

Args:
Expand Down Expand Up @@ -1127,6 +1136,33 @@ def flip_keypoints(self, keypoints: np.ndarray, img_shape: Tuple[int, int],
flipped = np.concatenate([keypoints, meta_info], axis=-1)
return flipped

def _flip_seg_map(self, seg_map: dict, direction: str) -> np.ndarray:
"""Flip segmentation map horizontally, vertically or diagonally.

Args:
seg_map (numpy.ndarray): segmentation map, shape (H, W).
direction (str): Flip direction. Options are 'horizontal',
'vertical'.

Returns:
numpy.ndarray: Flipped segmentation map.
"""
seg_map = mmcv.imflip(seg_map, direction=direction)
if self.swap_seg_labels is not None:
# to handle datasets with left/right annotations
# like 'Left-arm' and 'Right-arm' in LIP dataset
# Modified from https://github.com/openseg-group/openseg.pytorch/blob/master/lib/datasets/tools/cv2_aug_transforms.py # noqa:E501
# Licensed under MIT license
temp = seg_map.copy()
assert isinstance(self.swap_seg_labels, (tuple, list))
for pair in self.swap_seg_labels:
assert isinstance(pair, (tuple, list)) and len(pair) == 2, \
'swap_seg_labels must be a sequence with pair, but got ' \
f'{self.swap_seg_labels}.'
seg_map[temp == pair[0]] = pair[1]
seg_map[temp == pair[1]] = pair[0]
return seg_map

@cache_randomness
def _choose_direction(self) -> str:
"""Choose the flip direction according to `prob` and `direction`"""
Expand Down Expand Up @@ -1162,19 +1198,20 @@ def _flip(self, results: dict) -> None:

# flip bboxes
if results.get('gt_bboxes', None) is not None:
results['gt_bboxes'] = self.flip_bbox(results['gt_bboxes'],
img_shape,
results['flip_direction'])
results['gt_bboxes'] = self._flip_bbox(results['gt_bboxes'],
img_shape,
results['flip_direction'])

# flip keypoints
if results.get('gt_keypoints', None) is not None:
results['gt_keypoints'] = self.flip_keypoints(
results['gt_keypoints'] = self._flip_keypoints(
results['gt_keypoints'], img_shape, results['flip_direction'])

# flip segs
# flip seg map
if results.get('gt_seg_map', None) is not None:
results['gt_seg_map'] = mmcv.imflip(
results['gt_seg_map'] = self._flip_seg_map(
results['gt_seg_map'], direction=results['flip_direction'])
results['swap_seg_labels'] = self.swap_seg_labels

def _flip_on_direction(self, results: dict) -> None:
"""Function to flip images, bounding boxes, semantic segmentation map
Expand Down
38 changes: 32 additions & 6 deletions tests/test_transforms/test_transforms_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -777,49 +777,75 @@ def test_transform(self):
'img': np.random.random((224, 224, 3)),
'gt_bboxes': np.array([[0, 1, 100, 101]]),
'gt_keypoints': np.array([[[100, 100, 1.0]]]),
'gt_seg_map': np.random.random((224, 224, 3))
# seg map flip is irrelative with image, so there is no requirement
# that gt_set_map of test data matches image.
'gt_seg_map': np.array([[0, 1], [2, 3]])
}

# horizontal flip
TRANSFORMS = RandomFlip([1.0], ['horizontal'])
results_update = TRANSFORMS.transform(copy.deepcopy(results))
assert (results_update['gt_bboxes'] == np.array([[124, 1, 224,
101]])).all()
assert (results_update['gt_seg_map'] == np.array([[1, 0], [3,
2]])).all()

# diagnal flip
# diagonal flip
TRANSFORMS = RandomFlip([1.0], ['diagonal'])
results_update = TRANSFORMS.transform(copy.deepcopy(results))
assert (results_update['gt_bboxes'] == np.array([[124, 123, 224,
223]])).all()
assert (results_update['gt_seg_map'] == np.array([[3, 2], [1,
0]])).all()

# vertical flip
TRANSFORMS = RandomFlip([1.0], ['vertical'])
results_update = TRANSFORMS.transform(copy.deepcopy(results))
assert (results_update['gt_bboxes'] == np.array([[0, 123, 100,
223]])).all()
assert (results_update['gt_seg_map'] == np.array([[2, 3], [0,
1]])).all()

# horizontal flip when direction is None
TRANSFORMS = RandomFlip(1.0)
results_update = TRANSFORMS.transform(copy.deepcopy(results))
assert (results_update['gt_bboxes'] == np.array([[124, 1, 224,
101]])).all()
assert (results_update['gt_seg_map'] == np.array([[1, 0], [3,
2]])).all()

# horizontal flip and swap label pair
TRANSFORMS = RandomFlip([1.0], ['horizontal'],
swap_seg_labels=[[0, 1]])
results_update = TRANSFORMS.transform(copy.deepcopy(results))
assert (results_update['gt_seg_map'] == np.array([[0, 1], [3,
2]])).all()
assert results_update['swap_seg_labels'] == [[0, 1]]

TRANSFORMS = RandomFlip(0.0)
results_update = TRANSFORMS.transform(copy.deepcopy(results))
assert (results_update['gt_bboxes'] == np.array([[0, 1, 100,
101]])).all()
assert (results_update['gt_seg_map'] == np.array([[0, 1], [2,
3]])).all()

# flip direction is invalid in bbox flip
with pytest.raises(ValueError):
TRANSFORMS = RandomFlip(1.0)
results_update = TRANSFORMS.flip_bbox(results['gt_bboxes'],
(224, 224), 'invalid')
results_update = TRANSFORMS._flip_bbox(results['gt_bboxes'],
(224, 224), 'invalid')

# flip direction is invalid in keypoints flip
with pytest.raises(ValueError):
TRANSFORMS = RandomFlip(1.0)
results_update = TRANSFORMS.flip_keypoints(results['gt_keypoints'],
(224, 224), 'invalid')
results_update = TRANSFORMS._flip_keypoints(
results['gt_keypoints'], (224, 224), 'invalid')

# swap pair is invalid
with pytest.raises(AssertionError):
TRANSFORMS = RandomFlip(1.0, swap_seg_labels='invalid')
results_update = TRANSFORMS._flip_seg_map(results['gt_seg_map'],
'horizontal')

def test_repr(self):
TRANSFORMS = RandomFlip(0.1)
Expand Down