@@ -15,7 +15,7 @@ std::pair<at::Tensor, at::Tensor> ball_query(at::Tensor support, at::Tensor quer
1515
1616    at::Tensor out;
1717    at::Tensor out_dists;
18-     std::vector<long > neighbors_indices (query.size (0 ), 0 );
18+     std::vector<int64_t > neighbors_indices (query.size (0 ), 0 );
1919    std::vector<float > neighbors_dists (query.size (0 ), -1 );
2020
2121    auto  options = torch::TensorOptions ().dtype (torch::kLong ).device (torch::kCPU );
@@ -34,7 +34,7 @@ std::pair<at::Tensor, at::Tensor> ball_query(at::Tensor support, at::Tensor quer
3434                                                  neighbors_dists, radius, max_num, mode, sorted);
3535    });
3636    auto  neighbors_dists_ptr = neighbors_dists.data ();
37-     long * neighbors_indices_ptr = neighbors_indices.data ();
37+     int64_t * neighbors_indices_ptr = neighbors_indices.data ();
3838    if  (mode == 0 )
3939    {
4040        out =
@@ -73,7 +73,7 @@ std::pair<at::Tensor, at::Tensor> batch_ball_query(at::Tensor support, at::Tenso
7373    at::Tensor idx;
7474
7575    at::Tensor dist;
76-     std::vector<long > neighbors_indices;
76+     std::vector<int64_t > neighbors_indices;
7777    std::vector<float > neighbors_dists;
7878
7979    auto  options = torch::TensorOptions ().dtype (torch::kLong ).device (torch::kCPU );
@@ -91,10 +91,11 @@ std::pair<at::Tensor, at::Tensor> batch_ball_query(at::Tensor support, at::Tenso
9191    query_batch = at::cat ({at::zeros (1 , query_batch.options ()), query_batch.cumsum (0 )}, 0 );
9292    support_batch = degree (support_batch, batch_size);
9393    support_batch = at::cat ({at::zeros (1 , support_batch.options ()), support_batch.cumsum (0 )}, 0 );
94-     std::vector<long > query_batch_stl (query_batch.DATA_PTR <long >(),
95-                                       query_batch.DATA_PTR <long >() + query_batch.numel ());
96-     std::vector<long > support_batch_stl (support_batch.DATA_PTR <long >(),
97-                                         support_batch.DATA_PTR <long >() + support_batch.numel ());
94+     std::vector<int64_t > query_batch_stl (query_batch.DATA_PTR <int64_t >(),
95+                                          query_batch.DATA_PTR <int64_t >() + query_batch.numel ());
96+     std::vector<int64_t > support_batch_stl (support_batch.DATA_PTR <int64_t >(),
97+                                            support_batch.DATA_PTR <int64_t >() +
98+                                                support_batch.numel ());
9899
99100    AT_DISPATCH_ALL_TYPES (query.scalar_type (), " batch_radius_search"  , [&] {
100101        std::vector<scalar_t > queries_stl (query.DATA_PTR <scalar_t >(),
@@ -107,7 +108,7 @@ std::pair<at::Tensor, at::Tensor> batch_ball_query(at::Tensor support, at::Tenso
107108            neighbors_dists, radius, max_num, mode, sorted);
108109    });
109110    auto  neighbors_dists_ptr = neighbors_dists.data ();
110-     long * neighbors_indices_ptr = neighbors_indices.data ();
111+     int64_t * neighbors_indices_ptr = neighbors_indices.data ();
111112
112113    if  (mode == 0 )
113114    {
0 commit comments