Skip to content

Commit

Permalink
[Bug Fix] Fix TTA resize scale (open-mmlab#334)
Browse files Browse the repository at this point in the history
* fix tta bug

* modify as suggested

* fix test_tta bug
  • Loading branch information
yamengxi authored Jan 7, 2021
1 parent 7c4e505 commit 022b055
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 6 deletions.
2 changes: 1 addition & 1 deletion mmseg/datasets/pipelines/test_time_aug.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def __call__(self, results):
aug_data = []
if self.img_scale is None and mmcv.is_list_of(self.img_ratios, float):
h, w = results['img'].shape[:2]
img_scale = [(int(h * ratio), int(w * ratio))
img_scale = [(int(w * ratio), int(h * ratio))
for ratio in self.img_ratios]
else:
img_scale = self.img_scale
Expand Down
5 changes: 3 additions & 2 deletions mmseg/datasets/pipelines/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,8 +156,9 @@ def _random_scale(self, results):

if self.ratio_range is not None:
if self.img_scale is None:
scale, scale_idx = self.random_sample_ratio(
results['img'].shape[:2], self.ratio_range)
h, w = results['img'].shape[:2]
scale, scale_idx = self.random_sample_ratio((w, h),
self.ratio_range)
else:
scale, scale_idx = self.random_sample_ratio(
self.img_scale[0], self.ratio_range)
Expand Down
6 changes: 3 additions & 3 deletions tests/test_data/test_tta.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def test_multi_scale_flip_aug():
)
tta_module = build_from_cfg(tta_transform, PIPELINES)
tta_results = tta_module(results.copy())
assert tta_results['scale'] == [(144, 256), (288, 512), (576, 1024)]
assert tta_results['scale'] == [(256, 144), (512, 288), (1024, 576)]
assert tta_results['flip'] == [False, False, False]

tta_transform = dict(
Expand All @@ -120,8 +120,8 @@ def test_multi_scale_flip_aug():
)
tta_module = build_from_cfg(tta_transform, PIPELINES)
tta_results = tta_module(results.copy())
assert tta_results['scale'] == [(144, 256), (144, 256), (288, 512),
(288, 512), (576, 1024), (576, 1024)]
assert tta_results['scale'] == [(256, 144), (256, 144), (512, 288),
(512, 288), (1024, 576), (1024, 576)]
assert tta_results['flip'] == [False, True, False, True, False, True]

tta_transform = dict(
Expand Down

0 comments on commit 022b055

Please sign in to comment.