Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature]: support image visualization for tensorboard and wandb #15

Merged
merged 3 commits into from
Apr 21, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions benchmarks/extract_dataset_configs/imagenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,4 +23,5 @@
dict(type='CenterCrop', size=224),
dict(type='ToTensor'),
dict(type='Normalize', **img_norm_cfg),
dict(type='Collect', keys=['img', 'gt_labels'])
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

has nothing to do with this pr, fix previous bug, please refer to: #6

]))
12 changes: 10 additions & 2 deletions configs/detection/yolox/yolox_s_8xb16_300e_coco.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,14 @@
]

# evaluation
eval_config = dict(interval=10, gpu_collect=False)
eval_config = dict(
interval=10,
gpu_collect=False,
visualization_config=dict(
vis_num=10,
score_thr=0.5,
) # show by TensorboardLoggerHookV2 and WandbLoggerHookV2
)
eval_pipelines = [
dict(
mode='test',
Expand Down Expand Up @@ -168,7 +175,8 @@
interval=100,
hooks=[
dict(type='TextLoggerHook'),
dict(type='TensorboardLoggerHook')
dict(type='TensorboardLoggerHookV2'),
# dict(type='WandbLoggerHookV2'),
Copy link
Collaborator

@wenmengzhou wenmengzhou Apr 20, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wandb config example should be added in another file or using comment format

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add comment

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wandb initialization config is missing

])

export = dict(use_jit=False)
4 changes: 2 additions & 2 deletions easycv/core/visualization/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from .image import imshow_bboxes, imshow_keypoints
from .image import imshow_bboxes, imshow_keypoints, imshow_label

__all__ = ['imshow_bboxes', 'imshow_keypoints']
__all__ = ['imshow_bboxes', 'imshow_keypoints', 'imshow_label']
121 changes: 116 additions & 5 deletions easycv/core/visualization/image.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,119 @@
# Copyright (c) OpenMMLab. All rights reserved.
# Adapt from https://github.com/open-mmlab/mmpose/blob/master/mmpose/core/visualization/image.py
import math
import os
from urllib import request

import cv2
import mmcv
import numpy as np
from mmcv.utils.misc import deprecated_api_warning
from PIL import Image, ImageDraw, ImageFont
from torch.hub import DEFAULT_CACHE_DIR


def download_font(save_path=None):
url_path = 'http://pai-vision-data-hz.oss-accelerate.aliyuncs.com/EasyCV/pkgs/simhei.ttf'
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add this font to git and auto package it into our python whl

Copy link
Collaborator Author

@Cathy0908 Cathy0908 Apr 20, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

put simhei.ttf to resource dir, and support package to whl and zip

if save_path is None:
save_dir = DEFAULT_CACHE_DIR
if not os.path.exists(save_dir):
os.makedirs(save_dir)
save_path = os.path.join(save_dir, 'simhei.ttf')

if os.path.exists(save_path):
return save_path

f = request.urlopen(url_path)
with open(save_path, 'wb') as fw:
fw.write(f.read())

return save_path


def put_text(img, xy, text, fill, font_path, size=20):
"""support chinese text
"""
img = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
draw = ImageDraw.Draw(img)
fontText = ImageFont.truetype(font_path, size=size, encoding='utf-8')
draw.text(xy, text, fill=fill, font=fontText)
img = cv2.cvtColor(np.asarray(img), cv2.COLOR_RGB2BGR)
return img


def imshow_label(img,
labels,
text_color='blue',
font_size=20,
thickness=1,
font_scale=0.5,
intervel=5,
show=True,
win_name='',
wait_time=0,
out_file=None):
"""Draw images with labels on an image.

Args:
img (str or ndarray): The image to be displayed.
labels (str or list[str]): labels of each image.
text_color (str or tuple or :obj:`Color`): Color of texts.
font_size (int): Size of font.
thickness (int): Thickness of lines.
font_scale (float): Font scales of texts.
intervel(int): interval pixels between multiple labels
show (bool): Whether to show the image.
win_name (str): The window name.
wait_time (int): Value of waitKey param.
out_file (str, optional): The filename to write the image.

Returns:
ndarray: The image with bboxes drawn on it.
"""
font_path = download_font()
img = mmcv.imread(img)
img = np.ascontiguousarray(img)
labels = [labels] if isinstance(labels, str) else labels

cur_height = 0
for label in labels:
# roughly estimate the proper font size
text_size, text_baseline = cv2.getTextSize(label,
cv2.FONT_HERSHEY_DUPLEX,
font_scale, thickness)

org = (text_baseline + text_size[1],
text_baseline + text_size[1] + cur_height)

# support chinese text
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

support chinese font

# TODO: Unify the font of cv2 and PIL, and auto get font_size according to the font_scale
img = put_text(
img,
org,
text=label,
fill=text_color,
font_path=font_path,
size=font_size)

# cv2.putText(img, label, org, cv2.FONT_HERSHEY_DUPLEX, font_scale,
# mmcv.color_val(text_color), thickness)

cur_height += text_baseline + text_size[1] + intervel

if show:
mmcv.imshow(img, win_name, wait_time)
if out_file is not None:
mmcv.imwrite(img, out_file)

return img


def imshow_bboxes(img,
bboxes,
labels=None,
colors='green',
text_color='white',
font_size=20,
thickness=1,
font_scale=0.5,
show=True,
Expand All @@ -29,6 +130,7 @@ def imshow_bboxes(img,
labels (str or list[str], optional): labels of each bbox.
colors (list[str or tuple or :obj:`Color`]): A list of colors.
text_color (str or tuple or :obj:`Color`): Color of texts.
font_size (int): Size of font.
thickness (int): Thickness of lines.
font_scale (float): Font scales of texts.
show (bool): Whether to show the image.
Expand All @@ -39,6 +141,7 @@ def imshow_bboxes(img,
Returns:
ndarray: The image with bboxes drawn on it.
"""
font_path = download_font()

# adapt to mmcv.imshow_bboxes input format
bboxes = np.split(
Expand All @@ -58,11 +161,10 @@ def imshow_bboxes(img,
out_file=None)

if labels is not None:
if not isinstance(labels, list):
labels = [labels for _ in range(len(bboxes))]
assert len(labels) == len(bboxes)

for bbox, label, color in zip(bboxes, labels, colors):
label = str(label)
bbox_int = bbox[0, :4].astype(np.int32)
# roughly estimate the proper font size
text_size, text_baseline = cv2.getTextSize(label,
Expand All @@ -74,9 +176,18 @@ def imshow_bboxes(img,
text_y2 = text_y1 + text_size[1] + text_baseline
cv2.rectangle(img, (text_x1, text_y1), (text_x2, text_y2), color,
cv2.FILLED)
cv2.putText(img, label, (text_x1, text_y2 - text_baseline),
cv2.FONT_HERSHEY_DUPLEX, font_scale,
mmcv.color_val(text_color), thickness)
# cv2.putText(img, label, (text_x1, text_y2 - text_baseline),
# cv2.FONT_HERSHEY_DUPLEX, font_scale,
# mmcv.color_val(text_color), thickness)

# support chinese text
# TODO: Unify the font of cv2 and PIL, and auto get font_size according to the font_scale
img = put_text(
img, (text_x1, text_y1),
text=label,
fill=text_color,
font_path=font_path,
size=font_size)

if show:
mmcv.imshow(img, win_name, wait_time)
Expand Down
35 changes: 35 additions & 0 deletions easycv/datasets/classification/raw.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import torch
from PIL import Image

from easycv.core.visualization.image import imshow_label
from easycv.datasets.registry import DATASETS
from easycv.datasets.shared.base import BaseDataset

Expand Down Expand Up @@ -59,3 +60,37 @@ def evaluate(self, results, evaluators, logger=None, topk=(1, 5)):
eval_res = evaluators[0].evaluate(results, gt_labels)

return eval_res

def visualize(self, results, vis_num=10, **kwargs):
"""Visulaize the model output on validation data.
Args:
results: A dictionary containing
class: List of length number of test images.
img_metas: List of length number of test images,
dict of image meta info, containing filename, img_shape,
origin_img_shape and so on.
vis_num: number of images visualized
Returns: A dictionary containing
images: Visulaized images, list of np.ndarray.
img_metas: List of length number of test images,
dict of image meta info, containing filename, img_shape,
origin_img_shape and so on.
"""
vis_imgs = []

# TODO: support img_metas for torch.jit
if results.get('img_metas', None) is None:
return {}

img_metas = results['img_metas'][:vis_num]
labels = results['class']

for i, img_meta in enumerate(img_metas):
filename = img_meta['filename']

vis_img = imshow_label(img=filename, labels=labels, show=False)
vis_imgs.append(vis_img)

output = {'images': vis_imgs, 'img_metas': img_metas}

return output
56 changes: 4 additions & 52 deletions easycv/datasets/detection/mix.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,13 @@
import torch

from easycv.datasets.registry import DATASETS, PIPELINES
from easycv.datasets.shared.base import BaseDataset
from easycv.utils import build_from_cfg
from easycv.utils.bbox_util import batched_xyxy2cxcywh_with_shape
from easycv.utils.bbox_util import xyxy2xywh as xyxy2cxcywh
from .raw import DetDataset


@DATASETS.register_module
class DetImagesMixDataset(BaseDataset):
class DetImagesMixDataset(DetDataset):
"""A wrapper of multiple images mixed dataset.

Suitable for training on multiple images mixed data augmentation like
Expand Down Expand Up @@ -50,7 +49,7 @@ def __init__(self,
label_padding=True):

super(DetImagesMixDataset, self).__init__(
data_source, pipeline, profiling=profiling)
data_source, pipeline, profiling=profiling, classes=classes)

if skip_type_keys is not None:
assert all([
Expand All @@ -70,10 +69,9 @@ def __init__(self,
else:
raise TypeError('pipeline must be a dict')

self.CLASSES = classes
if hasattr(self.data_source, 'flag'):
self.flag = self.data_source.flag
self.num_samples = self.data_source.get_length()

if dynamic_scale is not None:
assert isinstance(dynamic_scale, tuple)

Expand All @@ -83,9 +81,6 @@ def __init__(self,
self.label_padding = label_padding
self.max_labels_num = 120

def __len__(self):
return self.num_samples

def __getitem__(self, idx):
results = copy.deepcopy(self.data_source.get_sample(idx))
for (transform, transform_type) in zip(self.pipeline_yolox,
Expand Down Expand Up @@ -116,21 +111,8 @@ def __getitem__(self, idx):
if 'img_scale' in results:
results.pop('img_scale')

# print(result.keys())

# if self.yolo_format:
# # print(type(results['img_metas']), results['img_metas'])
# # print(type(results['img_metas']._data), results['img_metas']._data)
# img_shape = results['img_metas']._data['img_shape'][:2]
# # print(type(results['gt_bboxes']))
# gt_bboxes = xyxy2cxcywh_with_shape(results['gt_bboxes']._data, img_shape)
# results['gt_bboxes'] = gt_bboxes.float()

if self.label_padding:

cxcywh_gt_bboxes = xyxy2cxcywh(results['gt_bboxes']._data)
# cxcywh_gt_bboxes = results['gt_bboxes']._data

padded_gt_bboxes = torch.zeros((self.max_labels_num, 4),
device=cxcywh_gt_bboxes.device)
padded_gt_bboxes[range(cxcywh_gt_bboxes.shape[0])[:self.max_labels_num]] = \
Expand All @@ -146,9 +128,6 @@ def __getitem__(self, idx):
results['gt_bboxes'] = padded_gt_bboxes
results['gt_labels'] = padded_labels

# ['img_metas', 'img', 'gt_bboxes', 'gt_labels']
# results.pop('img_metas')
# print(results['img_metas'], "hhh", idx)
return results

def update_skip_type_keys(self, skip_type_keys):
Expand Down Expand Up @@ -240,30 +219,3 @@ def format_results(self, results, jsonfile_prefix=None, **kwargs):
tmp_dir = None
result_files = self.results2json(results, jsonfile_prefix)
return result_files, tmp_dir

def evaluate(self, results, evaluators=None, logger=None):
'''results: a dict of list of Tensors, list length equals to number of test images
'''

eval_result = dict()

groundtruth_dict = {}
groundtruth_dict['groundtruth_boxes'] = [
batched_xyxy2cxcywh_with_shape(
self.data_source.get_ann_info(idx)['bboxes'],
results['img_metas'][idx]['ori_img_shape'])
for idx in range(len(results['img_metas']))
]
groundtruth_dict['groundtruth_classes'] = [
self.data_source.get_ann_info(idx)['labels']
for idx in range(len(results['img_metas']))
]
groundtruth_dict['groundtruth_is_crowd'] = [
self.data_source.get_ann_info(idx)['groundtruth_is_crowd']
for idx in range(len(results['img_metas']))
]

for evaluator in evaluators:
eval_result.update(evaluator.evaluate(results, groundtruth_dict))
# print(eval_result)
return eval_result
Loading