@@ -397,42 +397,51 @@ def flow_to_image(flow: torch.Tensor) -> torch.Tensor:
397397 Converts a flow to an RGB image.
398398
399399 Args:
400- flow (Tensor): Flow of shape (2, H, W) and dtype torch.float.
400+ flow (Tensor): Flow of shape (N, 2, H, W) or ( 2, H, W) and dtype torch.float.
401401
402402 Returns:
403- img (Tensor(3, H, W)): Image Tensor of dtype uint8 where each color corresponds to a given flow direction.
403+ img (Tensor): Image Tensor of dtype uint8 where each color corresponds
404+ to a given flow direction. Shape is (N, 3, H, W) or (3, H, W) depending on the input.
404405 """
405406
406407 if flow .dtype != torch .float :
407408 raise ValueError (f"Flow should be of dtype torch.float, got { flow .dtype } ." )
408409
409- if flow .ndim != 3 or flow .size (0 ) != 2 :
410- raise ValueError (f"Input flow should have shape (2, H, W), got { flow .shape } ." )
410+ orig_shape = flow .shape
411+ if flow .ndim == 3 :
412+ flow = flow [None ] # Add batch dim
411413
412- max_norm = torch .sum (flow ** 2 , dim = 0 ).sqrt ().max ()
414+ if flow .ndim != 4 or flow .shape [1 ] != 2 :
415+ raise ValueError (f"Input flow should have shape (2, H, W) or (N, 2, H, W), got { orig_shape } ." )
416+
417+ max_norm = torch .sum (flow ** 2 , dim = 1 ).sqrt ().max ()
413418 epsilon = torch .finfo ((flow ).dtype ).eps
414419 normalized_flow = flow / (max_norm + epsilon )
415- return _normalized_flow_to_image (normalized_flow )
420+ img = _normalized_flow_to_image (normalized_flow )
421+
422+ if len (orig_shape ) == 3 :
423+ img = img [0 ] # Remove batch dim
424+ return img
416425
417426
418427@torch .no_grad ()
419428def _normalized_flow_to_image (normalized_flow : torch .Tensor ) -> torch .Tensor :
420429
421430 """
422- Converts a normalized flow to an RGB image.
431+ Converts a batch of normalized flow to an RGB image.
423432
424433 Args:
425- normalized_flow (torch.Tensor): Normalized flow tensor of shape (2, H, W)
434+ normalized_flow (torch.Tensor): Normalized flow tensor of shape (N, 2, H, W)
426435 Returns:
427- img (Tensor(3, H, W)): Flow visualization image of dtype uint8.
436+ img (Tensor(N, 3, H, W)): Flow visualization image of dtype uint8.
428437 """
429438
430- _ , H , W = normalized_flow .shape
431- flow_image = torch .zeros ((3 , H , W ), dtype = torch .uint8 )
439+ N , _ , H , W = normalized_flow .shape
440+ flow_image = torch .zeros ((N , 3 , H , W ), dtype = torch .uint8 )
432441 colorwheel = _make_colorwheel () # shape [55x3]
433442 num_cols = colorwheel .shape [0 ]
434- norm = torch .sum (normalized_flow ** 2 , dim = 0 ).sqrt ()
435- a = torch .atan2 (- normalized_flow [1 ], - normalized_flow [0 ]) / torch .pi
443+ norm = torch .sum (normalized_flow ** 2 , dim = 1 ).sqrt ()
444+ a = torch .atan2 (- normalized_flow [:, 1 , :, : ], - normalized_flow [:, 0 , :, : ]) / torch .pi
436445 fk = (a + 1 ) / 2 * (num_cols - 1 )
437446 k0 = torch .floor (fk ).to (torch .long )
438447 k1 = k0 + 1
@@ -445,7 +454,7 @@ def _normalized_flow_to_image(normalized_flow: torch.Tensor) -> torch.Tensor:
445454 col1 = tmp [k1 ] / 255.0
446455 col = (1 - f ) * col0 + f * col1
447456 col = 1 - norm * (1 - col )
448- flow_image [c , :, :] = torch .floor (255 * col )
457+ flow_image [:, c , :, :] = torch .floor (255 * col )
449458 return flow_image
450459
451460
0 commit comments