Skip to content

Commit

Permalink
[Fix] Fix mmseg.api.inference inference_segmentor (#1849)
Browse files Browse the repository at this point in the history
* [Fix] Fix mmseg.api.inference inference_segmentor

Motivation
Fix inference_segmentor not working with multiple images path or images. List[str/ndarray]

Modification
- process images if instance is list

* fix typo

* Update mmseg/apis/inference.py

Co-authored-by: Hakjin Lee <nijkah@gmail.com>

Co-authored-by: Hakjin Lee <nijkah@gmail.com>
  • Loading branch information
jinwonkim93 and nijkah authored Sep 13, 2022
1 parent ca7c098 commit ecd1ecb
Showing 1 changed file with 8 additions and 4 deletions.
12 changes: 8 additions & 4 deletions mmseg/apis/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def __call__(self, results):
return results


def inference_segmentor(model, img):
def inference_segmentor(model, imgs):
"""Inference image(s) with the segmentor.
Args:
Expand All @@ -84,9 +84,13 @@ def inference_segmentor(model, img):
test_pipeline = [LoadImage()] + cfg.data.test.pipeline[1:]
test_pipeline = Compose(test_pipeline)
# prepare data
data = dict(img=img)
data = test_pipeline(data)
data = collate([data], samples_per_gpu=1)
data = []
imgs = imgs if isinstance(imgs, list) else [imgs]
for img in imgs:
img_data = dict(img=img)
img_data = test_pipeline(img_data)
data.append(img_data)
data = collate(data, samples_per_gpu=len(imgs))
if next(model.parameters()).is_cuda:
# scatter to specified GPU
data = scatter(data, [device])[0]
Expand Down

0 comments on commit ecd1ecb

Please sign in to comment.