-
Notifications
You must be signed in to change notification settings - Fork 2.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Feature] Support MMSegInferencer #2413
Merged
+528
−15
Merged
Changes from 12 commits
Commits
Show all changes
15 commits
Select commit
Hold shift + click to select a range
c72f8c2
inferencer refactor
xiexinch 24f9b17
support show, save mask and save results
xiexinch fd79dcd
update ut
xiexinch 081638f
Merge remote-tracking branch 'upstream/dev-1.x' into mmseg_inferencer
xiexinch e0f9e02
Merge remote-tracking branch 'upstream/dev-1.x' into mmseg_inferencer
xiexinch fcfe1f1
Merge remote-tracking branch 'upstream/dev-1.x' into mmseg_inferencer
xiexinch 9ac9723
fix return image and add docstring
xiexinch 418a6d7
remove print_result
xiexinch 98749a7
update ut
xiexinch 63b07e1
remove print_result
xiexinch 5405417
Merge remote-tracking branch 'upstream/dev-1.x' into mmseg_inferencer
xiexinch dabc121
remove data type convertion
xiexinch 5d3f6d9
add mmengine requirement to docstring
xiexinch 7d5474c
Merge branch 'dev-1.x' into mmseg_inferencer
MeowZheng 39076e8
mmengine version
MeowZheng File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
from argparse import ArgumentParser | ||
|
||
from mmseg.apis import MMSegInferencer | ||
|
||
|
||
def main(): | ||
parser = ArgumentParser() | ||
parser.add_argument('img', help='Image file') | ||
parser.add_argument('model', help='Config file') | ||
parser.add_argument('--checkpoint', default=None, help='Checkpoint file') | ||
parser.add_argument( | ||
'--out-dir', default='', help='Path to save result file') | ||
parser.add_argument( | ||
'--show', | ||
action='store_true', | ||
default=False, | ||
help='Whether to display the drawn image.') | ||
parser.add_argument( | ||
'--save-mask', | ||
action='store_true', | ||
default=False, | ||
help='Enable save the mask file') | ||
parser.add_argument( | ||
'--dataset-name', | ||
default='cityscapes', | ||
help='Color palette used for segmentation map') | ||
parser.add_argument( | ||
'--device', default='cuda:0', help='Device used for inference') | ||
parser.add_argument( | ||
'--opacity', | ||
type=float, | ||
default=0.5, | ||
help='Opacity of painted segmentation map. In (0, 1] range.') | ||
args = parser.parse_args() | ||
|
||
# build the model from a config file and a checkpoint file | ||
mmseg_inferencer = MMSegInferencer( | ||
args.model, | ||
args.checkpoint, | ||
dataset_name=args.dataset_name, | ||
device=args.device) | ||
|
||
# test a single image | ||
mmseg_inferencer( | ||
args.img, | ||
show=args.show, | ||
out_dir=args.out_dir, | ||
save_mask=args.save_mask, | ||
opacity=args.opacity) | ||
|
||
|
||
if __name__ == '__main__': | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,7 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
from .inference import inference_model, init_model, show_result_pyplot | ||
from .mmseg_inferencer import MMSegInferencer | ||
|
||
__all__ = ['init_model', 'inference_model', 'show_result_pyplot'] | ||
__all__ = [ | ||
'init_model', 'inference_model', 'show_result_pyplot', 'MMSegInferencer' | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,278 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
import os.path as osp | ||
from typing import List, Optional, Sequence, Union | ||
|
||
import mmcv | ||
import mmengine | ||
import numpy as np | ||
from mmcv.transforms import Compose | ||
from mmengine.infer.infer import BaseInferencer, ModelType | ||
|
||
from mmseg.structures import SegDataSample | ||
from mmseg.utils import ConfigType, SampleList, register_all_modules | ||
from mmseg.visualization import SegLocalVisualizer | ||
|
||
InputType = Union[str, np.ndarray] | ||
InputsType = Union[InputType, Sequence[InputType]] | ||
PredType = Union[SegDataSample, SampleList] | ||
|
||
|
||
class MMSegInferencer(BaseInferencer): | ||
"""MMSegInferencer. | ||
|
||
Args: | ||
model (str, optional): Path to the config file or the model name | ||
defined in metafile. For example, it could be | ||
"fcn_r50-d8_4xb2-40k_cityscapes-512x1024" or | ||
"configs/fcn/fcn_r50-d8_4xb2-40k_cityscapes-512x1024.py" | ||
weights (str, optional): Path to the checkpoint. If it is not specified | ||
and model is a model name of metafile, the weights will be loaded | ||
from metafile. Defaults to None. | ||
palette (List[List[int]], optional): The palette of | ||
segmentation map. | ||
classes (Tuple[str], optional): Category information. | ||
dataset_name (str, optional): Name of the datasets supported in mmseg. | ||
device (str, optional): Device to run inference. If None, the available | ||
device will be automatically used. Defaults to None. | ||
scope (str, optional): The scope of the model. Defaults to None. | ||
""" | ||
|
||
preprocess_kwargs: set = set() | ||
forward_kwargs: set = {'mode', 'out_dir'} | ||
visualize_kwargs: set = { | ||
'show', 'wait_time', 'draw_pred', 'img_out_dir', 'opacity' | ||
} | ||
postprocess_kwargs: set = { | ||
'pred_out_dir', 'return_datasample', 'save_mask', 'mask_dir' | ||
} | ||
|
||
def __init__(self, | ||
model: Union[ModelType, str], | ||
weights: Optional[str] = None, | ||
palette: Optional[Union[str, List]] = None, | ||
classes: Optional[Union[str, List]] = None, | ||
dataset_name: Optional[str] = None, | ||
device: Optional[str] = None, | ||
scope: Optional[str] = 'mmseg') -> None: | ||
# A global counter tracking the number of images processes, for | ||
# naming of the output images | ||
self.num_visualized_imgs = 0 | ||
register_all_modules() | ||
super().__init__( | ||
model=model, weights=weights, device=device, scope=scope) | ||
|
||
assert isinstance(self.visualizer, SegLocalVisualizer) | ||
self.visualizer.set_dataset_meta(palette, classes, dataset_name) | ||
|
||
def __call__(self, | ||
inputs: InputsType, | ||
return_datasamples: bool = False, | ||
batch_size: int = 1, | ||
show: bool = False, | ||
wait_time: int = 0, | ||
draw_pred: bool = True, | ||
out_dir: str = '', | ||
save_mask: bool = False, | ||
mask_dir: str = 'mask', | ||
**kwargs) -> dict: | ||
"""Call the inferencer. | ||
|
||
Args: | ||
inputs (Union[str, np.ndarray]): Inputs for the inferencer. | ||
return_datasamples (bool): Whether to return results as | ||
:obj:`SegDataSample`. Defaults to False. | ||
batch_size (int): Batch size. Defaults to 1. | ||
show (bool): Whether to display the image in a popup window. | ||
Defaults to False. | ||
wait_time (float): The interval of show (s). Defaults to 0. | ||
draw_pred (bool): Whether to draw Prediction SegDataSample. | ||
Defaults to True. | ||
out_dir (str): Output directory of inference results. Defaults: ''. | ||
save_mask (bool): Whether save pred mask as a file. | ||
mask_dir (str): Sub directory of `pred_out_dir`, used to save pred | ||
mask file. | ||
|
||
Returns: | ||
dict: Inference and visualization results. | ||
""" | ||
return super().__call__( | ||
inputs=inputs, | ||
return_datasamples=return_datasamples, | ||
batch_size=batch_size, | ||
show=show, | ||
wait_time=wait_time, | ||
draw_pred=draw_pred, | ||
img_out_dir=out_dir, | ||
pred_out_dir=out_dir, | ||
save_mask=save_mask, | ||
mask_dir=mask_dir, | ||
**kwargs) | ||
|
||
def visualize(self, | ||
inputs: list, | ||
preds: List[dict], | ||
show: bool = False, | ||
wait_time: int = 0, | ||
draw_pred: bool = True, | ||
img_out_dir: str = '', | ||
opacity: float = 0.8) -> List[np.ndarray]: | ||
"""Visualize predictions. | ||
|
||
Args: | ||
inputs (list): Inputs preprocessed by :meth:`_inputs_to_list`. | ||
preds (Any): Predictions of the model. | ||
show (bool): Whether to display the image in a popup window. | ||
Defaults to False. | ||
wait_time (float): The interval of show (s). Defaults to 0. | ||
draw_pred (bool): Whether to draw Prediction SegDataSample. | ||
Defaults to True. | ||
img_out_dir (str): Output directory of drawn images. Defaults: '' | ||
opacity (int, float): The transparency of segmentation mask. | ||
Defaults to 0.8. | ||
|
||
Returns: | ||
List[np.ndarray]: Visualization results. | ||
""" | ||
if self.visualizer is None or (not show and img_out_dir == ''): | ||
return None | ||
|
||
if getattr(self, 'visualizer') is None: | ||
raise ValueError('Visualization needs the "visualizer" term' | ||
'defined in the config, but got None') | ||
|
||
self.visualizer.alpha = opacity | ||
|
||
results = [] | ||
|
||
for single_input, pred in zip(inputs, preds): | ||
if isinstance(single_input, str): | ||
img_bytes = mmengine.fileio.get(single_input) | ||
img = mmcv.imfrombytes(img_bytes) | ||
img = img[:, :, ::-1] | ||
img_name = osp.basename(single_input) | ||
elif isinstance(single_input, np.ndarray): | ||
img = single_input.copy() | ||
img_num = str(self.num_visualized_imgs).zfill(8) | ||
img_name = f'{img_num}.jpg' | ||
else: | ||
raise ValueError('Unsupported input type:' | ||
f'{type(single_input)}') | ||
|
||
out_file = osp.join(img_out_dir, img_name) if img_out_dir != ''\ | ||
else None | ||
|
||
self.visualizer.add_datasample( | ||
img_name, | ||
img, | ||
pred, | ||
show=show, | ||
wait_time=wait_time, | ||
draw_gt=False, | ||
draw_pred=draw_pred, | ||
out_file=out_file) | ||
results.append(self.visualizer.get_image()) | ||
self.num_visualized_imgs += 1 | ||
|
||
return results | ||
|
||
def postprocess(self, | ||
preds: PredType, | ||
visualization: List[np.ndarray], | ||
return_datasample: bool = False, | ||
mask_dir: str = 'mask', | ||
save_mask: bool = True, | ||
pred_out_dir: str = '') -> dict: | ||
"""Process the predictions and visualization results from ``forward`` | ||
and ``visualize``. | ||
|
||
This method should be responsible for the following tasks: | ||
|
||
1. Convert datasamples into a json-serializable dict if needed. | ||
2. Pack the predictions and visualization results and return them. | ||
3. Dump or log the predictions. | ||
|
||
Args: | ||
preds (List[Dict]): Predictions of the model. | ||
visualization (np.ndarray): Visualized predictions. | ||
return_datasample (bool): Whether to return results as datasamples. | ||
Defaults to False. | ||
pred_out_dir: File to save the inference results w/o | ||
visualization. If left as empty, no file will be saved. | ||
Defaults to ''. | ||
mask_dir (str): Sub directory of `pred_out_dir`, used to save pred | ||
mask file. | ||
save_mask (bool): Whether save pred mask as a file. | ||
|
||
Returns: | ||
dict: Inference and visualization results with key ``predictions`` | ||
and ``visualization`` | ||
|
||
- ``visualization (Any)``: Returned by :meth:`visualize` | ||
- ``predictions`` (dict or DataSample): Returned by | ||
:meth:`forward` and processed in :meth:`postprocess`. | ||
If ``return_datasample=False``, it usually should be a | ||
json-serializable dict containing only basic data elements such | ||
as strings and numbers. | ||
""" | ||
results_dict = {} | ||
|
||
results_dict['predictions'] = preds | ||
results_dict['visualization'] = visualization | ||
|
||
if pred_out_dir != '': | ||
mmengine.mkdir_or_exist(pred_out_dir) | ||
if save_mask: | ||
preds = [preds] if isinstance(preds, SegDataSample) else preds | ||
for pred in preds: | ||
mmcv.imwrite( | ||
pred.pred_sem_seg.numpy().data[0], | ||
osp.join(pred_out_dir, mask_dir, | ||
osp.basename(pred.metainfo['img_path']))) | ||
else: | ||
mmengine.dump(results_dict, | ||
osp.join(pred_out_dir, 'results.pkl')) | ||
|
||
if return_datasample: | ||
return preds | ||
|
||
return results_dict | ||
|
||
def _init_pipeline(self, cfg: ConfigType) -> Compose: | ||
"""Initialize the test pipeline. | ||
|
||
Return a pipeline to handle various input data, such as ``str``, | ||
``np.ndarray``. It is an abstract method in BaseInferencer, and should | ||
be implemented in subclasses. | ||
|
||
The returned pipeline will be used to process a single data. | ||
It will be used in :meth:`preprocess` like this: | ||
|
||
.. code-block:: python | ||
def preprocess(self, inputs, batch_size, **kwargs): | ||
... | ||
dataset = map(self.pipeline, dataset) | ||
... | ||
""" | ||
pipeline_cfg = cfg.test_dataloader.dataset.pipeline | ||
# Loading annotations is also not applicable | ||
idx = self._get_transform_idx(pipeline_cfg, 'LoadAnnotations') | ||
if idx != -1: | ||
del pipeline_cfg[idx] | ||
load_img_idx = self._get_transform_idx(pipeline_cfg, | ||
'LoadImageFromFile') | ||
|
||
if load_img_idx == -1: | ||
raise ValueError( | ||
'LoadImageFromFile is not found in the test pipeline') | ||
pipeline_cfg[load_img_idx]['type'] = 'InferencerLoader' | ||
return Compose(pipeline_cfg) | ||
|
||
def _get_transform_idx(self, pipeline_cfg: ConfigType, name: str) -> int: | ||
"""Returns the index of the transform in a pipeline. | ||
|
||
If the transform is not found, returns -1. | ||
""" | ||
for i, transform in enumerate(pipeline_cfg): | ||
if transform['type'] == name: | ||
return i | ||
return -1 |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
might add MMSegInferencer needs mmengine>=0.5.0 in docstring