Skip to content

Commit

Permalink
[Enhancement]Add out_file in add_datasample to directly save image (#…
Browse files Browse the repository at this point in the history
…2090)

* [Enhancement]Add `out_file` in add_datasample to for save vis image directly

* comments

* ut
  • Loading branch information
MeowZheng authored Sep 20, 2022
1 parent 230246f commit 2a18328
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 24 deletions.
3 changes: 1 addition & 2 deletions mmseg/apis/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,9 +199,8 @@ def show_result_pyplot(model: BaseSegmentor,
draw_gt=draw_gt,
draw_pred=draw_pred,
wait_time=wait_time,
out_file=out_file,
show=show)
vis_img = visualizer.get_image()
if out_file is not None:
mmcv.imwrite(vis_img, out_file)

return vis_img
29 changes: 20 additions & 9 deletions mmseg/visualization/local_visualizer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, List, Optional, Tuple

import mmcv
import numpy as np
from mmengine.dist import master_only
from mmengine.structures import PixelData
Expand Down Expand Up @@ -99,22 +100,28 @@ def _draw_sem_seg(self, image: np.ndarray, sem_seg: PixelData,
return self.get_image()

@master_only
def add_datasample(self,
name: str,
image: np.ndarray,
data_sample: Optional[SegDataSample] = None,
draw_gt: bool = True,
draw_pred: bool = True,
show: bool = False,
wait_time: float = 0,
step: int = 0) -> None:
def add_datasample(
self,
name: str,
image: np.ndarray,
data_sample: Optional[SegDataSample] = None,
draw_gt: bool = True,
draw_pred: bool = True,
show: bool = False,
wait_time: float = 0,
# TODO: Supported in mmengine's Viusalizer.
out_file: Optional[str] = None,
step: int = 0) -> None:
"""Draw datasample and save to all backends.
- If GT and prediction are plotted at the same time, they are
displayed in a stitched image where the left image is the
ground truth and the right image is the prediction.
- If ``show`` is True, all storage backends are ignored, and
the images will be displayed in a local window.
- If ``out_file`` is specified, the drawn image will be
saved to ``out_file``. it is usually used when the display
is not available.
Args:
name (str): The image identifier.
Expand All @@ -128,6 +135,7 @@ def add_datasample(self,
Defaults to True.
show (bool): Whether to display the drawn image. Default to False.
wait_time (float): The interval of show (s). Defaults to 0.
out_file (str): Path to output file. Defaults to None.
step (int): Global step value to record. Defaults to 0.
"""
classes = self.dataset_meta.get('classes', None)
Expand Down Expand Up @@ -166,5 +174,8 @@ def add_datasample(self,

if show:
self.show(drawn_img, win_name=name, wait_time=wait_time)

if out_file is not None:
mmcv.imwrite(drawn_img, out_file)
else:
self.add_image(name, drawn_img, step)
22 changes: 9 additions & 13 deletions tests/test_visualization/test_local_visualizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,19 +118,14 @@ def test_cityscapes_add_datasample_forward(gt_sem_seg):
[255, 0, 0], [0, 0, 142], [0, 0, 70],
[0, 60, 100], [0, 80, 100], [0, 0, 230],
[119, 11, 32]])
seg_local_visualizer.add_datasample(out_file, image,
data_sample)

# test out_file
seg_local_visualizer.add_datasample(out_file, image,
data_sample)
assert os.path.exists(
osp.join(tmp_dir, 'vis_data', 'vis_image',
out_file + '_0.png'))
drawn_img = cv2.imread(
osp.join(tmp_dir, 'vis_data', 'vis_image',
out_file + '_0.png'))
assert drawn_img.shape == (h, w, 3)
seg_local_visualizer.add_datasample(
out_file,
image,
data_sample,
out_file=osp.join(tmp_dir, 'test.png'))
self._assert_image_and_shape(
osp.join(tmp_dir, 'test.png'), (h, w, 3))

# test gt_instances and pred_instances
pred_sem_seg_data = dict(
Expand All @@ -139,12 +134,13 @@ def test_cityscapes_add_datasample_forward(gt_sem_seg):

data_sample.pred_sem_seg = pred_sem_seg

# test draw prediction with gt
seg_local_visualizer.add_datasample(out_file, image,
data_sample)
self._assert_image_and_shape(
osp.join(tmp_dir, 'vis_data', 'vis_image',
out_file + '_0.png'), (h, w * 2, 3))

# test draw prediction without gt
seg_local_visualizer.add_datasample(
out_file, image, data_sample, draw_gt=False)
self._assert_image_and_shape(
Expand Down

0 comments on commit 2a18328

Please sign in to comment.