diff --git a/src/spatial/detail/ArborX_BruteForceImpl.hpp b/src/spatial/detail/ArborX_BruteForceImpl.hpp index cc5437b1f..d849ea5d3 100644 --- a/src/spatial/detail/ArborX_BruteForceImpl.hpp +++ b/src/spatial/detail/ArborX_BruteForceImpl.hpp @@ -154,7 +154,9 @@ struct BruteForceImpl int const n_indexables = values.size(); int const n_predicates = predicates.size(); - NearestBufferProvider buffer_provider(space, predicates); + using Coordinate = decltype(predicates(0).distance(indexables(0))); + NearestBufferProvider buffer_provider(space, + predicates); Kokkos::parallel_for( "ArborX::BruteForce::query::nearest::" @@ -168,7 +170,7 @@ struct BruteForceImpl return; using PairIndexDistance = - typename NearestBufferProvider::PairIndexDistance; + typename decltype(buffer_provider)::PairIndexDistance; struct CompareDistance { KOKKOS_INLINE_FUNCTION bool diff --git a/src/spatial/detail/ArborX_NearestBufferProvider.hpp b/src/spatial/detail/ArborX_NearestBufferProvider.hpp index 592d6b7ec..1cede7387 100644 --- a/src/spatial/detail/ArborX_NearestBufferProvider.hpp +++ b/src/spatial/detail/ArborX_NearestBufferProvider.hpp @@ -19,17 +19,20 @@ namespace ArborX::Details { -template +template struct NearestBufferProvider { static_assert(Kokkos::is_memory_space_v); - using PairIndexDistance = Kokkos::pair; + using PairIndexDistance = Kokkos::pair; Kokkos::View _buffer; Kokkos::View _offset; - NearestBufferProvider() = default; + NearestBufferProvider() + : _buffer("ArborX::NearestBufferProvider::buffer", 0) + , _offset("ArborX::NearestBufferProvider::offset", 0) + {} template NearestBufferProvider(ExecutionSpace const &space, @@ -46,11 +49,6 @@ struct NearestBufferProvider Kokkos::make_pair(_offset(i), _offset(i + 1))); } - // Enclosing function for an extended __host__ __device__ lambda cannot have - // private or protected access within its class -#ifndef KOKKOS_COMPILER_NVCC -private: -#endif template void allocateBuffer(ExecutionSpace const &space, Predicates const &predicates) { diff --git a/src/spatial/detail/ArborX_TreeTraversal.hpp b/src/spatial/detail/ArborX_TreeTraversal.hpp index b742fb5b8..8564bfafa 100644 --- a/src/spatial/detail/ArborX_TreeTraversal.hpp +++ b/src/spatial/detail/ArborX_TreeTraversal.hpp @@ -128,7 +128,10 @@ struct TreeTraversal Predicates _predicates; Callback _callback; - NearestBufferProvider _buffer; + using Coordinate = decltype(std::declval()(0).distance( + HappyTreeFriends::getIndexable(_bvh, 0))); + + NearestBufferProvider _buffer; template TreeTraversal(ExecutionSpace const &space, BVH const &bvh, @@ -151,7 +154,7 @@ struct TreeTraversal } else { - _buffer = NearestBufferProvider(space, predicates); + _buffer.allocateBuffer(space, predicates); Kokkos::parallel_for("ArborX::TreeTraversal::nearest", Kokkos::RangePolicy(space, 0, predicates.size()), @@ -184,8 +187,7 @@ struct TreeTraversal if (k < 1) return; - using PairIndexDistance = - typename NearestBufferProvider::PairIndexDistance; + using PairIndexDistance = typename decltype(_buffer)::PairIndexDistance; struct CompareDistance { KOKKOS_INLINE_FUNCTION bool operator()(PairIndexDistance const &lhs, @@ -217,7 +219,7 @@ struct TreeTraversal auto *stack_ptr = stack; *stack_ptr++ = SENTINEL; #if !defined(__CUDA_ARCH__) - float stack_distance[64]; + Coordinate stack_distance[64]; auto *stack_distance_ptr = stack_distance; *stack_distance_ptr++ = 0.f; #endif @@ -226,14 +228,14 @@ struct TreeTraversal int left_child; int right_child; - float distance_left = 0.f; - float distance_right = 0.f; - float distance_node = 0.f; + Coordinate distance_left = 0; + Coordinate distance_right = 0; + Coordinate distance_node = 0; // Nodes with a distance that exceed that radius can safely be // discarded. Initialize the radius to infinity and tighten it once k // neighbors have been found. - auto radius = KokkosExt::ArithmeticTraits::infinity::value; + auto radius = KokkosExt::ArithmeticTraits::infinity::value; do {