Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Properly clear state counters and data for reuse #1367

Merged
merged 1 commit into from
Aug 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions app/celer-sim/Transporter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,14 @@ auto Transporter<M>::operator()(SpanConstPrimary primaries) -> TransporterResult

result.num_aborted = track_counts.alive + track_counts.queued;
result.num_track_slots = stepper_->state().size();

if (result.num_aborted > 0)
{
// Reset the state data for the next event if the stepping loop was
// aborted early
step.reset_state();
}

return result;
}

Expand Down
1 change: 1 addition & 0 deletions src/celeritas/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,7 @@ celeritas_polysource(global/alongstep/AlongStepGeneralLinearAction)
celeritas_polysource(global/alongstep/AlongStepNeutralAction)
celeritas_polysource(global/alongstep/AlongStepUniformMscAction)
celeritas_polysource(global/alongstep/AlongStepRZMapFieldMscAction)
celeritas_polysource(global/detail/TrackSlotUtils)
celeritas_polysource(neutron/model/ChipsNeutronElasticModel)
celeritas_polysource(neutron/model/NeutronInelasticModel)
celeritas_polysource(optical/detail/CerenkovOffloadAction)
Expand Down
25 changes: 22 additions & 3 deletions src/celeritas/global/CoreState.cc
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,11 @@
//---------------------------------------------------------------------------//
#include "CoreState.hh"

#include "corecel/data/CollectionAlgorithms.hh"
#include "corecel/data/Copier.hh"
#include "corecel/io/Logger.hh"
#include "corecel/sys/ScopedProfiling.hh"
#include "celeritas/track/TrackInitParams.hh"
#include "celeritas/track/detail/TrackSortUtils.hh"

#include "ActionRegistry.hh"
#include "CoreParams.hh"
Expand Down Expand Up @@ -43,8 +43,6 @@ CoreState<M>::CoreState(CoreParams const& params,
params.host_ref(), stream_id, num_track_slots);

counters_.num_vacancies = num_track_slots;
counters_.num_primaries = 0;
counters_.num_initializers = 0;

if constexpr (M == MemSpace::device)
{
Expand Down Expand Up @@ -110,6 +108,27 @@ Range<ThreadId> CoreState<M>::get_action_range(ActionId action_id) const
return {thread_offsets[action_id], thread_offsets[action_id + 1]};
}

//---------------------------------------------------------------------------//
/*!
* Reset the state data.
*
* This clears the state counters and initializes the necessary state data so
* the state can be reused for a new event. This should only be necessary if
* the previous event aborted early.
*/
template<MemSpace M>
void CoreState<M>::reset()
{
counters_ = CoreStateCounters{};
counters_.num_vacancies = this->size();

// Reset all the track slots to inactive
fill(TrackStatus::inactive, &this->ref().sim.status);

// Mark all the track slots as empty
fill_sequence(&this->ref().init.vacancies, this->stream_id());
}

//---------------------------------------------------------------------------//
// EXPLICIT INSTANTIATION
//---------------------------------------------------------------------------//
Expand Down
3 changes: 3 additions & 0 deletions src/celeritas/global/CoreState.hh
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,9 @@ class CoreState final : public CoreStateInterface
//! Get a native-memspace pointer to the mutable state data
Ptr ptr() { return ptr_; }

//! Reset the state data
void reset();

//// COUNTERS ////

//! Track initialization counters
Expand Down
12 changes: 6 additions & 6 deletions src/celeritas/global/CoreTrackData.cc
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@
//---------------------------------------------------------------------------//
#include "CoreTrackData.hh"

#include "celeritas/track/detail/TrackSortUtils.hh"
#include "corecel/data/CollectionAlgorithms.hh"

#include "detail/TrackSlotUtils.hh"

namespace celeritas
{
Expand Down Expand Up @@ -38,18 +40,16 @@ void resize(CoreStateData<Ownership::value, M>* state,
resize(&state->physics, params.physics, size);
resize(&state->rng, params.rng, stream_id, size);
resize(&state->sim, size);
resize(&state->init, params.init, size);
resize(&state->init, params.init, stream_id, size);
state->stream_id = stream_id;

if (params.init.track_order != TrackOrder::unsorted)
{
resize(&state->track_slots, size);
Span track_slots{
state->track_slots[AllItems<TrackSlotId::size_type, M>{}]};
detail::fill_track_slots<M>(track_slots, stream_id);
fill_sequence(&state->track_slots, stream_id);
if (params.init.track_order == TrackOrder::shuffled)
{
detail::shuffle_track_slots<M>(track_slots, stream_id);
detail::shuffle_track_slots(&state->track_slots, stream_id);
}
}

Expand Down
3 changes: 3 additions & 0 deletions src/celeritas/global/Stepper.hh
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,9 @@ class Stepper final : public StepperInterface
//! Get the core state interface for diagnostic output
CoreStateInterface const& state() const final { return state_; }

//! Reset the core state counters and data so it can be reused
void reset_state() { state_.reset(); }

private:
// Params data
std::shared_ptr<CoreParams const> params_;
Expand Down
36 changes: 36 additions & 0 deletions src/celeritas/global/detail/TrackSlotUtils.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
//----------------------------------*-C++-*----------------------------------//
// Copyright 2024 UT-Battelle, LLC, and other Celeritas developers.
// See the top-level COPYRIGHT file for details.
// SPDX-License-Identifier: (Apache-2.0 OR MIT)
//---------------------------------------------------------------------------//
//! \file celeritas/global/detail/TrackSlotUtils.cc
//---------------------------------------------------------------------------//
#include "TrackSlotUtils.hh"

#include <algorithm>
#include <random>

namespace celeritas
{
namespace detail
{
//---------------------------------------------------------------------------//
/*!
* Shuffle track slot indices.
*/
void shuffle_track_slots(
Collection<TrackSlotId::size_type, Ownership::value, MemSpace::host, ThreadId>*
track_slots,
StreamId)
{
CELER_EXPECT(track_slots);
auto* start
= static_cast<TrackSlotId::size_type*>(track_slots->data().get());
auto seed = static_cast<unsigned int>(track_slots->size());
std::mt19937 g{seed};
std::shuffle(start, start + track_slots->size(), g);
}

//---------------------------------------------------------------------------//
} // namespace detail
} // namespace celeritas
42 changes: 42 additions & 0 deletions src/celeritas/global/detail/TrackSlotUtils.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
//---------------------------------*-CUDA-*----------------------------------//
// Copyright 2024 UT-Battelle, LLC, and other Celeritas developers.
// See the top-level COPYRIGHT file for details.
// SPDX-License-Identifier: (Apache-2.0 OR MIT)
//---------------------------------------------------------------------------//
//! \file celeritas/global/detail/TrackSlotUtils.cu
//---------------------------------------------------------------------------//
#include "TrackSlotUtils.hh"

#include <thrust/device_ptr.h>
#include <thrust/execution_policy.h>
#include <thrust/random.h>
#include <thrust/shuffle.h>

#include "corecel/sys/Thrust.device.hh"

namespace celeritas
{
namespace detail
{
//---------------------------------------------------------------------------//
/*!
* Shuffle track slot indices.
*/
void shuffle_track_slots(
Collection<TrackSlotId::size_type, Ownership::value, MemSpace::device, ThreadId>*
track_slots,
StreamId stream)
{
CELER_EXPECT(track_slots);
using result_type = thrust::default_random_engine::result_type;
thrust::default_random_engine g{
static_cast<result_type>(track_slots->size())};
auto start = thrust::device_pointer_cast(track_slots->data().get());
thrust::shuffle(
thrust_execute_on(stream), start, start + track_slots->size(), g);
CELER_DEVICE_CHECK_ERROR();
}

//---------------------------------------------------------------------------//
} // namespace detail
} // namespace celeritas
42 changes: 42 additions & 0 deletions src/celeritas/global/detail/TrackSlotUtils.hh
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
//----------------------------------*-C++-*----------------------------------//
// Copyright 2024 UT-Battelle, LLC, and other Celeritas developers.
// See the top-level COPYRIGHT file for details.
// SPDX-License-Identifier: (Apache-2.0 OR MIT)
//---------------------------------------------------------------------------//
//! \file celeritas/global/detail/TrackSlotUtils.hh
//---------------------------------------------------------------------------//
#pragma once

#include "corecel/Assert.hh"
#include "corecel/Macros.hh"
#include "corecel/Types.hh"
#include "corecel/data/Collection.hh"
#include "corecel/sys/ThreadId.hh"

namespace celeritas
{
namespace detail
{
//---------------------------------------------------------------------------//
// Shuffle track slot indices
void shuffle_track_slots(
Collection<TrackSlotId::size_type, Ownership::value, MemSpace::host, ThreadId>*,
StreamId);
void shuffle_track_slots(
Collection<TrackSlotId::size_type, Ownership::value, MemSpace::device, ThreadId>*,
StreamId);

//---------------------------------------------------------------------------//
// INLINE DEFINITIONS
//---------------------------------------------------------------------------//
#if !CELER_USE_DEVICE
inline void shuffle_track_slots(
Collection<TrackSlotId::size_type, Ownership::value, MemSpace::device, ThreadId>*,
StreamId)
{
CELER_NOT_CONFIGURED("CUDA or HIP");
}
#endif
//---------------------------------------------------------------------------//
} // namespace detail
} // namespace celeritas
2 changes: 1 addition & 1 deletion src/celeritas/optical/TrackData.cc
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ void resize(CoreStateData<Ownership::value, M>* state,
resize(&state->physics, params.physics, size);
resize(&state->rng, params.rng, stream_id, size);
resize(&state->sim, size);
resize(&state->init, params.init, size);
resize(&state->init, params.init, stream_id, size);
state->stream_id = stream_id;

CELER_ENSURE(state);
Expand Down
6 changes: 0 additions & 6 deletions src/celeritas/track/SimData.hh
Original file line number Diff line number Diff line change
Expand Up @@ -183,15 +183,9 @@ void resize(SimStateData<Ownership::value, M>* data, size_type size)
resize(&data->track_ids, size);
resize(&data->parent_ids, size);
resize(&data->event_ids, size);

resize(&data->num_steps, size);
fill(size_type{0}, &data->num_steps);

resize(&data->num_looping_steps, size);
fill(size_type{0}, &data->num_looping_steps);

resize(&data->time, size);
fill(real_type{0}, &data->time);

resize(&data->status, size);
fill(TrackStatus::inactive, &data->status);
Expand Down
14 changes: 4 additions & 10 deletions src/celeritas/track/TrackInitData.hh
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include "corecel/Types.hh"
#include "corecel/cont/Range.hh"
#include "corecel/data/Collection.hh"
#include "corecel/data/CollectionAlgorithms.hh"
#include "corecel/data/CollectionBuilder.hh"
#include "corecel/sys/Device.hh"
#include "corecel/sys/ThreadId.hh"
Expand Down Expand Up @@ -145,9 +146,6 @@ struct TrackInitStateData
}
};

using TrackInitStateDeviceRef = DeviceRef<TrackInitStateData>;
using TrackInitStateHostRef = HostRef<TrackInitStateData>;

//---------------------------------------------------------------------------//
/*!
* Resize and initialize track initializer data.
Expand All @@ -162,6 +160,7 @@ using TrackInitStateHostRef = HostRef<TrackInitStateData>;
template<MemSpace M>
void resize(TrackInitStateData<Ownership::value, M>* data,
HostCRef<TrackInitParamsData> const& params,
StreamId stream,
size_type size)
{
CELER_EXPECT(params);
Expand All @@ -177,13 +176,8 @@ void resize(TrackInitStateData<Ownership::value, M>* data,
fill(size_type(0), &data->track_counters);

// Initialize vacancies to mark all track slots as empty
StateCollection<TrackSlotId, Ownership::value, MemSpace::host> vacancies;
resize(&vacancies, size);
for (auto i : range(size))
{
vacancies[TrackSlotId{i}] = TrackSlotId{i};
}
data->vacancies = std::move(vacancies);
resize(&data->vacancies, size);
fill_sequence(&data->vacancies, stream);

// Reserve space for initializers
resize(&data->initializers, params.capacity);
Expand Down
27 changes: 0 additions & 27 deletions src/celeritas/track/detail/TrackSortUtils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
#include <algorithm>
#include <iterator>
#include <numeric>
#include <random>

#include "corecel/data/Collection.hh"

Expand Down Expand Up @@ -65,32 +64,6 @@ IdLess(ObserverPtr<Id>) -> IdLess<Id>;
//---------------------------------------------------------------------------//
} // namespace

//---------------------------------------------------------------------------//
/*!
* Initialize default threads to track_slots mapping.
*
* This sets \code track_slots[i] = i \endcode .
*/
template<>
void fill_track_slots<MemSpace::host>(Span<TrackSlotId::size_type> track_slots,
StreamId)
{
std::iota(track_slots.data(), track_slots.data() + track_slots.size(), 0);
}

//---------------------------------------------------------------------------//
/*!
* Shuffle track slots.
*/
template<>
void shuffle_track_slots<MemSpace::host>(
Span<TrackSlotId::size_type> track_slots, StreamId)
{
auto seed = static_cast<unsigned int>(track_slots.size());
std::mt19937 g{seed};
std::shuffle(track_slots.begin(), track_slots.end(), g);
}

//---------------------------------------------------------------------------//
/*!
* Sort or partition tracks.
Expand Down
Loading
Loading