Skip to content

Commit

Permalink
Port test/test_utils.py to pytest (#3917)
Browse files Browse the repository at this point in the history
  • Loading branch information
zhiqwang authored May 25, 2021
1 parent 1b6fe68 commit 154283b
Showing 1 changed file with 124 additions and 115 deletions.
239 changes: 124 additions & 115 deletions test/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import tempfile
import torch
import torchvision.utils as utils
import unittest

from io import BytesIO
import torchvision.transforms.functional as F
from PIL import Image, __version__ as PILLOW_VERSION, ImageColor
Expand All @@ -18,122 +18,131 @@
[10, 15, 30, 35], [23, 35, 93, 95]], dtype=torch.float)


class Tester(unittest.TestCase):

def test_make_grid_not_inplace(self):
t = torch.rand(5, 3, 10, 10)
t_clone = t.clone()

utils.make_grid(t, normalize=False)
assert_equal(t, t_clone, msg='make_grid modified tensor in-place')

utils.make_grid(t, normalize=True, scale_each=False)
assert_equal(t, t_clone, msg='make_grid modified tensor in-place')

utils.make_grid(t, normalize=True, scale_each=True)
assert_equal(t, t_clone, msg='make_grid modified tensor in-place')

def test_normalize_in_make_grid(self):
t = torch.rand(5, 3, 10, 10) * 255
norm_max = torch.tensor(1.0)
norm_min = torch.tensor(0.0)

grid = utils.make_grid(t, normalize=True)
grid_max = torch.max(grid)
grid_min = torch.min(grid)

# Rounding the result to one decimal for comparison
n_digits = 1
rounded_grid_max = torch.round(grid_max * 10 ** n_digits) / (10 ** n_digits)
rounded_grid_min = torch.round(grid_min * 10 ** n_digits) / (10 ** n_digits)

assert_equal(norm_max, rounded_grid_max, msg='Normalized max is not equal to 1')
assert_equal(norm_min, rounded_grid_min, msg='Normalized min is not equal to 0')

@unittest.skipIf(sys.platform in ('win32', 'cygwin'), 'temporarily disabled on Windows')
def test_save_image(self):
with tempfile.NamedTemporaryFile(suffix='.png') as f:
t = torch.rand(2, 3, 64, 64)
utils.save_image(t, f.name)
self.assertTrue(os.path.exists(f.name), 'The image is not present after save')

@unittest.skipIf(sys.platform in ('win32', 'cygwin'), 'temporarily disabled on Windows')
def test_save_image_single_pixel(self):
with tempfile.NamedTemporaryFile(suffix='.png') as f:
t = torch.rand(1, 3, 1, 1)
utils.save_image(t, f.name)
self.assertTrue(os.path.exists(f.name), 'The pixel image is not present after save')

@unittest.skipIf(sys.platform in ('win32', 'cygwin'), 'temporarily disabled on Windows')
def test_save_image_file_object(self):
with tempfile.NamedTemporaryFile(suffix='.png') as f:
t = torch.rand(2, 3, 64, 64)
utils.save_image(t, f.name)
img_orig = Image.open(f.name)
fp = BytesIO()
utils.save_image(t, fp, format='png')
img_bytes = Image.open(fp)
assert_equal(F.to_tensor(img_orig), F.to_tensor(img_bytes), msg='Image not stored in file object')

@unittest.skipIf(sys.platform in ('win32', 'cygwin'), 'temporarily disabled on Windows')
def test_save_image_single_pixel_file_object(self):
with tempfile.NamedTemporaryFile(suffix='.png') as f:
t = torch.rand(1, 3, 1, 1)
utils.save_image(t, f.name)
img_orig = Image.open(f.name)
fp = BytesIO()
utils.save_image(t, fp, format='png')
img_bytes = Image.open(fp)
assert_equal(F.to_tensor(img_orig), F.to_tensor(img_bytes), msg='Image not stored in file object')

def test_draw_boxes(self):
img = torch.full((3, 100, 100), 255, dtype=torch.uint8)
img_cp = img.clone()
boxes_cp = boxes.clone()
labels = ["a", "b", "c", "d"]
colors = ["green", "#FF00FF", (0, 255, 0), "red"]
result = utils.draw_bounding_boxes(img, boxes, labels=labels, colors=colors, fill=True)

path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets", "fakedata", "draw_boxes_util.png")
if not os.path.exists(path):
res = Image.fromarray(result.permute(1, 2, 0).contiguous().numpy())
res.save(path)

if PILLOW_VERSION >= (8, 2):
# The reference image is only valid for new PIL versions
expected = torch.as_tensor(np.array(Image.open(path))).permute(2, 0, 1)
assert_equal(result, expected)

# Check if modification is not in place
assert_equal(boxes, boxes_cp)
assert_equal(img, img_cp)

def test_draw_boxes_vanilla(self):
img = torch.full((3, 100, 100), 0, dtype=torch.uint8)
img_cp = img.clone()
boxes_cp = boxes.clone()
result = utils.draw_bounding_boxes(img, boxes, fill=False, width=7)

path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets", "fakedata", "draw_boxes_vanilla.png")
if not os.path.exists(path):
res = Image.fromarray(result.permute(1, 2, 0).contiguous().numpy())
res.save(path)
def test_make_grid_not_inplace():
t = torch.rand(5, 3, 10, 10)
t_clone = t.clone()

utils.make_grid(t, normalize=False)
assert_equal(t, t_clone, msg='make_grid modified tensor in-place')

utils.make_grid(t, normalize=True, scale_each=False)
assert_equal(t, t_clone, msg='make_grid modified tensor in-place')

utils.make_grid(t, normalize=True, scale_each=True)
assert_equal(t, t_clone, msg='make_grid modified tensor in-place')


def test_normalize_in_make_grid():
t = torch.rand(5, 3, 10, 10) * 255
norm_max = torch.tensor(1.0)
norm_min = torch.tensor(0.0)

grid = utils.make_grid(t, normalize=True)
grid_max = torch.max(grid)
grid_min = torch.min(grid)

# Rounding the result to one decimal for comparison
n_digits = 1
rounded_grid_max = torch.round(grid_max * 10 ** n_digits) / (10 ** n_digits)
rounded_grid_min = torch.round(grid_min * 10 ** n_digits) / (10 ** n_digits)

assert_equal(norm_max, rounded_grid_max, msg='Normalized max is not equal to 1')
assert_equal(norm_min, rounded_grid_min, msg='Normalized min is not equal to 0')


@pytest.mark.skipif(sys.platform in ('win32', 'cygwin'), reason='temporarily disabled on Windows')
def test_save_image():
with tempfile.NamedTemporaryFile(suffix='.png') as f:
t = torch.rand(2, 3, 64, 64)
utils.save_image(t, f.name)
assert os.path.exists(f.name), 'The image is not present after save'


@pytest.mark.skipif(sys.platform in ('win32', 'cygwin'), reason='temporarily disabled on Windows')
def test_save_image_single_pixel():
with tempfile.NamedTemporaryFile(suffix='.png') as f:
t = torch.rand(1, 3, 1, 1)
utils.save_image(t, f.name)
assert os.path.exists(f.name), 'The pixel image is not present after save'


@pytest.mark.skipif(sys.platform in ('win32', 'cygwin'), reason='temporarily disabled on Windows')
def test_save_image_file_object():
with tempfile.NamedTemporaryFile(suffix='.png') as f:
t = torch.rand(2, 3, 64, 64)
utils.save_image(t, f.name)
img_orig = Image.open(f.name)
fp = BytesIO()
utils.save_image(t, fp, format='png')
img_bytes = Image.open(fp)
assert_equal(F.to_tensor(img_orig), F.to_tensor(img_bytes), msg='Image not stored in file object')


@pytest.mark.skipif(sys.platform in ('win32', 'cygwin'), reason='temporarily disabled on Windows')
def test_save_image_single_pixel_file_object():
with tempfile.NamedTemporaryFile(suffix='.png') as f:
t = torch.rand(1, 3, 1, 1)
utils.save_image(t, f.name)
img_orig = Image.open(f.name)
fp = BytesIO()
utils.save_image(t, fp, format='png')
img_bytes = Image.open(fp)
assert_equal(F.to_tensor(img_orig), F.to_tensor(img_bytes), msg='Image not stored in file object')


def test_draw_boxes():
img = torch.full((3, 100, 100), 255, dtype=torch.uint8)
img_cp = img.clone()
boxes_cp = boxes.clone()
labels = ["a", "b", "c", "d"]
colors = ["green", "#FF00FF", (0, 255, 0), "red"]
result = utils.draw_bounding_boxes(img, boxes, labels=labels, colors=colors, fill=True)

path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets", "fakedata", "draw_boxes_util.png")
if not os.path.exists(path):
res = Image.fromarray(result.permute(1, 2, 0).contiguous().numpy())
res.save(path)

if PILLOW_VERSION >= (8, 2):
# The reference image is only valid for new PIL versions
expected = torch.as_tensor(np.array(Image.open(path))).permute(2, 0, 1)
assert_equal(result, expected)
# Check if modification is not in place
assert_equal(boxes, boxes_cp)
assert_equal(img, img_cp)

def test_draw_invalid_boxes(self):
img_tp = ((1, 1, 1), (1, 2, 3))
img_wrong1 = torch.full((3, 5, 5), 255, dtype=torch.float)
img_wrong2 = torch.full((1, 3, 5, 5), 255, dtype=torch.uint8)
boxes = torch.tensor([[0, 0, 20, 20], [0, 0, 0, 0],
[10, 15, 30, 35], [23, 35, 93, 95]], dtype=torch.float)
self.assertRaises(TypeError, utils.draw_bounding_boxes, img_tp, boxes)
self.assertRaises(ValueError, utils.draw_bounding_boxes, img_wrong1, boxes)
self.assertRaises(ValueError, utils.draw_bounding_boxes, img_wrong2, boxes)
# Check if modification is not in place
assert_equal(boxes, boxes_cp)
assert_equal(img, img_cp)


def test_draw_boxes_vanilla():
img = torch.full((3, 100, 100), 0, dtype=torch.uint8)
img_cp = img.clone()
boxes_cp = boxes.clone()
result = utils.draw_bounding_boxes(img, boxes, fill=False, width=7)

path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets", "fakedata", "draw_boxes_vanilla.png")
if not os.path.exists(path):
res = Image.fromarray(result.permute(1, 2, 0).contiguous().numpy())
res.save(path)

expected = torch.as_tensor(np.array(Image.open(path))).permute(2, 0, 1)
assert_equal(result, expected)
# Check if modification is not in place
assert_equal(boxes, boxes_cp)
assert_equal(img, img_cp)


def test_draw_invalid_boxes():
img_tp = ((1, 1, 1), (1, 2, 3))
img_wrong1 = torch.full((3, 5, 5), 255, dtype=torch.float)
img_wrong2 = torch.full((1, 3, 5, 5), 255, dtype=torch.uint8)
boxes = torch.tensor([[0, 0, 20, 20], [0, 0, 0, 0],
[10, 15, 30, 35], [23, 35, 93, 95]], dtype=torch.float)
with pytest.raises(TypeError, match="Tensor expected"):
utils.draw_bounding_boxes(img_tp, boxes)
with pytest.raises(ValueError, match="Tensor uint8 expected"):
utils.draw_bounding_boxes(img_wrong1, boxes)
with pytest.raises(ValueError, match="Pass individual images, not batches"):
utils.draw_bounding_boxes(img_wrong2, boxes)


@pytest.mark.parametrize('colors', [
Expand Down Expand Up @@ -218,5 +227,5 @@ def test_draw_segmentation_masks_errors():
utils.draw_segmentation_masks(image=img, masks=masks, colors=bad_colors)


if __name__ == '__main__':
unittest.main()
if __name__ == "__main__":
pytest.main([__file__])

0 comments on commit 154283b

Please sign in to comment.