@@ -14,24 +14,24 @@ def test_simple_gpu(self):
1414 a = torch .tensor ([[[0 , 0 , 0 ], [1 , 0 , 0 ], [2 , 0 , 0 ]], [[0 , 0 , 0 ], [1 , 0 , 0 ], [2 , 0 , 0 ]]]).to (torch .float ).cuda ()
1515 b = torch .tensor ([[[0 , 0 , 0 ]], [[3 , 0 , 0 ]]]).to (torch .float ).cuda ()
1616 idx , dist = ball_query (1.01 , 2 , a , b )
17- torch .testing .assert_allclose (idx .long (). cpu (), torch .tensor ([[[0 , 1 ]], [[2 , 2 ]]]))
17+ torch .testing .assert_allclose (idx .cpu (), torch .tensor ([[[0 , 1 ]], [[2 , 2 ]]]))
1818 torch .testing .assert_allclose (dist .cpu (), torch .tensor ([[[0 , 1 ]], [[1 , - 1 ]]]).float ())
1919
2020 def test_simple_cpu (self ):
2121 a = torch .tensor ([[[0 , 0 , 0 ], [1 , 0 , 0 ], [2 , 0 , 0 ]], [[0 , 0 , 0 ], [1 , 0 , 0 ], [2 , 0 , 0 ]]]).to (torch .float )
2222 b = torch .tensor ([[[0 , 0 , 0 ]], [[3 , 0 , 0 ]]]).to (torch .float )
23- idx , dist = ball_query (1.01 , 2 , a , b )
24- torch .testing .assert_allclose (idx . long () , torch .tensor ([[[0 , 1 ]], [[2 , 2 ]]]))
23+ idx , dist = ball_query (1.01 , 2 , a , b , sort = True )
24+ torch .testing .assert_allclose (idx , torch .tensor ([[[0 , 1 ]], [[2 , 2 ]]]))
2525 torch .testing .assert_allclose (dist , torch .tensor ([[[0 , 1 ]], [[1 , - 1 ]]]).float ())
2626
2727 a = torch .tensor ([[[0 , 0 , 0 ], [1 , 0 , 0 ], [1 , 1 , 0 ]]]).to (torch .float )
28- idx , dist = ball_query (1.01 , 3 , a , a )
29- torch .testing .assert_allclose (idx . long (), torch .tensor ([[[0 , 1 , 0 ],[1 ,0 , 2 ],[2 ,1 , 2 ]]]))
28+ idx , dist = ball_query (1.01 , 3 , a , a , sort = True )
29+ torch .testing .assert_allclose (idx , torch .tensor ([[[0 , 1 , 0 ], [1 , 0 , 2 ], [2 , 1 , 2 ]]]))
3030
3131 @run_if_cuda
3232 def test_larger_gpu (self ):
3333 a = torch .randn (32 , 4096 , 3 ).to (torch .float ).cuda ()
34- idx ,dist = ball_query (1 , 64 , a , a )
34+ idx , dist = ball_query (1 , 64 , a , a )
3535 self .assertGreaterEqual (idx .min (), 0 )
3636
3737 @run_if_cuda
@@ -70,7 +70,7 @@ def test_simple_gpu(self):
7070 dist2 = dist2 .detach ().cpu ().numpy ()
7171
7272 idx_answer = np .asarray ([[1 , - 1 ]])
73- dist2_answer = np .asarray ([[0.0100 , - 1.0000 ]]).astype (np .float32 )
73+ dist2_answer = np .asarray ([[0.100 , - 1.0000 ]]).astype (np .float32 )
7474
7575 npt .assert_array_almost_equal (idx , idx_answer )
7676 npt .assert_array_almost_equal (dist2 , dist2_answer )
@@ -88,7 +88,7 @@ def test_simple_cpu(self):
8888 dist2 = dist2 .detach ().cpu ().numpy ()
8989
9090 idx_answer = np .asarray ([[1 , - 1 ]])
91- dist2_answer = np .asarray ([[0.0100 , - 1.0000 ]]).astype (np .float32 )
91+ dist2_answer = np .asarray ([[0.100 , - 1.0000 ]]).astype (np .float32 )
9292
9393 npt .assert_array_almost_equal (idx , idx_answer )
9494 npt .assert_array_almost_equal (dist2 , dist2_answer )
@@ -100,9 +100,13 @@ def test_random_cpu(self):
100100 batch_b = torch .tensor ([0 for i in range (b .shape [0 ] // 2 )] + [1 for i in range (b .shape [0 ] // 2 , b .shape [0 ])])
101101 R = 1
102102
103- idx , dist = ball_query (R , 15 , a , b , mode = "PARTIAL_DENSE" , batch_x = batch_a , batch_y = batch_b )
104- idx1 , dist = ball_query (R , 15 , a , b , mode = "PARTIAL_DENSE" , batch_x = batch_a , batch_y = batch_b )
103+ idx , dist = ball_query (R , 15 , a , b , mode = "PARTIAL_DENSE" , batch_x = batch_a , batch_y = batch_b , sort = True )
104+ idx1 , dist = ball_query (R , 15 , a , b , mode = "PARTIAL_DENSE" , batch_x = batch_a , batch_y = batch_b , sort = True )
105105 torch .testing .assert_allclose (idx1 , idx )
106+ with self .assertRaises (AssertionError ):
107+ idx , dist = ball_query (R , 15 , a , b , mode = "PARTIAL_DENSE" , batch_x = batch_a , batch_y = batch_b , sort = False )
108+ idx1 , dist = ball_query (R , 15 , a , b , mode = "PARTIAL_DENSE" , batch_x = batch_a , batch_y = batch_b , sort = False )
109+ torch .testing .assert_allclose (idx1 , idx )
106110
107111 self .assertEqual (idx .shape [0 ], b .shape [0 ])
108112 self .assertEqual (dist .shape [0 ], b .shape [0 ])
0 commit comments