Skip to content

Commit

Permalink
Insert SortTrackAction when sorting by particle type (#1059)
Browse files Browse the repository at this point in the history
* insert sort_particle_type action
* explicitly list all variants to trigger compiler warning if one is forgotten
* abstract access to CoreState offsets collections
* Add test
  • Loading branch information
esseivaju authored Dec 13, 2023
1 parent d47a259 commit 7b7bfb2
Show file tree
Hide file tree
Showing 5 changed files with 177 additions and 53 deletions.
5 changes: 4 additions & 1 deletion src/celeritas/global/CoreParams.cc
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,7 @@ CoreParams::CoreParams(Input input) : input_(std::move(input))
case TrackOrder::partition_status:
case TrackOrder::sort_step_limit_action:
case TrackOrder::sort_along_step_action:
case TrackOrder::sort_particle_type:
// Sort with just the given track order
insert_sort_tracks_action(track_order);
break;
Expand All @@ -263,7 +264,9 @@ CoreParams::CoreParams(Input input) : input_(std::move(input))
insert_sort_tracks_action(TrackOrder::sort_step_limit_action);
insert_sort_tracks_action(TrackOrder::sort_along_step_action);
break;
default:
case TrackOrder::unsorted:
case TrackOrder::shuffled:
case TrackOrder::size_:
break;
}

Expand Down
20 changes: 3 additions & 17 deletions src/celeritas/global/CoreState.cc
Original file line number Diff line number Diff line change
Expand Up @@ -77,16 +77,6 @@ void CoreState<M>::insert_primaries(Span<Primary const> host_primaries)
copy_to_temp(MemSpace::host, host_primaries);
}

//---------------------------------------------------------------------------//
/*!
* Reference to the ActionThread collection matching the state memory space
*/
template<MemSpace M>
auto CoreState<M>::native_action_thread_offsets() -> ActionThreads<M>&
{
return thread_offsets_;
}

//---------------------------------------------------------------------------//
/*!
* Get a range delimiting the [start, end) of the track partition assigned
Expand All @@ -95,7 +85,7 @@ auto CoreState<M>::native_action_thread_offsets() -> ActionThreads<M>&
template<MemSpace M>
Range<ThreadId> CoreState<M>::get_action_range(ActionId action_id) const
{
auto const& thread_offsets = action_thread_offsets();
auto const& thread_offsets = offsets_.host_action_thread_offsets();
CELER_EXPECT((action_id + 1) < thread_offsets.size());
return {thread_offsets[action_id], thread_offsets[action_id + 1]};
}
Expand All @@ -107,11 +97,7 @@ Range<ThreadId> CoreState<M>::get_action_range(ActionId action_id) const
template<MemSpace M>
void CoreState<M>::num_actions(size_type n)
{
resize(&thread_offsets_, n);
if constexpr (M == MemSpace::device)
{
resize(&host_thread_offsets_, n);
}
offsets_.resize(n);
}

//---------------------------------------------------------------------------//
Expand All @@ -121,7 +107,7 @@ void CoreState<M>::num_actions(size_type n)
template<MemSpace M>
size_type CoreState<M>::num_actions() const
{
return thread_offsets_.size();
return offsets_.host_action_thread_offsets().size();
}

//---------------------------------------------------------------------------//
Expand Down
39 changes: 17 additions & 22 deletions src/celeritas/global/CoreState.hh
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#include "corecel/data/Ref.hh"
#include "corecel/sys/ThreadId.hh"
#include "celeritas/global/CoreTrackData.hh"
#include "celeritas/global/detail/CoreStateThreadOffsets.hh"
#include "celeritas/phys/Primary.hh"
#include "celeritas/track/CoreStateCounters.hh"

Expand Down Expand Up @@ -70,7 +71,8 @@ class CoreState final : public CoreStateInterface
using Ptr = ObserverPtr<Ref, M>;
using PrimaryCRef = Collection<Primary, Ownership::const_reference, M>;
template<MemSpace M2>
using ActionThreads = Collection<ThreadId, Ownership::value, M2, ActionId>;
using ActionThreads =
typename detail::CoreStateThreadOffsets<M>::template ActionThreads<M2>;
//!@}

public:
Expand Down Expand Up @@ -140,17 +142,14 @@ class CoreState final : public CoreStateInterface

// Reference to the ActionThread collection matching the state memory
// space
ActionThreads<M>& native_action_thread_offsets();
inline auto& native_action_thread_offsets();

private:
// State data
CollectionStateStore<CoreStateData, M> states_;

// Indices of first thread assigned to a given action
ActionThreads<M> thread_offsets_;

// Only used if M == device for D2H copy of thread_offsets_
ActionThreads<MemSpace::mapped> host_thread_offsets_;
detail::CoreStateThreadOffsets<M> offsets_;

// Primaries to be added
Collection<Primary, Ownership::value, M> primaries_;
Expand Down Expand Up @@ -193,14 +192,7 @@ auto CoreState<M>::primary_storage() const -> PrimaryCRef
template<MemSpace M>
auto& CoreState<M>::action_thread_offsets()
{
if constexpr (M == MemSpace::device)
{
return host_thread_offsets_;
}
else
{
return thread_offsets_;
}
return offsets_.host_action_thread_offsets();
}

//---------------------------------------------------------------------------//
Expand All @@ -211,14 +203,17 @@ auto& CoreState<M>::action_thread_offsets()
template<MemSpace M>
auto const& CoreState<M>::action_thread_offsets() const
{
if constexpr (M == MemSpace::device)
{
return host_thread_offsets_;
}
else
{
return thread_offsets_;
}
return offsets_.host_action_thread_offsets();
}

//---------------------------------------------------------------------------//
/*!
* Reference to the ActionThread collection matching the state memory space
*/
template<MemSpace M>
auto& CoreState<M>::native_action_thread_offsets()
{
return offsets_.native_action_thread_offsets();
}

//---------------------------------------------------------------------------//
Expand Down
93 changes: 93 additions & 0 deletions src/celeritas/global/detail/CoreStateThreadOffsets.hh
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
//----------------------------------*-C++-*----------------------------------//
// Copyright 2023 UT-Battelle, LLC, and other Celeritas developers.
// See the top-level COPYRIGHT file for details.
// SPDX-License-Identifier: (Apache-2.0 OR MIT)
//---------------------------------------------------------------------------//
//! \file celeritas/global/detail/CoreStateThreadOffsets.hh
//---------------------------------------------------------------------------//
#pragma once

#include "corecel/Types.hh"
#include "corecel/data/Collection.hh"
#include "corecel/data/CollectionBuilder.hh"
#include "corecel/sys/ThreadId.hh"
#include "celeritas/Types.hh"

namespace celeritas
{
namespace detail
{
/*!
* Holds Collections used by CoreState to store thread offsets. This is
* specialized for device memory space as two collections are needed, one for
* the host and one for the device. Using pinned mapped memory would be less
* efficient.
*/
template<MemSpace M>
class CoreStateThreadOffsets
{
public:
//!@{
//! \name Type aliases
template<MemSpace M2>
using ActionThreads = Collection<ThreadId, Ownership::value, M2, ActionId>;
//!@}

public:
constexpr auto& host_action_thread_offsets() { return thread_offsets_; }
constexpr auto const& host_action_thread_offsets() const
{
return thread_offsets_;
}
constexpr auto& native_action_thread_offsets()
{
return host_action_thread_offsets();
}
constexpr auto const& native_action_thread_offsets() const
{
return host_action_thread_offsets();
}
void resize(size_type n) { celeritas::resize(&thread_offsets_, n); }

private:
ActionThreads<M> thread_offsets_;
};

template<>
class CoreStateThreadOffsets<MemSpace::device>
{
public:
//!@{
//! \name Type aliases
template<MemSpace M>
using ActionThreads = Collection<ThreadId, Ownership::value, M, ActionId>;
//!@}

public:
constexpr auto& host_action_thread_offsets()
{
return host_thread_offsets_;
}
constexpr auto const& host_action_thread_offsets() const
{
return host_thread_offsets_;
}
constexpr auto& native_action_thread_offsets() { return thread_offsets_; }
constexpr auto const& native_action_thread_offsets() const
{
return thread_offsets_;
}
void resize(size_type n)
{
celeritas::resize(&thread_offsets_, n);
celeritas::resize(&host_thread_offsets_, n);
}

private:
ActionThreads<MemSpace::device> thread_offsets_;
ActionThreads<MemSpace::mapped> host_thread_offsets_;
};

//---------------------------------------------------------------------------//
} // namespace detail
} // namespace celeritas
73 changes: 60 additions & 13 deletions test/celeritas/track/TrackSort.test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,11 @@
#include <vector>

#include "corecel/data/Collection.hh"
#include "corecel/io/LogContextException.hh"
#include "celeritas/Types.hh"
#include "celeritas/ext/GeantPhysicsOptions.hh"
#include "celeritas/global/ActionRegistry.hh"
#include "celeritas/global/CoreParams.hh"
#include "celeritas/global/CoreTrackData.hh"
#include "celeritas/global/Stepper.hh"
#include "celeritas/phys/PDGNumber.hh"
Expand Down Expand Up @@ -61,6 +63,16 @@ class TestEm3NoMsc : public TestEm3Base
}
return result;
}

protected:
auto build_init() -> SPConstTrackInit override
{
TrackInitParams::Input input;
input.capacity = 4096;
input.max_events = 4096;
input.track_order = TrackOrder::sort_step_limit_action;
return std::make_shared<TrackInitParams>(input);
}
};

class TrackSortTestBase : virtual public GlobalTestBase
Expand Down Expand Up @@ -114,14 +126,14 @@ class TestTrackSortActionIdEm3Stepper : public TestEm3NoMsc,

#define TestActionCountEm3Stepper \
TEST_IF_CELERITAS_GEANT(TestActionCountEm3Stepper)
template<MemSpace M>
class TestActionCountEm3Stepper : public TestEm3NoMsc, public TrackSortTestBase
{
protected:
template<MemSpace M>
using ActionThreads =
typename CoreState<MemSpace::device>::ActionThreads<M>;
template<MemSpace M>
using ActionThreadsItems = AllItems<ThreadId, M>;
template<MemSpace M2>
using ActionThreads = typename CoreState<M>::template ActionThreads<M2>;
template<MemSpace M2>
using ActionThreadsItems = AllItems<ThreadId, M2>;

auto build_init() -> SPConstTrackInit override
{
Expand All @@ -132,10 +144,11 @@ class TestActionCountEm3Stepper : public TestEm3NoMsc, public TrackSortTestBase
return std::make_shared<TrackInitParams>(input);
}

template<MemSpace M, MemSpace M2>
void
check_action_count(ActionThreads<M2> const& items, Stepper<M> const& step)
template<MemSpace M2>
void check_action_count(ActionThreads<M2> const& items, size_t size)
{
static_assert(M2 == MemSpace::host || M2 == MemSpace::mapped,
"ActionThreads must be host or mapped");
auto total_threads = 0;
Span<ThreadId const> items_span = items[ActionThreadsItems<M2>{}];
auto pos = std::find(items_span.begin(), items_span.end(), ThreadId{});
Expand All @@ -146,14 +159,46 @@ class TestActionCountEm3Stepper : public TestEm3NoMsc, public TrackSortTestBase
total_threads += r.size();
ASSERT_LE(items[ActionId{i}], items[ActionId{i + 1}]);
}
ASSERT_EQ(total_threads, step.state().size());
ASSERT_EQ(total_threads, size);
}
};

//---------------------------------------------------------------------------//
// TESTS
//---------------------------------------------------------------------------//

TEST_F(TestEm3NoMsc, host_is_sorting)
{
CoreState<MemSpace::host> state{*this->core(), StreamId{0}, 128};
auto execute = [&](std::string const& label) {
ActionId action_id = this->action_reg()->find_action(label);
CELER_VALIDATE(action_id, << "no '" << label << "' action found");
auto action = dynamic_cast<ExplicitActionInterface const*>(
this->action_reg()->action(action_id).get());
CELER_VALIDATE(action, << "action '" << label << "' cannot execute");
CELER_TRY_HANDLE(action->execute(*this->core(), state),
LogContextException{this->output_reg().get()});
};

auto primaries = this->make_primaries(state.size());
state.insert_primaries(make_span(primaries));
state.num_actions(this->action_reg()->num_actions() + 1);
execute("extend-from-primaries");
execute("initialize-tracks");
execute("pre-step");
execute("sort-tracks-post-step");
auto track_slots = state.ref().track_slots.data();
auto actions = detail::get_action_ptr(
state.ref(), this->core()->init()->host_ref().track_order);
detail::ActionAccessor action_accessor{actions, track_slots};
for (std::uint32_t i = 1; i < state.size(); ++i)
{
ASSERT_LE(action_accessor(ThreadId{i - 1}),
action_accessor(ThreadId{i}))
<< "Track slots are not sorted by action";
}
}

TEST_F(TestTrackPartitionEm3Stepper, host_is_partitioned)
{
// Create stepper and primaries, and take a step
Expand Down Expand Up @@ -322,7 +367,8 @@ TEST_F(TestTrackSortActionIdEm3Stepper, TEST_IF_CELER_DEVICE(device_is_sorted))
}
}

TEST_F(TestActionCountEm3Stepper, host_count_actions)
using TestActionCountEm3StepperH = TestActionCountEm3Stepper<MemSpace::host>;
TEST_F(TestActionCountEm3StepperH, host_count_actions)
{
using ActionThreadsH = ActionThreads<MemSpace::host>;
using ActionThreadsItemsH = ActionThreadsItems<MemSpace::host>;
Expand All @@ -348,7 +394,7 @@ TEST_F(TestActionCountEm3Stepper, host_count_actions)
buffer,
TrackOrder::sort_step_limit_action);

check_action_count(buffer, step);
check_action_count(buffer, step.state().size());
step();
};

Expand All @@ -365,7 +411,8 @@ TEST_F(TestActionCountEm3Stepper, host_count_actions)
}
}

TEST_F(TestActionCountEm3Stepper, TEST_IF_CELER_DEVICE(device_count_actions))
using TestActionCountEm3StepperD = TestActionCountEm3Stepper<MemSpace::device>;
TEST_F(TestActionCountEm3StepperD, TEST_IF_CELER_DEVICE(device_count_actions))
{
// Initialize some primaries and take a step
auto step = this->make_stepper<MemSpace::device>(128);
Expand All @@ -391,7 +438,7 @@ TEST_F(TestActionCountEm3Stepper, TEST_IF_CELER_DEVICE(device_count_actions))
buffer_h,
TrackOrder::sort_step_limit_action);

check_action_count(buffer_h, step);
check_action_count(buffer_h, step.state().size());
step();
};

Expand Down

0 comments on commit 7b7bfb2

Please sign in to comment.