Skip to content

Commit 9fa8000

Browse files
authored
Add support for flow batches in flow_to_image (#5308)
1 parent 8e874ff commit 9fa8000

File tree

2 files changed

+48
-26
lines changed

2 files changed

+48
-26
lines changed

test/test_utils.py

Lines changed: 25 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -317,29 +317,42 @@ def test_draw_keypoints_errors():
317317
utils.draw_keypoints(image=img, keypoints=invalid_keypoints)
318318

319319

320-
def test_flow_to_image():
320+
@pytest.mark.parametrize("batch", (True, False))
321+
def test_flow_to_image(batch):
321322
h, w = 100, 100
322323
flow = torch.meshgrid(torch.arange(h), torch.arange(w), indexing="ij")
323324
flow = torch.stack(flow[::-1], dim=0).float()
324325
flow[0] -= h / 2
325326
flow[1] -= w / 2
327+
328+
if batch:
329+
flow = torch.stack([flow, flow])
330+
326331
img = utils.flow_to_image(flow)
332+
assert img.shape == (2, 3, h, w) if batch else (3, h, w)
333+
327334
path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets", "expected_flow.pt")
328335
expected_img = torch.load(path, map_location="cpu")
329-
assert_equal(expected_img, img)
330336

337+
if batch:
338+
expected_img = torch.stack([expected_img, expected_img])
339+
340+
assert_equal(expected_img, img)
331341

332-
def test_flow_to_image_errors():
333-
wrong_flow1 = torch.full((3, 10, 10), 0, dtype=torch.float)
334-
wrong_flow2 = torch.full((2, 10), 0, dtype=torch.float)
335-
wrong_flow3 = torch.full((2, 10, 30), 0, dtype=torch.int)
336342

337-
with pytest.raises(ValueError, match="Input flow should have shape"):
338-
utils.flow_to_image(flow=wrong_flow1)
339-
with pytest.raises(ValueError, match="Input flow should have shape"):
340-
utils.flow_to_image(flow=wrong_flow2)
341-
with pytest.raises(ValueError, match="Flow should be of dtype torch.float"):
342-
utils.flow_to_image(flow=wrong_flow3)
343+
@pytest.mark.parametrize(
344+
"input_flow, match",
345+
(
346+
(torch.full((3, 10, 10), 0, dtype=torch.float), "Input flow should have shape"),
347+
(torch.full((5, 3, 10, 10), 0, dtype=torch.float), "Input flow should have shape"),
348+
(torch.full((2, 10), 0, dtype=torch.float), "Input flow should have shape"),
349+
(torch.full((5, 2, 10), 0, dtype=torch.float), "Input flow should have shape"),
350+
(torch.full((2, 10, 30), 0, dtype=torch.int), "Flow should be of dtype torch.float"),
351+
),
352+
)
353+
def test_flow_to_image_errors(input_flow, match):
354+
with pytest.raises(ValueError, match=match):
355+
utils.flow_to_image(flow=input_flow)
343356

344357

345358
if __name__ == "__main__":

torchvision/utils.py

Lines changed: 23 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -397,42 +397,51 @@ def flow_to_image(flow: torch.Tensor) -> torch.Tensor:
397397
Converts a flow to an RGB image.
398398
399399
Args:
400-
flow (Tensor): Flow of shape (2, H, W) and dtype torch.float.
400+
flow (Tensor): Flow of shape (N, 2, H, W) or (2, H, W) and dtype torch.float.
401401
402402
Returns:
403-
img (Tensor(3, H, W)): Image Tensor of dtype uint8 where each color corresponds to a given flow direction.
403+
img (Tensor): Image Tensor of dtype uint8 where each color corresponds
404+
to a given flow direction. Shape is (N, 3, H, W) or (3, H, W) depending on the input.
404405
"""
405406

406407
if flow.dtype != torch.float:
407408
raise ValueError(f"Flow should be of dtype torch.float, got {flow.dtype}.")
408409

409-
if flow.ndim != 3 or flow.size(0) != 2:
410-
raise ValueError(f"Input flow should have shape (2, H, W), got {flow.shape}.")
410+
orig_shape = flow.shape
411+
if flow.ndim == 3:
412+
flow = flow[None] # Add batch dim
411413

412-
max_norm = torch.sum(flow ** 2, dim=0).sqrt().max()
414+
if flow.ndim != 4 or flow.shape[1] != 2:
415+
raise ValueError(f"Input flow should have shape (2, H, W) or (N, 2, H, W), got {orig_shape}.")
416+
417+
max_norm = torch.sum(flow ** 2, dim=1).sqrt().max()
413418
epsilon = torch.finfo((flow).dtype).eps
414419
normalized_flow = flow / (max_norm + epsilon)
415-
return _normalized_flow_to_image(normalized_flow)
420+
img = _normalized_flow_to_image(normalized_flow)
421+
422+
if len(orig_shape) == 3:
423+
img = img[0] # Remove batch dim
424+
return img
416425

417426

418427
@torch.no_grad()
419428
def _normalized_flow_to_image(normalized_flow: torch.Tensor) -> torch.Tensor:
420429

421430
"""
422-
Converts a normalized flow to an RGB image.
431+
Converts a batch of normalized flow to an RGB image.
423432
424433
Args:
425-
normalized_flow (torch.Tensor): Normalized flow tensor of shape (2, H, W)
434+
normalized_flow (torch.Tensor): Normalized flow tensor of shape (N, 2, H, W)
426435
Returns:
427-
img (Tensor(3, H, W)): Flow visualization image of dtype uint8.
436+
img (Tensor(N, 3, H, W)): Flow visualization image of dtype uint8.
428437
"""
429438

430-
_, H, W = normalized_flow.shape
431-
flow_image = torch.zeros((3, H, W), dtype=torch.uint8)
439+
N, _, H, W = normalized_flow.shape
440+
flow_image = torch.zeros((N, 3, H, W), dtype=torch.uint8)
432441
colorwheel = _make_colorwheel() # shape [55x3]
433442
num_cols = colorwheel.shape[0]
434-
norm = torch.sum(normalized_flow ** 2, dim=0).sqrt()
435-
a = torch.atan2(-normalized_flow[1], -normalized_flow[0]) / torch.pi
443+
norm = torch.sum(normalized_flow ** 2, dim=1).sqrt()
444+
a = torch.atan2(-normalized_flow[:, 1, :, :], -normalized_flow[:, 0, :, :]) / torch.pi
436445
fk = (a + 1) / 2 * (num_cols - 1)
437446
k0 = torch.floor(fk).to(torch.long)
438447
k1 = k0 + 1
@@ -445,7 +454,7 @@ def _normalized_flow_to_image(normalized_flow: torch.Tensor) -> torch.Tensor:
445454
col1 = tmp[k1] / 255.0
446455
col = (1 - f) * col0 + f * col1
447456
col = 1 - norm * (1 - col)
448-
flow_image[c, :, :] = torch.floor(255 * col)
457+
flow_image[:, c, :, :] = torch.floor(255 * col)
449458
return flow_image
450459

451460

0 commit comments

Comments
 (0)