Skip to content

Commit

Permalink
Properly clear state counters and data for reuse (#1367)
Browse files Browse the repository at this point in the history
* Add a helper method to reset state data
* Add a generic function to fill a collection with a sequence
* Move track shuffle util to global/detail
  • Loading branch information
amandalund authored Aug 16, 2024
1 parent ac99201 commit 1e4aa73
Show file tree
Hide file tree
Showing 18 changed files with 208 additions and 135 deletions.
8 changes: 8 additions & 0 deletions app/celer-sim/Transporter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,14 @@ auto Transporter<M>::operator()(SpanConstPrimary primaries) -> TransporterResult

result.num_aborted = track_counts.alive + track_counts.queued;
result.num_track_slots = stepper_->state().size();

if (result.num_aborted > 0)
{
// Reset the state data for the next event if the stepping loop was
// aborted early
step.reset_state();
}

return result;
}

Expand Down
1 change: 1 addition & 0 deletions src/celeritas/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,7 @@ celeritas_polysource(global/alongstep/AlongStepGeneralLinearAction)
celeritas_polysource(global/alongstep/AlongStepNeutralAction)
celeritas_polysource(global/alongstep/AlongStepUniformMscAction)
celeritas_polysource(global/alongstep/AlongStepRZMapFieldMscAction)
celeritas_polysource(global/detail/TrackSlotUtils)
celeritas_polysource(neutron/model/ChipsNeutronElasticModel)
celeritas_polysource(neutron/model/NeutronInelasticModel)
celeritas_polysource(optical/detail/CerenkovOffloadAction)
Expand Down
25 changes: 22 additions & 3 deletions src/celeritas/global/CoreState.cc
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,11 @@
//---------------------------------------------------------------------------//
#include "CoreState.hh"

#include "corecel/data/CollectionAlgorithms.hh"
#include "corecel/data/Copier.hh"
#include "corecel/io/Logger.hh"
#include "corecel/sys/ScopedProfiling.hh"
#include "celeritas/track/TrackInitParams.hh"
#include "celeritas/track/detail/TrackSortUtils.hh"

#include "ActionRegistry.hh"
#include "CoreParams.hh"
Expand Down Expand Up @@ -43,8 +43,6 @@ CoreState<M>::CoreState(CoreParams const& params,
params.host_ref(), stream_id, num_track_slots);

counters_.num_vacancies = num_track_slots;
counters_.num_primaries = 0;
counters_.num_initializers = 0;

if constexpr (M == MemSpace::device)
{
Expand Down Expand Up @@ -110,6 +108,27 @@ Range<ThreadId> CoreState<M>::get_action_range(ActionId action_id) const
return {thread_offsets[action_id], thread_offsets[action_id + 1]};
}

//---------------------------------------------------------------------------//
/*!
* Reset the state data.
*
* This clears the state counters and initializes the necessary state data so
* the state can be reused for a new event. This should only be necessary if
* the previous event aborted early.
*/
template<MemSpace M>
void CoreState<M>::reset()
{
counters_ = CoreStateCounters{};
counters_.num_vacancies = this->size();

// Reset all the track slots to inactive
fill(TrackStatus::inactive, &this->ref().sim.status);

// Mark all the track slots as empty
fill_sequence(&this->ref().init.vacancies, this->stream_id());
}

//---------------------------------------------------------------------------//
// EXPLICIT INSTANTIATION
//---------------------------------------------------------------------------//
Expand Down
3 changes: 3 additions & 0 deletions src/celeritas/global/CoreState.hh
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,9 @@ class CoreState final : public CoreStateInterface
//! Get a native-memspace pointer to the mutable state data
Ptr ptr() { return ptr_; }

//! Reset the state data
void reset();

//// COUNTERS ////

//! Track initialization counters
Expand Down
12 changes: 6 additions & 6 deletions src/celeritas/global/CoreTrackData.cc
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@
//---------------------------------------------------------------------------//
#include "CoreTrackData.hh"

#include "celeritas/track/detail/TrackSortUtils.hh"
#include "corecel/data/CollectionAlgorithms.hh"

#include "detail/TrackSlotUtils.hh"

namespace celeritas
{
Expand Down Expand Up @@ -38,18 +40,16 @@ void resize(CoreStateData<Ownership::value, M>* state,
resize(&state->physics, params.physics, size);
resize(&state->rng, params.rng, stream_id, size);
resize(&state->sim, size);
resize(&state->init, params.init, size);
resize(&state->init, params.init, stream_id, size);
state->stream_id = stream_id;

if (params.init.track_order != TrackOrder::unsorted)
{
resize(&state->track_slots, size);
Span track_slots{
state->track_slots[AllItems<TrackSlotId::size_type, M>{}]};
detail::fill_track_slots<M>(track_slots, stream_id);
fill_sequence(&state->track_slots, stream_id);
if (params.init.track_order == TrackOrder::shuffled)
{
detail::shuffle_track_slots<M>(track_slots, stream_id);
detail::shuffle_track_slots(&state->track_slots, stream_id);
}
}

Expand Down
3 changes: 3 additions & 0 deletions src/celeritas/global/Stepper.hh
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,9 @@ class Stepper final : public StepperInterface
//! Get the core state interface for diagnostic output
CoreStateInterface const& state() const final { return state_; }

//! Reset the core state counters and data so it can be reused
void reset_state() { state_.reset(); }

private:
// Params data
std::shared_ptr<CoreParams const> params_;
Expand Down
36 changes: 36 additions & 0 deletions src/celeritas/global/detail/TrackSlotUtils.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
//----------------------------------*-C++-*----------------------------------//
// Copyright 2024 UT-Battelle, LLC, and other Celeritas developers.
// See the top-level COPYRIGHT file for details.
// SPDX-License-Identifier: (Apache-2.0 OR MIT)
//---------------------------------------------------------------------------//
//! \file celeritas/global/detail/TrackSlotUtils.cc
//---------------------------------------------------------------------------//
#include "TrackSlotUtils.hh"

#include <algorithm>
#include <random>

namespace celeritas
{
namespace detail
{
//---------------------------------------------------------------------------//
/*!
* Shuffle track slot indices.
*/
void shuffle_track_slots(
Collection<TrackSlotId::size_type, Ownership::value, MemSpace::host, ThreadId>*
track_slots,
StreamId)
{
CELER_EXPECT(track_slots);
auto* start
= static_cast<TrackSlotId::size_type*>(track_slots->data().get());
auto seed = static_cast<unsigned int>(track_slots->size());
std::mt19937 g{seed};
std::shuffle(start, start + track_slots->size(), g);
}

//---------------------------------------------------------------------------//
} // namespace detail
} // namespace celeritas
42 changes: 42 additions & 0 deletions src/celeritas/global/detail/TrackSlotUtils.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
//---------------------------------*-CUDA-*----------------------------------//
// Copyright 2024 UT-Battelle, LLC, and other Celeritas developers.
// See the top-level COPYRIGHT file for details.
// SPDX-License-Identifier: (Apache-2.0 OR MIT)
//---------------------------------------------------------------------------//
//! \file celeritas/global/detail/TrackSlotUtils.cu
//---------------------------------------------------------------------------//
#include "TrackSlotUtils.hh"

#include <thrust/device_ptr.h>
#include <thrust/execution_policy.h>
#include <thrust/random.h>
#include <thrust/shuffle.h>

#include "corecel/sys/Thrust.device.hh"

namespace celeritas
{
namespace detail
{
//---------------------------------------------------------------------------//
/*!
* Shuffle track slot indices.
*/
void shuffle_track_slots(
Collection<TrackSlotId::size_type, Ownership::value, MemSpace::device, ThreadId>*
track_slots,
StreamId stream)
{
CELER_EXPECT(track_slots);
using result_type = thrust::default_random_engine::result_type;
thrust::default_random_engine g{
static_cast<result_type>(track_slots->size())};
auto start = thrust::device_pointer_cast(track_slots->data().get());
thrust::shuffle(
thrust_execute_on(stream), start, start + track_slots->size(), g);
CELER_DEVICE_CHECK_ERROR();
}

//---------------------------------------------------------------------------//
} // namespace detail
} // namespace celeritas
42 changes: 42 additions & 0 deletions src/celeritas/global/detail/TrackSlotUtils.hh
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
//----------------------------------*-C++-*----------------------------------//
// Copyright 2024 UT-Battelle, LLC, and other Celeritas developers.
// See the top-level COPYRIGHT file for details.
// SPDX-License-Identifier: (Apache-2.0 OR MIT)
//---------------------------------------------------------------------------//
//! \file celeritas/global/detail/TrackSlotUtils.hh
//---------------------------------------------------------------------------//
#pragma once

#include "corecel/Assert.hh"
#include "corecel/Macros.hh"
#include "corecel/Types.hh"
#include "corecel/data/Collection.hh"
#include "corecel/sys/ThreadId.hh"

namespace celeritas
{
namespace detail
{
//---------------------------------------------------------------------------//
// Shuffle track slot indices
void shuffle_track_slots(
Collection<TrackSlotId::size_type, Ownership::value, MemSpace::host, ThreadId>*,
StreamId);
void shuffle_track_slots(
Collection<TrackSlotId::size_type, Ownership::value, MemSpace::device, ThreadId>*,
StreamId);

//---------------------------------------------------------------------------//
// INLINE DEFINITIONS
//---------------------------------------------------------------------------//
#if !CELER_USE_DEVICE
inline void shuffle_track_slots(
Collection<TrackSlotId::size_type, Ownership::value, MemSpace::device, ThreadId>*,
StreamId)
{
CELER_NOT_CONFIGURED("CUDA or HIP");
}
#endif
//---------------------------------------------------------------------------//
} // namespace detail
} // namespace celeritas
2 changes: 1 addition & 1 deletion src/celeritas/optical/TrackData.cc
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ void resize(CoreStateData<Ownership::value, M>* state,
resize(&state->physics, params.physics, size);
resize(&state->rng, params.rng, stream_id, size);
resize(&state->sim, size);
resize(&state->init, params.init, size);
resize(&state->init, params.init, stream_id, size);
state->stream_id = stream_id;

CELER_ENSURE(state);
Expand Down
6 changes: 0 additions & 6 deletions src/celeritas/track/SimData.hh
Original file line number Diff line number Diff line change
Expand Up @@ -183,15 +183,9 @@ void resize(SimStateData<Ownership::value, M>* data, size_type size)
resize(&data->track_ids, size);
resize(&data->parent_ids, size);
resize(&data->event_ids, size);

resize(&data->num_steps, size);
fill(size_type{0}, &data->num_steps);

resize(&data->num_looping_steps, size);
fill(size_type{0}, &data->num_looping_steps);

resize(&data->time, size);
fill(real_type{0}, &data->time);

resize(&data->status, size);
fill(TrackStatus::inactive, &data->status);
Expand Down
14 changes: 4 additions & 10 deletions src/celeritas/track/TrackInitData.hh
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include "corecel/Types.hh"
#include "corecel/cont/Range.hh"
#include "corecel/data/Collection.hh"
#include "corecel/data/CollectionAlgorithms.hh"
#include "corecel/data/CollectionBuilder.hh"
#include "corecel/sys/Device.hh"
#include "corecel/sys/ThreadId.hh"
Expand Down Expand Up @@ -145,9 +146,6 @@ struct TrackInitStateData
}
};

using TrackInitStateDeviceRef = DeviceRef<TrackInitStateData>;
using TrackInitStateHostRef = HostRef<TrackInitStateData>;

//---------------------------------------------------------------------------//
/*!
* Resize and initialize track initializer data.
Expand All @@ -162,6 +160,7 @@ using TrackInitStateHostRef = HostRef<TrackInitStateData>;
template<MemSpace M>
void resize(TrackInitStateData<Ownership::value, M>* data,
HostCRef<TrackInitParamsData> const& params,
StreamId stream,
size_type size)
{
CELER_EXPECT(params);
Expand All @@ -177,13 +176,8 @@ void resize(TrackInitStateData<Ownership::value, M>* data,
fill(size_type(0), &data->track_counters);

// Initialize vacancies to mark all track slots as empty
StateCollection<TrackSlotId, Ownership::value, MemSpace::host> vacancies;
resize(&vacancies, size);
for (auto i : range(size))
{
vacancies[TrackSlotId{i}] = TrackSlotId{i};
}
data->vacancies = std::move(vacancies);
resize(&data->vacancies, size);
fill_sequence(&data->vacancies, stream);

// Reserve space for initializers
resize(&data->initializers, params.capacity);
Expand Down
27 changes: 0 additions & 27 deletions src/celeritas/track/detail/TrackSortUtils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
#include <algorithm>
#include <iterator>
#include <numeric>
#include <random>

#include "corecel/data/Collection.hh"

Expand Down Expand Up @@ -65,32 +64,6 @@ IdLess(ObserverPtr<Id>) -> IdLess<Id>;
//---------------------------------------------------------------------------//
} // namespace

//---------------------------------------------------------------------------//
/*!
* Initialize default threads to track_slots mapping.
*
* This sets \code track_slots[i] = i \endcode .
*/
template<>
void fill_track_slots<MemSpace::host>(Span<TrackSlotId::size_type> track_slots,
StreamId)
{
std::iota(track_slots.data(), track_slots.data() + track_slots.size(), 0);
}

//---------------------------------------------------------------------------//
/*!
* Shuffle track slots.
*/
template<>
void shuffle_track_slots<MemSpace::host>(
Span<TrackSlotId::size_type> track_slots, StreamId)
{
auto seed = static_cast<unsigned int>(track_slots.size());
std::mt19937 g{seed};
std::shuffle(track_slots.begin(), track_slots.end(), g);
}

//---------------------------------------------------------------------------//
/*!
* Sort or partition tracks.
Expand Down
Loading

0 comments on commit 1e4aa73

Please sign in to comment.