Skip to content

Commit

Permalink
[Feature]: support image visualization for tensorboard and wandb (#15)
Browse files Browse the repository at this point in the history
* [Feature]: support image visualization for tensorboard and wandb
  • Loading branch information
Cathy0908 authored Apr 21, 2022
1 parent 3a4a3a9 commit 9a3826f
Show file tree
Hide file tree
Showing 20 changed files with 486 additions and 103 deletions.
1 change: 1 addition & 0 deletions MANIFEST.in
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
recursive-include easycv/configs *.py
recursive-include easycv/tools *.py
recursive-include easycv/resource/ *.ttf
1 change: 1 addition & 0 deletions benchmarks/extract_dataset_configs/imagenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,4 +23,5 @@
dict(type='CenterCrop', size=224),
dict(type='ToTensor'),
dict(type='Normalize', **img_norm_cfg),
dict(type='Collect', keys=['img', 'gt_labels'])
]))
12 changes: 10 additions & 2 deletions configs/detection/yolox/yolox_s_8xb16_300e_coco.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,14 @@
]

# evaluation
eval_config = dict(interval=10, gpu_collect=False)
eval_config = dict(
interval=10,
gpu_collect=False,
visualization_config=dict(
vis_num=10,
score_thr=0.5,
) # show by TensorboardLoggerHookV2 and WandbLoggerHookV2
)
eval_pipelines = [
dict(
mode='test',
Expand Down Expand Up @@ -168,7 +175,8 @@
interval=100,
hooks=[
dict(type='TextLoggerHook'),
dict(type='TensorboardLoggerHook')
dict(type='TensorboardLoggerHookV2'),
# dict(type='WandbLoggerHookV2'),
])

export = dict(use_jit=False)
4 changes: 2 additions & 2 deletions easycv/core/visualization/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from .image import imshow_bboxes, imshow_keypoints
from .image import imshow_bboxes, imshow_keypoints, imshow_label

__all__ = ['imshow_bboxes', 'imshow_keypoints']
__all__ = ['imshow_bboxes', 'imshow_keypoints', 'imshow_label']
111 changes: 106 additions & 5 deletions easycv/core/visualization/image.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,111 @@
# Copyright (c) OpenMMLab. All rights reserved.
# Adapt from https://github.com/open-mmlab/mmpose/blob/master/mmpose/core/visualization/image.py
import math
import os
from os.path import dirname as opd

import cv2
import mmcv
import numpy as np
from mmcv.utils.misc import deprecated_api_warning
from PIL import Image, ImageDraw, ImageFont


def get_font_path():
root_path = opd(opd(opd(os.path.realpath(__file__))))
# find in whl
find_path_whl = os.path.join(root_path, 'resource/simhei.ttf')
# find in source code
find_path_source = os.path.join(opd(root_path), 'resource/simhei.ttf')
if os.path.exists(find_path_whl):
return find_path_whl
elif os.path.exists(find_path_source):
return find_path_source
else:
raise ValueError('Not find font file both in %s and %s' %
(find_path_whl, find_path_source))


_FONT_PATH = get_font_path()


def put_text(img, xy, text, fill, size=20):
"""support chinese text
"""
img = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
draw = ImageDraw.Draw(img)
fontText = ImageFont.truetype(_FONT_PATH, size=size, encoding='utf-8')
draw.text(xy, text, fill=fill, font=fontText)
img = cv2.cvtColor(np.asarray(img), cv2.COLOR_RGB2BGR)
return img


def imshow_label(img,
labels,
text_color='blue',
font_size=20,
thickness=1,
font_scale=0.5,
intervel=5,
show=True,
win_name='',
wait_time=0,
out_file=None):
"""Draw images with labels on an image.
Args:
img (str or ndarray): The image to be displayed.
labels (str or list[str]): labels of each image.
text_color (str or tuple or :obj:`Color`): Color of texts.
font_size (int): Size of font.
thickness (int): Thickness of lines.
font_scale (float): Font scales of texts.
intervel(int): interval pixels between multiple labels
show (bool): Whether to show the image.
win_name (str): The window name.
wait_time (int): Value of waitKey param.
out_file (str, optional): The filename to write the image.
Returns:
ndarray: The image with bboxes drawn on it.
"""
img = mmcv.imread(img)
img = np.ascontiguousarray(img)
labels = [labels] if isinstance(labels, str) else labels

cur_height = 0
for label in labels:
# roughly estimate the proper font size
text_size, text_baseline = cv2.getTextSize(label,
cv2.FONT_HERSHEY_DUPLEX,
font_scale, thickness)

org = (text_baseline + text_size[1],
text_baseline + text_size[1] + cur_height)

# support chinese text
# TODO: Unify the font of cv2 and PIL, and auto get font_size according to the font_scale
img = put_text(img, org, text=label, fill=text_color, size=font_size)

# cv2.putText(img, label, org, cv2.FONT_HERSHEY_DUPLEX, font_scale,
# mmcv.color_val(text_color), thickness)

cur_height += text_baseline + text_size[1] + intervel

if show:
mmcv.imshow(img, win_name, wait_time)
if out_file is not None:
mmcv.imwrite(img, out_file)

return img


def imshow_bboxes(img,
bboxes,
labels=None,
colors='green',
text_color='white',
font_size=20,
thickness=1,
font_scale=0.5,
show=True,
Expand All @@ -29,6 +122,7 @@ def imshow_bboxes(img,
labels (str or list[str], optional): labels of each bbox.
colors (list[str or tuple or :obj:`Color`]): A list of colors.
text_color (str or tuple or :obj:`Color`): Color of texts.
font_size (int): Size of font.
thickness (int): Thickness of lines.
font_scale (float): Font scales of texts.
show (bool): Whether to show the image.
Expand Down Expand Up @@ -58,11 +152,10 @@ def imshow_bboxes(img,
out_file=None)

if labels is not None:
if not isinstance(labels, list):
labels = [labels for _ in range(len(bboxes))]
assert len(labels) == len(bboxes)

for bbox, label, color in zip(bboxes, labels, colors):
label = str(label)
bbox_int = bbox[0, :4].astype(np.int32)
# roughly estimate the proper font size
text_size, text_baseline = cv2.getTextSize(label,
Expand All @@ -74,9 +167,17 @@ def imshow_bboxes(img,
text_y2 = text_y1 + text_size[1] + text_baseline
cv2.rectangle(img, (text_x1, text_y1), (text_x2, text_y2), color,
cv2.FILLED)
cv2.putText(img, label, (text_x1, text_y2 - text_baseline),
cv2.FONT_HERSHEY_DUPLEX, font_scale,
mmcv.color_val(text_color), thickness)
# cv2.putText(img, label, (text_x1, text_y2 - text_baseline),
# cv2.FONT_HERSHEY_DUPLEX, font_scale,
# mmcv.color_val(text_color), thickness)

# support chinese text
# TODO: Unify the font of cv2 and PIL, and auto get font_size according to the font_scale
img = put_text(
img, (text_x1, text_y1),
text=label,
fill=text_color,
size=font_size)

if show:
mmcv.imshow(img, win_name, wait_time)
Expand Down
35 changes: 35 additions & 0 deletions easycv/datasets/classification/raw.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import torch
from PIL import Image

from easycv.core.visualization.image import imshow_label
from easycv.datasets.registry import DATASETS
from easycv.datasets.shared.base import BaseDataset

Expand Down Expand Up @@ -59,3 +60,37 @@ def evaluate(self, results, evaluators, logger=None, topk=(1, 5)):
eval_res = evaluators[0].evaluate(results, gt_labels)

return eval_res

def visualize(self, results, vis_num=10, **kwargs):
"""Visulaize the model output on validation data.
Args:
results: A dictionary containing
class: List of length number of test images.
img_metas: List of length number of test images,
dict of image meta info, containing filename, img_shape,
origin_img_shape and so on.
vis_num: number of images visualized
Returns: A dictionary containing
images: Visulaized images, list of np.ndarray.
img_metas: List of length number of test images,
dict of image meta info, containing filename, img_shape,
origin_img_shape and so on.
"""
vis_imgs = []

# TODO: support img_metas for torch.jit
if results.get('img_metas', None) is None:
return {}

img_metas = results['img_metas'][:vis_num]
labels = results['class']

for i, img_meta in enumerate(img_metas):
filename = img_meta['filename']

vis_img = imshow_label(img=filename, labels=labels, show=False)
vis_imgs.append(vis_img)

output = {'images': vis_imgs, 'img_metas': img_metas}

return output
56 changes: 4 additions & 52 deletions easycv/datasets/detection/mix.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,13 @@
import torch

from easycv.datasets.registry import DATASETS, PIPELINES
from easycv.datasets.shared.base import BaseDataset
from easycv.utils import build_from_cfg
from easycv.utils.bbox_util import batched_xyxy2cxcywh_with_shape
from easycv.utils.bbox_util import xyxy2xywh as xyxy2cxcywh
from .raw import DetDataset


@DATASETS.register_module
class DetImagesMixDataset(BaseDataset):
class DetImagesMixDataset(DetDataset):
"""A wrapper of multiple images mixed dataset.
Suitable for training on multiple images mixed data augmentation like
Expand Down Expand Up @@ -50,7 +49,7 @@ def __init__(self,
label_padding=True):

super(DetImagesMixDataset, self).__init__(
data_source, pipeline, profiling=profiling)
data_source, pipeline, profiling=profiling, classes=classes)

if skip_type_keys is not None:
assert all([
Expand All @@ -70,10 +69,9 @@ def __init__(self,
else:
raise TypeError('pipeline must be a dict')

self.CLASSES = classes
if hasattr(self.data_source, 'flag'):
self.flag = self.data_source.flag
self.num_samples = self.data_source.get_length()

if dynamic_scale is not None:
assert isinstance(dynamic_scale, tuple)

Expand All @@ -83,9 +81,6 @@ def __init__(self,
self.label_padding = label_padding
self.max_labels_num = 120

def __len__(self):
return self.num_samples

def __getitem__(self, idx):
results = copy.deepcopy(self.data_source.get_sample(idx))
for (transform, transform_type) in zip(self.pipeline_yolox,
Expand Down Expand Up @@ -116,21 +111,8 @@ def __getitem__(self, idx):
if 'img_scale' in results:
results.pop('img_scale')

# print(result.keys())

# if self.yolo_format:
# # print(type(results['img_metas']), results['img_metas'])
# # print(type(results['img_metas']._data), results['img_metas']._data)
# img_shape = results['img_metas']._data['img_shape'][:2]
# # print(type(results['gt_bboxes']))
# gt_bboxes = xyxy2cxcywh_with_shape(results['gt_bboxes']._data, img_shape)
# results['gt_bboxes'] = gt_bboxes.float()

if self.label_padding:

cxcywh_gt_bboxes = xyxy2cxcywh(results['gt_bboxes']._data)
# cxcywh_gt_bboxes = results['gt_bboxes']._data

padded_gt_bboxes = torch.zeros((self.max_labels_num, 4),
device=cxcywh_gt_bboxes.device)
padded_gt_bboxes[range(cxcywh_gt_bboxes.shape[0])[:self.max_labels_num]] = \
Expand All @@ -146,9 +128,6 @@ def __getitem__(self, idx):
results['gt_bboxes'] = padded_gt_bboxes
results['gt_labels'] = padded_labels

# ['img_metas', 'img', 'gt_bboxes', 'gt_labels']
# results.pop('img_metas')
# print(results['img_metas'], "hhh", idx)
return results

def update_skip_type_keys(self, skip_type_keys):
Expand Down Expand Up @@ -240,30 +219,3 @@ def format_results(self, results, jsonfile_prefix=None, **kwargs):
tmp_dir = None
result_files = self.results2json(results, jsonfile_prefix)
return result_files, tmp_dir

def evaluate(self, results, evaluators=None, logger=None):
'''results: a dict of list of Tensors, list length equals to number of test images
'''

eval_result = dict()

groundtruth_dict = {}
groundtruth_dict['groundtruth_boxes'] = [
batched_xyxy2cxcywh_with_shape(
self.data_source.get_ann_info(idx)['bboxes'],
results['img_metas'][idx]['ori_img_shape'])
for idx in range(len(results['img_metas']))
]
groundtruth_dict['groundtruth_classes'] = [
self.data_source.get_ann_info(idx)['labels']
for idx in range(len(results['img_metas']))
]
groundtruth_dict['groundtruth_is_crowd'] = [
self.data_source.get_ann_info(idx)['groundtruth_is_crowd']
for idx in range(len(results['img_metas']))
]

for evaluator in evaluators:
eval_result.update(evaluator.evaluate(results, groundtruth_dict))
# print(eval_result)
return eval_result
Loading

0 comments on commit 9a3826f

Please sign in to comment.