Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor TrackSortUtils #1047

Merged
merged 4 commits into from
Dec 5, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 33 additions & 56 deletions src/celeritas/track/detail/TrackSortUtils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ using ThreadItems
using TrackSlots = ThreadItems<TrackSlotId::size_type>;

template<class F>
void partition_impl(TrackSlots const& track_slots, F&& func, StreamId)
void partition_impl(TrackSlots const& track_slots, F&& func)
{
auto* start = track_slots.data().get();
std::partition(start, start + track_slots.size(), std::forward<F>(func));
Expand All @@ -38,43 +38,12 @@ void partition_impl(TrackSlots const& track_slots, F&& func, StreamId)
//---------------------------------------------------------------------------//

template<class F>
void sort_impl(TrackSlots const& track_slots, F&& func, StreamId)
void sort_impl(TrackSlots const& track_slots, F&& func)
{
auto* start = track_slots.data().get();
std::sort(start, start + track_slots.size(), std::forward<F>(func));
}

// PRE: get_action is sorted, i.e. i <= j ==> get_action(i) <=
// get_action(j)
template<class F>
void count_tracks_per_action_impl(Span<ThreadId> offsets,
size_type size,
F&& get_action)
{
std::fill(offsets.begin(), offsets.end(), ThreadId{});

// if get_action(0) != get_action(1), get_action(0) never gets initialized
#pragma omp parallel for
for (size_type i = 1; i < size; ++i)
{
ActionId current_action = get_action(ThreadId{i});
if (!current_action)
continue;

if (current_action != get_action(ThreadId{i - 1}))
{
offsets[current_action.unchecked_get()] = ThreadId{i};
}
}

// so make sure get_action(0) is initialized
if (ActionId first = get_action(ThreadId{0}))
{
offsets[first.unchecked_get()] = ThreadId{0};
}
backfill_action_count(offsets, size);
}

//---------------------------------------------------------------------------//
} // namespace

Expand Down Expand Up @@ -111,18 +80,15 @@ void sort_tracks(HostRef<CoreStateData> const& states, TrackOrder order)
{
case TrackOrder::partition_status:
return partition_impl(states.track_slots,
alive_predicate{states.sim.status.data()},
states.stream_id);
alive_predicate{states.sim.status.data()});
case TrackOrder::sort_along_step_action:
return sort_impl(
states.track_slots,
action_comparator{states.sim.along_step_action.data()},
states.stream_id);
case TrackOrder::sort_step_limit_action:
return sort_impl(states.track_slots,
id_comparator{get_action_ptr(states, order)});
case TrackOrder::sort_particle_type:
return sort_impl(
states.track_slots,
action_comparator{states.sim.post_step_action.data()},
states.stream_id);
id_comparator{states.particles.particle_id.data()});
default:
CELER_ASSERT_UNREACHABLE();
}
Expand All @@ -140,23 +106,34 @@ void count_tracks_per_action(
Collection<ThreadId, Ownership::value, MemSpace::host, ActionId>&,
TrackOrder order)
{
switch (order)
CELER_ASSERT(order == TrackOrder::sort_along_step_action
|| order == TrackOrder::sort_step_limit_action);

ActionAccessor get_action{get_action_ptr(states, order),
states.track_slots.data()};

std::fill(offsets.begin(), offsets.end(), ThreadId{});
auto const size = states.size();
// if get_action(0) != get_action(1), get_action(0) never gets initialized
#pragma omp parallel for
for (size_type i = 1; i < size; ++i)
{
case TrackOrder::sort_along_step_action:
return count_tracks_per_action_impl(
offsets,
states.size(),
ActionAccessor{states.sim.along_step_action.data(),
states.track_slots.data()});
case TrackOrder::sort_step_limit_action:
return count_tracks_per_action_impl(
offsets,
states.size(),
ActionAccessor{states.sim.post_step_action.data(),
states.track_slots.data()});
default:
return;
ActionId current_action = get_action(ThreadId{i});
if (!current_action)
continue;

if (current_action != get_action(ThreadId{i - 1}))
{
offsets[current_action.unchecked_get()] = ThreadId{i};
}
}

// so make sure get_action(0) is initialized
if (ActionId first = get_action(ThreadId{0}))
{
offsets[first.unchecked_get()] = ThreadId{0};
}
backfill_action_count(offsets, size);
}

void backfill_action_count(Span<ThreadId> offsets, size_type num_actions)
Expand Down
108 changes: 37 additions & 71 deletions src/celeritas/track/detail/TrackSortUtils.cu
Original file line number Diff line number Diff line change
Expand Up @@ -96,13 +96,15 @@ void sort_impl(TrackSlots const& track_slots,
CELER_DEVICE_CHECK_ERROR();
}

// PRE: get_action is sorted, i.e. i <= j ==> get_action(i) <=
// get_action(j)
template<class F>
__device__ void
tracks_per_action_impl(Span<ThreadId> offsets, size_type size, F&& get_action)
// PRE: actions are sorted
esseivaju marked this conversation as resolved.
Show resolved Hide resolved
__global__ void
tracks_per_action_kernel(ObserverPtr<ActionId const> actions,
ObserverPtr<TrackSlotId::size_type const> track_slots,
Span<ThreadId> offsets,
size_type size)
{
ThreadId tid = celeritas::KernelParamCalculator::thread_id();
ActionAccessor get_action{actions, track_slots};

if ((tid < size) && tid != ThreadId{0})
{
Expand All @@ -123,30 +125,6 @@ tracks_per_action_impl(Span<ThreadId> offsets, size_type size, F&& get_action)
}
}

__global__ void tracks_per_action_kernel(DeviceRef<CoreStateData> const states,
Span<ThreadId> offsets,
size_type size,
TrackOrder order)
{
switch (order)
{
case TrackOrder::sort_along_step_action:
return tracks_per_action_impl(
offsets,
size,
ActionAccessor{states.sim.along_step_action.data(),
states.track_slots.data()});
case TrackOrder::sort_step_limit_action:
return tracks_per_action_impl(
offsets,
size,
ActionAccessor{states.sim.post_step_action.data(),
states.track_slots.data()});
default:
CELER_ASSERT_UNREACHABLE();
}
}

//---------------------------------------------------------------------------//
} // namespace

Expand Down Expand Up @@ -199,20 +177,11 @@ void sort_tracks(DeviceRef<CoreStateData> const& states, TrackOrder order)
return partition_impl(states.track_slots,
alive_predicate{states.sim.status.data()},
states.stream_id);
case TrackOrder::sort_along_step_action: {
using Id =
typename decltype(states.sim.along_step_action)::value_type;
return sort_impl<Id>(states.track_slots,
states.sim.along_step_action.data(),
states.stream_id);
}
case TrackOrder::sort_step_limit_action: {
using Id =
typename decltype(states.sim.post_step_action)::value_type;
return sort_impl<Id>(states.track_slots,
states.sim.post_step_action.data(),
states.stream_id);
}
case TrackOrder::sort_along_step_action:
case TrackOrder::sort_step_limit_action:
return sort_impl(states.track_slots,
get_action_ptr(states, order),
states.stream_id);
case TrackOrder::sort_particle_type: {
using Id =
typename decltype(states.particles.particle_id)::value_type;
Expand All @@ -237,34 +206,31 @@ void count_tracks_per_action(
Collection<ThreadId, Ownership::value, MemSpace::mapped, ActionId>& out,
TrackOrder order)
{
if (order == TrackOrder::sort_along_step_action
|| order == TrackOrder::sort_step_limit_action)
{
// dispatch in the kernel since CELER_LAUNCH_KERNEL doesn't work
// with templated kernels
auto start = device_pointer_cast(make_observer(offsets.data()));
thrust::fill(thrust_execute_on(states.stream_id),
start,
start + offsets.size(),
ThreadId{});
CELER_DEVICE_CHECK_ERROR();
auto* stream = celeritas::device().stream(states.stream_id).get();
CELER_LAUNCH_KERNEL(tracks_per_action,
states.size(),
stream,
states,
offsets,
states.size(),
order);

Span<ThreadId> sout = out[AllItems<ThreadId, MemSpace::mapped>{}];
Copier<ThreadId, MemSpace::host> copy_to_host{sout, states.stream_id};
copy_to_host(MemSpace::device, offsets);

// Copies must be complete before backfilling
CELER_DEVICE_CALL_PREFIX(StreamSynchronize(stream));
backfill_action_count(sout, states.size());
}
CELER_ASSERT(order == TrackOrder::sort_along_step_action
|| order == TrackOrder::sort_step_limit_action);

auto start = device_pointer_cast(make_observer(offsets.data()));
thrust::fill(thrust_execute_on(states.stream_id),
start,
start + offsets.size(),
ThreadId{});
CELER_DEVICE_CHECK_ERROR();
auto* stream = celeritas::device().stream(states.stream_id).get();
CELER_LAUNCH_KERNEL(tracks_per_action,
states.size(),
stream,
get_action_ptr(states, order),
states.track_slots.data(),
offsets,
states.size());

Span<ThreadId> sout = out[AllItems<ThreadId, MemSpace::mapped>{}];
Copier<ThreadId, MemSpace::host> copy_to_host{sout, states.stream_id};
copy_to_host(MemSpace::device, offsets);

// Copies must be complete before backfilling
CELER_DEVICE_CALL_PREFIX(StreamSynchronize(stream));
backfill_action_count(sout, states.size());
}

//---------------------------------------------------------------------------//
Expand Down
43 changes: 31 additions & 12 deletions src/celeritas/track/detail/TrackSortUtils.hh
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,8 @@ namespace detail

// Initialize default threads to track_slots mapping, track_slots[i] = i
// TODO: move to global/detail and overload using ObserverPtr
template<MemSpace M,
typename Size,
typename = std::enable_if_t<std::is_unsigned_v<Size>>>
void fill_track_slots(Span<Size> track_slots, StreamId);
template<MemSpace M>
void fill_track_slots(Span<TrackSlotId::size_type> track_slots, StreamId);

template<>
void fill_track_slots<MemSpace::host>(Span<TrackSlotId::size_type> track_slots,
Expand All @@ -42,10 +40,8 @@ void fill_track_slots<MemSpace::device>(Span<TrackSlotId::size_type> track_slots
//---------------------------------------------------------------------------//
// Shuffle tracks
// TODO: move to global/detail and overload using ObserverPtr
template<MemSpace M,
typename Size,
typename = std::enable_if_t<std::is_unsigned_v<Size>>>
void shuffle_track_slots(Span<Size> track_slots, StreamId);
template<MemSpace M>
void shuffle_track_slots(Span<TrackSlotId::size_type> track_slots, StreamId);

template<>
void shuffle_track_slots<MemSpace::host>(
Expand Down Expand Up @@ -77,7 +73,7 @@ void count_tracks_per_action(
void backfill_action_count(Span<ThreadId>, size_type);

//---------------------------------------------------------------------------//
// HELPER CLASSES
// HELPER CLASSES AND FUNCTIONS
//---------------------------------------------------------------------------//
struct alive_predicate
{
Expand All @@ -89,13 +85,14 @@ struct alive_predicate
}
};

struct action_comparator
template<class Id>
struct id_comparator
esseivaju marked this conversation as resolved.
Show resolved Hide resolved
{
ObserverPtr<ActionId const> action_;
ObserverPtr<Id const> ids_;

CELER_FUNCTION bool operator()(size_type a, size_type b) const
{
return action_.get()[a] < action_.get()[b];
return ids_.get()[a] < ids_.get()[b];
}
};

Expand All @@ -110,6 +107,28 @@ struct ActionAccessor
}
};

template<Ownership W, MemSpace M>
CELER_FUNCTION ObserverPtr<ActionId const>
get_action_ptr(CoreStateData<W, M> const& states, TrackOrder order)
{
if (order == TrackOrder::sort_along_step_action)
{
return states.sim.along_step_action.data();
}
else if (order == TrackOrder::sort_step_limit_action)
{
return states.sim.post_step_action.data();
}
CELER_ASSERT_UNREACHABLE();
}

//---------------------------------------------------------------------------//
// DEDUCTION GUIDES
//---------------------------------------------------------------------------//

template<class Id>
id_comparator(ObserverPtr<Id>) -> id_comparator<Id>;

//---------------------------------------------------------------------------//
// INLINE DEFINITIONS
//---------------------------------------------------------------------------//
Expand Down
2 changes: 1 addition & 1 deletion src/corecel/sys/KernelParamCalculator.device.hh
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@
#NAME, NAME##_kernel<T1>); \
auto grid_ = calc_launch_params_(THREADS); \
\
CELER_LAUNCH_KERNEL_IMPL(NAME##_kernel, \
CELER_LAUNCH_KERNEL_IMPL(NAME##_kernel<T1>, \
grid_.blocks_per_grid, \
grid_.threads_per_block, \
0, \
Expand Down