diff --git a/core/src/Cabana_Experimental_NeighborList.hpp b/core/src/Cabana_Experimental_NeighborList.hpp index a699f31d1..414e798f1 100644 --- a/core/src/Cabana_Experimental_NeighborList.hpp +++ b/core/src/Cabana_Experimental_NeighborList.hpp @@ -77,7 +77,7 @@ struct Access( x( i, 0 ) ), @@ -114,40 +114,86 @@ namespace Experimental namespace Impl { -// Custom callback for ArborX::BVH::query() template -struct NeighborDiscriminatorCallback; +struct CollisionFilter; + +template <> +struct CollisionFilter +{ + KOKKOS_FUNCTION bool static keep( int i, int j ) noexcept + { + return i != j; // discard self-collision + } +}; template <> -struct NeighborDiscriminatorCallback +struct CollisionFilter +{ + KOKKOS_FUNCTION static bool keep( int i, int j ) noexcept { return i > j; } +}; + +// Custom callback for ArborX::BVH::query() +template +struct NeighborDiscriminatorCallback { using tag = ArborX::Details::InlineCallbackTag; template - KOKKOS_FUNCTION void operator()( Predicate const &predicate, int i, + KOKKOS_FUNCTION void operator()( Predicate const &predicate, + int primitive_index, OutputFunctor const &out ) const { - if ( getData( predicate ) != i ) // discard self-collision + int const predicate_index = getData( predicate ); + if ( CollisionFilter::keep( predicate_index, primitive_index ) ) { - out( i ); + out( primitive_index ); } } }; -template <> -struct NeighborDiscriminatorCallback +// Count in the first pass +template +struct NeighborDiscriminatorCallback2D_FirstPass { + Counts counts; using tag = ArborX::Details::InlineCallbackTag; template - KOKKOS_FUNCTION void operator()( Predicate const &predicate, int i, - OutputFunctor const &out ) const + KOKKOS_FUNCTION void operator()( Predicate const &predicate, + int primitive_index, + OutputFunctor const & ) const + { + int const predicate_index = getData( predicate ); + if ( CollisionFilter::keep( predicate_index, primitive_index ) ) + { + ++counts( predicate_index ); // WARNING see below** + } + } +}; + +// Fill in the second pass +template +struct NeighborDiscriminatorCallback2D_SecondPass +{ + Counts counts; + Neighbors neighbors; + using tag = ArborX::Details::InlineCallbackTag; + template + KOKKOS_FUNCTION void operator()( Predicate const &predicate, + int primitive_index, + OutputFunctor const & ) const { - if ( getData( predicate ) > i ) + int const predicate_index = getData( predicate ); + if ( CollisionFilter::keep( predicate_index, primitive_index ) ) { - out( i ); + assert( counts( predicate_index ) < (int)neighbors.extent( 1 ) ); + neighbors( predicate_index, counts( predicate_index )++ ) = + primitive_index; // WARNING see below** } } }; +// NOTE** Taking advantage of the knowledge that one predicate is processed by a +// single thread. Count increment should be atomic otherwise. + } // namespace Impl template @@ -179,8 +225,59 @@ auto makeNeighborList( Tag, Slice const &coordinate_slice, Impl::makePredicates( coordinate_slice, first, last, radius ), Impl::NeighborDiscriminatorCallback{}, indices, offset ); - return CrsGraph{ - std::move( indices ), std::move( offset ), first, bvh.size()}; + return CrsGraph{std::move( indices ), std::move( offset ), + first, bvh.size()}; +} + +template +struct Dense +{ + Kokkos::View cnt; + Kokkos::View val; + typename MemorySpace::size_type shift; + typename MemorySpace::size_type total; +}; + +template +auto make2DNeighborList( Tag, Slice const &coordinate_slice, + typename Slice::size_type first, + typename Slice::size_type last, + typename Slice::value_type radius ) +{ + using MemorySpace = typename DeviceType::memory_space; + using ExecutionSpace = typename DeviceType::execution_space; + ExecutionSpace space{}; + + ArborX::BVH bvh( space, coordinate_slice ); + + Kokkos::View indices( "indices", 0 ); + Kokkos::View offset( "offset", 0 ); + + auto const predicates = + Impl::makePredicates( coordinate_slice, first, last, radius ); + + auto const n_queries = ArborX::Traits:: + Access::size( + predicates ); + + Kokkos::View counts( "counts", n_queries ); + bvh.query( + space, predicates, + Impl::NeighborDiscriminatorCallback2D_FirstPass{counts}, + indices, offset ); + + Kokkos::View neighbors( + Kokkos::view_alloc( "neighbors", Kokkos::WithoutInitializing ), + n_queries, ArborX::max( space, counts ) ); + Kokkos::deep_copy( counts, 0 ); // reset counts to zero + bvh.query( + space, predicates, + Impl::NeighborDiscriminatorCallback2D_SecondPass< + decltype( counts ), decltype( neighbors ), Tag>{counts, neighbors}, + indices, offset ); + + return Dense{counts, neighbors, first, bvh.size()}; } } // namespace Experimental @@ -205,13 +302,38 @@ class NeighborList> static KOKKOS_FUNCTION size_type getNeighbor( crs_graph_type const &crs_graph, size_type p, size_type n ) { - assert( (int)p >= 0 && p < crs_graph.total ); assert( n < numNeighbor( crs_graph, p ) ); p -= crs_graph.shift; return crs_graph.col_ind( crs_graph.row_ptr( p ) + n ); } }; +template +class NeighborList> +{ + using size_type = std::size_t; + using specialization_type = Experimental::Dense; + + public: + using memory_space = MemorySpace; + static KOKKOS_FUNCTION size_type numNeighbor( specialization_type const &d, + size_type p ) + { + assert( (int)p >= 0 && p < d.total ); + p -= d.shift; + if ( (int)p < 0 || p >= d.cnt.size() ) + return 0; + return d.cnt( p ); + } + static KOKKOS_FUNCTION size_type getNeighbor( specialization_type const &d, + size_type p, size_type n ) + { + assert( n < numNeighbor( d, p ) ); + p -= d.shift; + return d.val( p, n ); + } +}; + } // namespace Cabana #endif diff --git a/core/unit_test/tstNeighborListArborX.hpp b/core/unit_test/tstNeighborListArborX.hpp index ef729a0f2..fe58687ca 100644 --- a/core/unit_test/tstNeighborListArborX.hpp +++ b/core/unit_test/tstNeighborListArborX.hpp @@ -53,13 +53,26 @@ void testArborXListHalf() auto aosoa = createParticles( num_particle, box_min, box_max ); auto position = Cabana::slice<0>( aosoa ); - // Create the neighbor list. - using device_type = TEST_MEMSPACE; // sigh... - auto const nlist = Cabana::Experimental::makeNeighborList( - Cabana::HalfNeighborTag{}, position, 0, aosoa.size(), test_radius ); - - // Check the neighbor list. - checkHalfNeighborList( nlist, position, test_radius ); + { + // Create the neighbor list. + using device_type = TEST_MEMSPACE; // sigh... + auto const nlist = Cabana::Experimental::makeNeighborList( + Cabana::HalfNeighborTag{}, position, 0, aosoa.size(), test_radius ); + + // Check the neighbor list. + checkHalfNeighborList( nlist, position, test_radius ); + } + { + // Create the neighbor list. + using device_type = TEST_MEMSPACE; // sigh... + auto const nlist = + Cabana::Experimental::make2DNeighborList( + Cabana::HalfNeighborTag{}, position, 0, aosoa.size(), + test_radius ); + + // Check the neighbor list. + checkHalfNeighborList( nlist, position, test_radius ); + } } //---------------------------------------------------------------------------// @@ -74,14 +87,28 @@ void testArborXListFullPartialRange() auto aosoa = createParticles( num_particle, box_min, box_max ); auto position = Cabana::slice<0>( aosoa ); - // Create the neighbor list. - using device_type = TEST_MEMSPACE; // sigh... - auto const nlist = Cabana::Experimental::makeNeighborList( - Cabana::FullNeighborTag{}, position, 0, num_ignore, test_radius ); - - // Check the neighbor list. - checkFullNeighborListPartialRange( nlist, position, test_radius, - num_ignore ); + { + // Create the neighbor list. + using device_type = TEST_MEMSPACE; // sigh... + auto const nlist = Cabana::Experimental::makeNeighborList( + Cabana::FullNeighborTag{}, position, 0, num_ignore, test_radius ); + + // Check the neighbor list. + checkFullNeighborListPartialRange( nlist, position, test_radius, + num_ignore ); + } + { + // Create the neighbor list. + using device_type = TEST_MEMSPACE; // sigh... + auto const nlist = + Cabana::Experimental::make2DNeighborList( + Cabana::FullNeighborTag{}, position, 0, num_ignore, + test_radius ); + + // Check the neighbor list. + checkFullNeighborListPartialRange( nlist, position, test_radius, + num_ignore ); + } } //---------------------------------------------------------------------------// @@ -95,14 +122,28 @@ void testNeighborArborXParallelFor() auto aosoa = createParticles( num_particle, box_min, box_max ); auto position = Cabana::slice<0>( aosoa ); - // Create the neighbor list. - using device_type = TEST_MEMSPACE; // sigh... - auto const nlist = Cabana::Experimental::makeNeighborList( - Cabana::FullNeighborTag{}, position, 0, aosoa.size(), test_radius ); + { + // Create the neighbor list. + using device_type = TEST_MEMSPACE; // sigh... + auto const nlist = Cabana::Experimental::makeNeighborList( + Cabana::FullNeighborTag{}, position, 0, aosoa.size(), test_radius ); - checkFirstNeighborParallelFor( nlist, position, test_radius ); + checkFirstNeighborParallelFor( nlist, position, test_radius ); - checkSecondNeighborParallelFor( nlist, position, test_radius ); + checkSecondNeighborParallelFor( nlist, position, test_radius ); + } + { + // Create the neighbor list. + using device_type = TEST_MEMSPACE; // sigh... + auto const nlist = + Cabana::Experimental::make2DNeighborList( + Cabana::FullNeighborTag{}, position, 0, aosoa.size(), + test_radius ); + + checkFirstNeighborParallelFor( nlist, position, test_radius ); + + checkSecondNeighborParallelFor( nlist, position, test_radius ); + } } //---------------------------------------------------------------------------// @@ -116,14 +157,28 @@ void testNeighborArborXParallelReduce() auto aosoa = createParticles( num_particle, box_min, box_max ); auto position = Cabana::slice<0>( aosoa ); - // Create the neighbor list. - using device_type = TEST_MEMSPACE; // sigh... - auto const nlist = Cabana::Experimental::makeNeighborList( - Cabana::FullNeighborTag{}, position, 0, aosoa.size(), test_radius ); + { + // Create the neighbor list. + using device_type = TEST_MEMSPACE; // sigh... + auto const nlist = Cabana::Experimental::makeNeighborList( + Cabana::FullNeighborTag{}, position, 0, aosoa.size(), test_radius ); + + checkFirstNeighborParallelReduce( nlist, aosoa, test_radius ); + + checkSecondNeighborParallelReduce( nlist, aosoa, test_radius ); + } + { + // Create the neighbor list. + using device_type = TEST_MEMSPACE; // sigh... + auto const nlist = + Cabana::Experimental::make2DNeighborList( + Cabana::FullNeighborTag{}, position, 0, aosoa.size(), + test_radius ); - checkFirstNeighborParallelReduce( nlist, aosoa, test_radius ); + checkFirstNeighborParallelReduce( nlist, aosoa, test_radius ); - checkSecondNeighborParallelReduce( nlist, aosoa, test_radius ); + checkSecondNeighborParallelReduce( nlist, aosoa, test_radius ); + } } //---------------------------------------------------------------------------//