@@ -30,7 +30,10 @@ def furthest_point_sample(xyz, npoint):
3030 (B, npoint) tensor containing the set
3131 """
3232 if npoint > xyz .shape [1 ]:
33- raise ValueError ("caanot sample %i points from an input set of %i points" % (npoint , xyz .shape [1 ]))
33+ raise ValueError (
34+ "caanot sample %i points from an input set of %i points"
35+ % (npoint , xyz .shape [1 ])
36+ )
3437 if xyz .is_cuda :
3538 return tpcuda .furthest_point_sampling (xyz , npoint )
3639 else :
@@ -99,9 +102,13 @@ def backward(ctx, grad_out):
99102 idx , weight , m = ctx .three_interpolate_for_backward
100103
101104 if grad_out .is_cuda :
102- grad_features = tpcuda .three_interpolate_grad (grad_out .contiguous (), idx , weight , m )
105+ grad_features = tpcuda .three_interpolate_grad (
106+ grad_out .contiguous (), idx , weight , m
107+ )
103108 else :
104- grad_features = tpcpu .knn_interpolate_grad (grad_out .contiguous (), idx , weight , m )
109+ grad_features = tpcpu .knn_interpolate_grad (
110+ grad_out .contiguous (), idx , weight , m
111+ )
105112
106113 return grad_features , None , None
107114
@@ -143,17 +150,23 @@ def grouping_operation(features, idx):
143150 all_idx = idx .reshape (idx .shape [0 ], - 1 )
144151 all_idx = all_idx .unsqueeze (1 ).repeat (1 , features .shape [1 ], 1 )
145152 grouped_features = features .gather (2 , all_idx )
146- return grouped_features .reshape (idx .shape [0 ], features .shape [1 ], idx .shape [1 ], idx .shape [2 ])
153+ return grouped_features .reshape (
154+ idx .shape [0 ], features .shape [1 ], idx .shape [1 ], idx .shape [2 ]
155+ )
147156
148157
149- def ball_query_dense (radius , nsample , xyz , new_xyz , batch_xyz = None , batch_new_xyz = None , sort = False ):
158+ def ball_query_dense (
159+ radius , nsample , xyz , new_xyz , batch_xyz = None , batch_new_xyz = None , sort = False
160+ ):
150161 # type: (Any, float, int, torch.Tensor, torch.Tensor) -> torch.Tensor
151162 if new_xyz .is_cuda :
152163 if sort :
153164 raise NotImplementedError ("CUDA version does not sort the neighbors" )
154165 ind , dist = tpcuda .ball_query_dense (new_xyz , xyz , radius , nsample )
155166 else :
156- ind , dist = tpcpu .dense_ball_query (new_xyz , xyz , radius , nsample , mode = 0 , sorted = sort )
167+ ind , dist = tpcpu .dense_ball_query (
168+ new_xyz , xyz , radius , nsample , mode = 0 , sorted = sort
169+ )
157170 return ind , dist
158171
159172
@@ -162,9 +175,13 @@ def ball_query_partial_dense(radius, nsample, x, y, batch_x, batch_y, sort=False
162175 if x .is_cuda :
163176 if sort :
164177 raise NotImplementedError ("CUDA version does not sort the neighbors" )
165- ind , dist = tpcuda .ball_query_partial_dense (x , y , batch_x , batch_y , radius , nsample )
178+ ind , dist = tpcuda .ball_query_partial_dense (
179+ x , y , batch_x , batch_y , radius , nsample
180+ )
166181 else :
167- ind , dist = tpcpu .batch_ball_query (x , y , batch_x , batch_y , radius , nsample , mode = 0 , sorted = sort )
182+ ind , dist = tpcpu .batch_ball_query (
183+ x , y , batch_x , batch_y , radius , nsample , mode = 0 , sorted = sort
184+ )
168185 return ind , dist
169186
170187
@@ -207,7 +224,9 @@ def ball_query(
207224 assert x .size (0 ) == batch_x .size (0 )
208225 assert y .size (0 ) == batch_y .size (0 )
209226 assert x .dim () == 2
210- return ball_query_partial_dense (radius , nsample , x , y , batch_x , batch_y , sort = sort )
227+ return ball_query_partial_dense (
228+ radius , nsample , x , y , batch_x , batch_y , sort = sort
229+ )
211230
212231 elif mode .lower () == "dense" :
213232 if (batch_x is not None ) or (batch_y is not None ):
0 commit comments