@@ -63,7 +63,7 @@ def test_simple_gpu(self):
6363 idx = idx .detach ().cpu ().numpy ()
6464 dist2 = dist2 .detach ().cpu ().numpy ()
6565
66- idx_answer = np .asarray ([[1 , 4 ]])
66+ idx_answer = np .asarray ([[1 , - 1 ]])
6767 dist2_answer = np .asarray ([[0.0100 , - 1.0000 ]]).astype (np .float32 )
6868
6969 npt .assert_array_almost_equal (idx , idx_answer )
@@ -81,7 +81,7 @@ def test_simple_cpu(self):
8181 idx = idx .detach ().cpu ().numpy ()
8282 dist2 = dist2 .detach ().cpu ().numpy ()
8383
84- idx_answer = np .asarray ([[1 , 4 ]])
84+ idx_answer = np .asarray ([[1 , - 1 ]])
8585 dist2_answer = np .asarray ([[0.0100 , - 1.0000 ]]).astype (np .float32 )
8686
8787 npt .assert_array_almost_equal (idx , idx_answer )
@@ -95,6 +95,9 @@ def test_random_cpu(self):
9595 R = 1
9696
9797 idx , dist = ball_query (R , 15 , a , b , mode = "PARTIAL_DENSE" , batch_x = batch_a , batch_y = batch_b )
98+ idx1 , dist = ball_query (R , 15 , a , b , mode = "PARTIAL_DENSE" , batch_x = batch_a , batch_y = batch_b )
99+ torch .testing .assert_allclose (idx1 , idx )
100+
98101 self .assertEqual (idx .shape [0 ], b .shape [0 ])
99102 self .assertEqual (dist .shape [0 ], b .shape [0 ])
100103 self .assertLessEqual (idx .max ().item (), len (batch_a ))
@@ -104,7 +107,7 @@ def test_random_cpu(self):
104107 idx3_sk = tree .query_radius (b .detach ().numpy (), r = R )
105108 i = np .random .randint (len (batch_b ))
106109 for p in idx [i ].detach ().numpy ():
107- if p < len (batch_a ):
110+ if p >= 0 and p < len (batch_a ):
108111 assert p in idx3_sk [i ]
109112
110113
0 commit comments