Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add feature for image demo. #128

Merged
merged 12 commits into from
Oct 12, 2022
82 changes: 60 additions & 22 deletions demo/image_demo.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,34 @@
# Copyright (c) OpenMMLab. All rights reserved.
import logging
import os
import urllib
from argparse import ArgumentParser

import mmcv
import torch
from mmdet.apis import inference_detector, init_detector
from mmengine.logging import print_log
from mmengine.utils import ProgressBar, scandir

from mmyolo.registry import VISUALIZERS
from mmyolo.utils import register_all_modules

IMG_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif',
'.tiff', '.webp')


def parse_args():
parser = ArgumentParser()
parser.add_argument('img', help='Image file')
parser.add_argument(
'img', help='Image path, include image file, dir and URL.')
parser.add_argument('config', help='Config file')
parser.add_argument('checkpoint', help='Checkpoint file')
parser.add_argument('--out-file', default=None, help='Path to output file')
parser.add_argument(
'--out-dir', default='./output', help='Path to output file')
parser.add_argument(
hhaAndroid marked this conversation as resolved.
Show resolved Hide resolved
'--device', default='cuda:0', help='Device used for inference')
parser.add_argument(
'--palette',
default='coco',
choices=['coco', 'voc', 'citys', 'random'],
help='Color palette used for visualization')
'--show', action='store_true', help='Show the detection results')
parser.add_argument(
'--score-thr', type=float, default=0.3, help='bbox score threshold')
args = parser.parse_args()
Expand All @@ -31,29 +39,59 @@ def main(args):
# register all modules in mmdet into the registries
register_all_modules()

# TODO: Support inference of image directory.
# build the model from a config file and a checkpoint file
model = init_detector(args.config, args.checkpoint, device=args.device)

# init visualizer
visualizer = VISUALIZERS.build(model.cfg.visualizer)
visualizer.dataset_meta = model.dataset_meta

# test a single image
result = inference_detector(model, args.img)

# show the results
img = mmcv.imread(args.img)
img = mmcv.imconvert(img, 'bgr', 'rgb')
visualizer.add_datasample(
'result',
img,
data_sample=result,
draw_gt=False,
show=True,
wait_time=0,
out_file=args.out_file,
pred_score_thr=args.score_thr)
is_dir = os.path.isdir(args.img)
is_url = args.img.startswith(('http:/', 'https:/'))
is_file = os.path.splitext(args.img)[-1] in (IMG_EXTENSIONS)

files = []
if is_dir:
# when input source is dir
for file in scandir(args.img, IMG_EXTENSIONS, recursive=True):
files.append(os.path.join(args.img, file))
elif is_url:
# when input source is url
filename = os.path.basename(
urllib.parse.unquote(args.img).split('?')[0])
torch.hub.download_url_to_file(args.img, filename)
files = [os.path.join(os.getcwd(), filename)]
elif is_file:
# when input source is single image
files = [args.img]
else:
print_log(
'Cannot find image file.', logger='current', level=logging.WARNING)

# start detector inference
progress_bar = ProgressBar(len(files))
for file in files:
result = inference_detector(model, file)
img = mmcv.imread(file)
img = mmcv.imconvert(img, 'bgr', 'rgb')
if is_dir:
filename = os.path.relpath(file, args.img).replace('/', '_')
else:
filename = os.path.basename(file)
out_file = None if args.show else os.path.join(args.out_dir, filename)
visualizer.add_datasample(
filename,
img,
data_sample=result,
draw_gt=False,
show=args.show,
wait_time=0,
out_file=out_file,
pred_score_thr=args.score_thr)
progress_bar.update()
if not args.show:
print_log(
f'\nResults have been saved at {os.path.abspath(args.out_dir)}')


if __name__ == '__main__':
Expand Down
18 changes: 16 additions & 2 deletions docs/en/get_started.md
Original file line number Diff line number Diff line change
Expand Up @@ -99,10 +99,24 @@ The downloading will take several seconds or more, depending on your network env
Option (a). If you install MMYOLO from source, just run the following command.

```shell
python demo/image_demo.py demo/demo.jpg yolov5_s-v61_syncbn_fast_8xb16-300e_coco.py yolov5_s-v61_syncbn_fast_8xb16-300e_coco_20220918_084700-86e02187.pth --device cpu --out-file result.jpg
python demo/image_demo.py demo/demo.jpg \
yolov5_s-v61_syncbn_fast_8xb16-300e_coco.py \
yolov5_s-v61_syncbn_fast_8xb16-300e_coco_20220918_084700-86e02187.pth

# Optional parameters
# --out-dir ./output *The detection results are output to the specified directory. When args have action --show, the script do not save results. Default: ./output
# --device cuda:0 *The computing resources used, including cuda and cpu. Default: cuda:0
# --show *Display the results on the screen. Default: False
# --score-thr 0.3 *Confidence threshold. Default: 0.3
```

You will see a new image `result.jpg` on your current folder, where bounding boxes are plotted.
You will see a new image on your `output` folder, where bounding boxes are plotted.

Supported input types:

- Single image, include `jpg`, `jpeg`, `png`, `ppm`, `bmp`, `pgm`, `tif`, `tiff`, `webp`.
- Folder, all image files in the folder will be traversed and the corresponding results will be output.
- URL, will automatically download from the URL and the corresponding results will be output.

Option (b). If you install MMYOLO with MIM, open your python interpreter and copy&paste the following codes.

Expand Down
18 changes: 14 additions & 4 deletions docs/zh_cn/get_started.md
Original file line number Diff line number Diff line change
Expand Up @@ -100,12 +100,22 @@ mim download mmyolo --config yolov5_s-v61_syncbn_fast_8xb16-300e_coco --dest .
```shell
python demo/image_demo.py demo/demo.jpg \
yolov5_s-v61_syncbn_fast_8xb16-300e_coco.py \
yolov5_s-v61_syncbn_fast_8xb16-300e_coco_20220918_084700-86e02187.pth \
--device cpu \
--out-file result.jpg
yolov5_s-v61_syncbn_fast_8xb16-300e_coco_20220918_084700-86e02187.pth

# 可选参数
# --out-dir ./output *检测结果输出到指定目录下,默认为./output, 当--show参数存在时,不保存检测结果
# --device cuda:0 *使用的计算资源,包括cuda, cpu等,默认为cuda:0
# --show *使用该参数表示在屏幕上显示检测结果,默认为False
# --score-thr 0.3 *置信度阈值,默认为0.3
```

你会在当前文件夹中看到一个新的图像 `result.jpg`,图像中包含有网络预测的检测框。
运行结束后,在 `output` 文件夹中可以看到检测结果图像,图像中包含有网络预测的检测框。

支持输入类型包括

- 单张图片, 支持 `jpg`, `jpeg`, `png`, `ppm`, `bmp`, `pgm`, `tif`, `tiff`, `webp`。
- 文件目录,会遍历文件目录下所有图片文件,并输出对应结果。
- 网址,会自动从对应网址下载图片,并输出结果。

方案 2. 如果你通过 MIM 安装的 MMYOLO, 那么可以打开你的 Python 解析器,复制并粘贴以下代码:

Expand Down
2 changes: 1 addition & 1 deletion mmyolo/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
mmcv_version = digit_version(mmcv.__version__)

mmengine_minimum_version = '0.1.0'
mmengine_maximum_version = '0.2.0'
mmengine_maximum_version = '0.2.1'
mmengine_version = digit_version(mmengine.__version__)

mmdet_minimum_version = '3.0.0rc1'
Expand Down