Skip to content

Commit

Permalink
Add feature for image demo. (#128)
Browse files Browse the repository at this point in the history
* Add the switch of whether to display the result.

* Add the inference test of the images folder

* Add video demo.

* Add progress bar in image and video demo.

* Add dir inference support.

* Add url download support.

* Add output path prompt at the end.

* Modify the logic of saving results.

* Update mmengine limit.
  • Loading branch information
JiayuXu0 authored and hhaAndroid committed Oct 13, 2022
1 parent 2c8d5ab commit 1ad4ff2
Show file tree
Hide file tree
Showing 4 changed files with 91 additions and 29 deletions.
82 changes: 60 additions & 22 deletions demo/image_demo.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,34 @@
# Copyright (c) OpenMMLab. All rights reserved.
import logging
import os
import urllib
from argparse import ArgumentParser

import mmcv
import torch
from mmdet.apis import inference_detector, init_detector
from mmengine.logging import print_log
from mmengine.utils import ProgressBar, scandir

from mmyolo.registry import VISUALIZERS
from mmyolo.utils import register_all_modules

IMG_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif',
'.tiff', '.webp')


def parse_args():
parser = ArgumentParser()
parser.add_argument('img', help='Image file')
parser.add_argument(
'img', help='Image path, include image file, dir and URL.')
parser.add_argument('config', help='Config file')
parser.add_argument('checkpoint', help='Checkpoint file')
parser.add_argument('--out-file', default=None, help='Path to output file')
parser.add_argument(
'--out-dir', default='./output', help='Path to output file')
parser.add_argument(
'--device', default='cuda:0', help='Device used for inference')
parser.add_argument(
'--palette',
default='coco',
choices=['coco', 'voc', 'citys', 'random'],
help='Color palette used for visualization')
'--show', action='store_true', help='Show the detection results')
parser.add_argument(
'--score-thr', type=float, default=0.3, help='bbox score threshold')
args = parser.parse_args()
Expand All @@ -31,29 +39,59 @@ def main(args):
# register all modules in mmdet into the registries
register_all_modules()

# TODO: Support inference of image directory.
# build the model from a config file and a checkpoint file
model = init_detector(args.config, args.checkpoint, device=args.device)

# init visualizer
visualizer = VISUALIZERS.build(model.cfg.visualizer)
visualizer.dataset_meta = model.dataset_meta

# test a single image
result = inference_detector(model, args.img)

# show the results
img = mmcv.imread(args.img)
img = mmcv.imconvert(img, 'bgr', 'rgb')
visualizer.add_datasample(
'result',
img,
data_sample=result,
draw_gt=False,
show=True,
wait_time=0,
out_file=args.out_file,
pred_score_thr=args.score_thr)
is_dir = os.path.isdir(args.img)
is_url = args.img.startswith(('http:/', 'https:/'))
is_file = os.path.splitext(args.img)[-1] in (IMG_EXTENSIONS)

files = []
if is_dir:
# when input source is dir
for file in scandir(args.img, IMG_EXTENSIONS, recursive=True):
files.append(os.path.join(args.img, file))
elif is_url:
# when input source is url
filename = os.path.basename(
urllib.parse.unquote(args.img).split('?')[0])
torch.hub.download_url_to_file(args.img, filename)
files = [os.path.join(os.getcwd(), filename)]
elif is_file:
# when input source is single image
files = [args.img]
else:
print_log(
'Cannot find image file.', logger='current', level=logging.WARNING)

# start detector inference
progress_bar = ProgressBar(len(files))
for file in files:
result = inference_detector(model, file)
img = mmcv.imread(file)
img = mmcv.imconvert(img, 'bgr', 'rgb')
if is_dir:
filename = os.path.relpath(file, args.img).replace('/', '_')
else:
filename = os.path.basename(file)
out_file = None if args.show else os.path.join(args.out_dir, filename)
visualizer.add_datasample(
filename,
img,
data_sample=result,
draw_gt=False,
show=args.show,
wait_time=0,
out_file=out_file,
pred_score_thr=args.score_thr)
progress_bar.update()
if not args.show:
print_log(
f'\nResults have been saved at {os.path.abspath(args.out_dir)}')


if __name__ == '__main__':
Expand Down
18 changes: 16 additions & 2 deletions docs/en/get_started.md
Original file line number Diff line number Diff line change
Expand Up @@ -99,10 +99,24 @@ The downloading will take several seconds or more, depending on your network env
Option (a). If you install MMYOLO from source, just run the following command.

```shell
python demo/image_demo.py demo/demo.jpg yolov5_s-v61_syncbn_fast_8xb16-300e_coco.py yolov5_s-v61_syncbn_fast_8xb16-300e_coco_20220918_084700-86e02187.pth --device cpu --out-file result.jpg
python demo/image_demo.py demo/demo.jpg \
yolov5_s-v61_syncbn_fast_8xb16-300e_coco.py \
yolov5_s-v61_syncbn_fast_8xb16-300e_coco_20220918_084700-86e02187.pth

# Optional parameters
# --out-dir ./output *The detection results are output to the specified directory. When args have action --show, the script do not save results. Default: ./output
# --device cuda:0 *The computing resources used, including cuda and cpu. Default: cuda:0
# --show *Display the results on the screen. Default: False
# --score-thr 0.3 *Confidence threshold. Default: 0.3
```

You will see a new image `result.jpg` on your current folder, where bounding boxes are plotted.
You will see a new image on your `output` folder, where bounding boxes are plotted.

Supported input types:

- Single image, include `jpg`, `jpeg`, `png`, `ppm`, `bmp`, `pgm`, `tif`, `tiff`, `webp`.
- Folder, all image files in the folder will be traversed and the corresponding results will be output.
- URL, will automatically download from the URL and the corresponding results will be output.

Option (b). If you install MMYOLO with MIM, open your python interpreter and copy&paste the following codes.

Expand Down
18 changes: 14 additions & 4 deletions docs/zh_cn/get_started.md
Original file line number Diff line number Diff line change
Expand Up @@ -100,12 +100,22 @@ mim download mmyolo --config yolov5_s-v61_syncbn_fast_8xb16-300e_coco --dest .
```shell
python demo/image_demo.py demo/demo.jpg \
yolov5_s-v61_syncbn_fast_8xb16-300e_coco.py \
yolov5_s-v61_syncbn_fast_8xb16-300e_coco_20220918_084700-86e02187.pth \
--device cpu \
--out-file result.jpg
yolov5_s-v61_syncbn_fast_8xb16-300e_coco_20220918_084700-86e02187.pth

# 可选参数
# --out-dir ./output *检测结果输出到指定目录下,默认为./output, 当--show参数存在时,不保存检测结果
# --device cuda:0 *使用的计算资源,包括cuda, cpu等,默认为cuda:0
# --show *使用该参数表示在屏幕上显示检测结果,默认为False
# --score-thr 0.3 *置信度阈值,默认为0.3
```

你会在当前文件夹中看到一个新的图像 `result.jpg`,图像中包含有网络预测的检测框。
运行结束后,在 `output` 文件夹中可以看到检测结果图像,图像中包含有网络预测的检测框。

支持输入类型包括

- 单张图片, 支持 `jpg`, `jpeg`, `png`, `ppm`, `bmp`, `pgm`, `tif`, `tiff`, `webp`
- 文件目录,会遍历文件目录下所有图片文件,并输出对应结果。
- 网址,会自动从对应网址下载图片,并输出结果。

方案 2. 如果你通过 MIM 安装的 MMYOLO, 那么可以打开你的 Python 解析器,复制并粘贴以下代码:

Expand Down
2 changes: 1 addition & 1 deletion mmyolo/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
mmcv_version = digit_version(mmcv.__version__)

mmengine_minimum_version = '0.1.0'
mmengine_maximum_version = '0.2.0'
mmengine_maximum_version = '0.2.1'
mmengine_version = digit_version(mmengine.__version__)

mmdet_minimum_version = '3.0.0rc1'
Expand Down

0 comments on commit 1ad4ff2

Please sign in to comment.