@@ -317,29 +317,42 @@ def test_draw_keypoints_errors():
317317 utils .draw_keypoints (image = img , keypoints = invalid_keypoints )
318318
319319
320- def test_flow_to_image ():
320+ @pytest .mark .parametrize ("batch" , (True , False ))
321+ def test_flow_to_image (batch ):
321322 h , w = 100 , 100
322323 flow = torch .meshgrid (torch .arange (h ), torch .arange (w ), indexing = "ij" )
323324 flow = torch .stack (flow [::- 1 ], dim = 0 ).float ()
324325 flow [0 ] -= h / 2
325326 flow [1 ] -= w / 2
327+
328+ if batch :
329+ flow = torch .stack ([flow , flow ])
330+
326331 img = utils .flow_to_image (flow )
332+ assert img .shape == (2 , 3 , h , w ) if batch else (3 , h , w )
333+
327334 path = os .path .join (os .path .dirname (os .path .abspath (__file__ )), "assets" , "expected_flow.pt" )
328335 expected_img = torch .load (path , map_location = "cpu" )
329- assert_equal (expected_img , img )
330336
337+ if batch :
338+ expected_img = torch .stack ([expected_img , expected_img ])
339+
340+ assert_equal (expected_img , img )
331341
332- def test_flow_to_image_errors ():
333- wrong_flow1 = torch .full ((3 , 10 , 10 ), 0 , dtype = torch .float )
334- wrong_flow2 = torch .full ((2 , 10 ), 0 , dtype = torch .float )
335- wrong_flow3 = torch .full ((2 , 10 , 30 ), 0 , dtype = torch .int )
336342
337- with pytest .raises (ValueError , match = "Input flow should have shape" ):
338- utils .flow_to_image (flow = wrong_flow1 )
339- with pytest .raises (ValueError , match = "Input flow should have shape" ):
340- utils .flow_to_image (flow = wrong_flow2 )
341- with pytest .raises (ValueError , match = "Flow should be of dtype torch.float" ):
342- utils .flow_to_image (flow = wrong_flow3 )
343+ @pytest .mark .parametrize (
344+ "input_flow, match" ,
345+ (
346+ (torch .full ((3 , 10 , 10 ), 0 , dtype = torch .float ), "Input flow should have shape" ),
347+ (torch .full ((5 , 3 , 10 , 10 ), 0 , dtype = torch .float ), "Input flow should have shape" ),
348+ (torch .full ((2 , 10 ), 0 , dtype = torch .float ), "Input flow should have shape" ),
349+ (torch .full ((5 , 2 , 10 ), 0 , dtype = torch .float ), "Input flow should have shape" ),
350+ (torch .full ((2 , 10 , 30 ), 0 , dtype = torch .int ), "Flow should be of dtype torch.float" ),
351+ ),
352+ )
353+ def test_flow_to_image_errors (input_flow , match ):
354+ with pytest .raises (ValueError , match = match ):
355+ utils .flow_to_image (flow = input_flow )
343356
344357
345358if __name__ == "__main__" :
0 commit comments