Skip to content

Commit 240210c

Browse files
oke-adityadatumbox
andauthored
Add utility to draw bounding boxes (#2785)
* initital prototype * flake * Adds documentation * minimal working bboxes * Adds label display * adds colors :-) * adds suggestions and fixes CI * handles image of dim 4 * fixes image handling * removes dev file * adds suggested changes * Updating the API. * Update test. * Implementing code review improvements. * Further refactoring and adding test. * Replace random to white to reduce size and change font on tests. Co-authored-by: Vasilis Vryniotis <vvryniotis@fb.com>
1 parent b3adace commit 240210c

File tree

6 files changed

+86
-11
lines changed

6 files changed

+86
-11
lines changed

docs/source/utils.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,3 +7,4 @@ torchvision.utils
77

88
.. autofunction:: save_image
99

10+
.. autofunction:: draw_bounding_boxes
490 Bytes
Loading

test/common_utils.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import torch
1010
import warnings
1111
import __main__
12+
import random
1213

1314
from numbers import Number
1415
from torch._six import string_classes
@@ -30,6 +31,12 @@ def get_tmp_dir(src=None, **kwargs):
3031
shutil.rmtree(tmp_dir)
3132

3233

34+
def set_rng_seed(seed):
35+
torch.manual_seed(seed)
36+
random.seed(seed)
37+
np.random.seed(seed)
38+
39+
3340
ACCEPT = os.getenv('EXPECTTEST_ACCEPT')
3441
TEST_WITH_SLOW = os.getenv('PYTORCH_TEST_WITH_SLOW', '0') == '1'
3542
# TEST_WITH_SLOW = True # TODO: Delete this line once there is a PYTORCH_TEST_WITH_SLOW aware CI job

test/test_models.py

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,15 @@
1-
from common_utils import TestCase, map_nested_tensor_object, freeze_rng_state
1+
from common_utils import TestCase, map_nested_tensor_object, freeze_rng_state, set_rng_seed
22
from collections import OrderedDict
33
from itertools import product
44
import functools
55
import operator
66
import torch
77
import torch.nn as nn
8-
import numpy as np
98
from torchvision import models
109
import unittest
11-
import random
1210
import warnings
1311

1412

15-
def set_rng_seed(seed):
16-
torch.manual_seed(seed)
17-
random.seed(seed)
18-
np.random.seed(seed)
19-
20-
2113
def get_available_classification_models():
2214
# TODO add a registration mechanism to torchvision.models
2315
return [k for k, v in models.__dict__.items() if callable(v) and k[0].lower() == k[0] and k[0] != "_"]

test/test_utils.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import unittest
77
from io import BytesIO
88
import torchvision.transforms.functional as F
9+
from torchvision.io.image import read_image
910
from PIL import Image
1011

1112

@@ -79,6 +80,21 @@ def test_save_image_single_pixel_file_object(self):
7980
self.assertTrue(torch.equal(F.to_tensor(img_orig), F.to_tensor(img_bytes)),
8081
'Pixel Image not stored in file object')
8182

83+
def test_draw_boxes(self):
84+
img = torch.full((3, 100, 100), 255, dtype=torch.uint8)
85+
boxes = torch.tensor([[0, 0, 20, 20], [0, 0, 0, 0],
86+
[10, 15, 30, 35], [23, 35, 93, 95]], dtype=torch.float)
87+
labels = ["a", "b", "c", "d"]
88+
colors = ["green", "#FF00FF", (0, 255, 0), "red"]
89+
result = utils.draw_bounding_boxes(img, boxes, labels=labels, colors=colors)
90+
91+
path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets", "fakedata", "draw_boxes_util.png")
92+
if not os.path.exists(path):
93+
Image.fromarray(result.permute(1, 2, 0).numpy()).save(path)
94+
95+
expected = read_image(path)
96+
self.assertTrue(torch.equal(result, expected))
97+
8298

8399
if __name__ == '__main__':
84100
unittest.main()

torchvision/utils.py

Lines changed: 61 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,13 @@
11
from typing import Union, Optional, List, Tuple, Text, BinaryIO
2-
import io
32
import pathlib
43
import torch
54
import math
5+
import numpy as np
6+
from PIL import Image, ImageDraw
7+
from PIL import ImageFont
8+
9+
__all__ = ["make_grid", "save_image", "draw_bounding_boxes"]
10+
611
irange = range
712

813

@@ -121,10 +126,64 @@ def save_image(
121126
If a file object was used instead of a filename, this parameter should always be used.
122127
**kwargs: Other arguments are documented in ``make_grid``.
123128
"""
124-
from PIL import Image
125129
grid = make_grid(tensor, nrow=nrow, padding=padding, pad_value=pad_value,
126130
normalize=normalize, range=range, scale_each=scale_each)
127131
# Add 0.5 after unnormalizing to [0, 255] to round to nearest integer
128132
ndarr = grid.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to('cpu', torch.uint8).numpy()
129133
im = Image.fromarray(ndarr)
130134
im.save(fp, format=format)
135+
136+
137+
@torch.no_grad()
138+
def draw_bounding_boxes(
139+
image: torch.Tensor,
140+
boxes: torch.Tensor,
141+
labels: Optional[List[str]] = None,
142+
colors: Optional[List[Union[str, Tuple[int, int, int]]]] = None,
143+
width: int = 1,
144+
font: Optional[str] = None,
145+
font_size: int = 10
146+
) -> torch.Tensor:
147+
148+
"""
149+
Draws bounding boxes on given image.
150+
The values of the input image should be uint8 between 0 and 255.
151+
152+
Args:
153+
image (Tensor): Tensor of shape (C x H x W)
154+
bboxes (Tensor): Tensor of size (N, 4) containing bounding boxes in (xmin, ymin, xmax, ymax) format. Note that
155+
the boxes are absolute coordinates with respect to the image. In other words: `0 <= xmin < xmax < W` and
156+
`0 <= ymin < ymax < H`.
157+
labels (List[str]): List containing the labels of bounding boxes.
158+
colors (List[Union[str, Tuple[int, int, int]]]): List containing the colors of bounding boxes. The colors can
159+
be represented as `str` or `Tuple[int, int, int]`.
160+
width (int): Width of bounding box.
161+
font (str): A filename containing a TrueType font. If the file is not found in this filename, the loader may
162+
also search in other directories, such as the `fonts/` directory on Windows or `/Library/Fonts/`,
163+
`/System/Library/Fonts/` and `~/Library/Fonts/` on macOS.
164+
font_size (int): The requested font size in points.
165+
"""
166+
167+
if not isinstance(image, torch.Tensor):
168+
raise TypeError(f"Tensor expected, got {type(image)}")
169+
elif image.dtype != torch.uint8:
170+
raise ValueError(f"Tensor uint8 expected, got {image.dtype}")
171+
elif image.dim() != 3:
172+
raise ValueError("Pass individual images, not batches")
173+
174+
ndarr = image.permute(1, 2, 0).numpy()
175+
img_to_draw = Image.fromarray(ndarr)
176+
177+
img_boxes = boxes.to(torch.int64).tolist()
178+
179+
draw = ImageDraw.Draw(img_to_draw)
180+
181+
for i, bbox in enumerate(img_boxes):
182+
color = None if colors is None else colors[i]
183+
draw.rectangle(bbox, width=width, outline=color)
184+
185+
if labels is not None:
186+
txt_font = ImageFont.load_default() if font is None else ImageFont.truetype(font=font, size=font_size)
187+
draw.text((bbox[0], bbox[1]), labels[i], fill=color, font=txt_font)
188+
189+
return torch.from_numpy(np.array(img_to_draw)).permute(2, 0, 1)

0 commit comments

Comments
 (0)