Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ensure reproducibility when using MT Geant4 with Celeritas offloading #1061

Merged
merged 14 commits into from
Dec 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions app/celer-g4/EventAction.cc
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,9 @@ void EventAction::BeginOfEventAction(G4Event const* event)
if (SharedParams::CeleritasDisabled())
return;

// Set event ID in local transporter
// Set event ID in local transporter and reseed Celerits RNG
ExceptionConverter call_g4exception{"celer0002"};
CELER_TRY_HANDLE(transport_->SetEventId(event->GetEventID()),
CELER_TRY_HANDLE(transport_->InitializeEvent(event->GetEventID()),
call_g4exception);
}

Expand Down
19 changes: 17 additions & 2 deletions app/celer-g4/GeantDiagnostics.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@
#include "corecel/sys/MemRegistry.hh"
#include "corecel/sys/MultiExceptionHandler.hh"
#include "celeritas/Types.hh"
#include "celeritas/global/ActionRegistry.hh"
#include "celeritas/global/CoreParams.hh"
#include "celeritas/user/StepDiagnostic.hh"

#include "GlobalSetup.hh"

Expand Down Expand Up @@ -57,9 +59,22 @@ GeantDiagnostics::GeantDiagnostics(SharedParams const& params)
if (global_setup.StepDiagnostic())
{
// Create the track step diagnostic and add to output registry
step_diagnostic_ = std::make_shared<GeantStepDiagnostic>(
global_setup.GetStepDiagnosticBins(), num_threads);
auto num_bins = GlobalSetup::Instance()->GetStepDiagnosticBins();
step_diagnostic_
= std::make_shared<GeantStepDiagnostic>(num_bins, num_threads);
output_reg->insert(step_diagnostic_);

// Add the Celeritas step diagnostic if Celeritas offloading is enabled
if (params)
{
auto step_diagnostic = std::make_shared<celeritas::StepDiagnostic>(
params.Params()->action_reg()->next_id(),
params.Params()->particle(),
num_bins,
num_threads);
params.Params()->action_reg()->insert(step_diagnostic);
output_reg->insert(step_diagnostic);
}
}

if (!params)
Expand Down
15 changes: 13 additions & 2 deletions src/accel/LocalTransporter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include <string>
#include <type_traits>
#include <CLHEP/Units/SystemOfUnits.h>
#include <G4MTRunManager.hh>
#include <G4ParticleDefinition.hh>
#include <G4Threading.hh>
#include <G4ThreeVector.hh>
Expand Down Expand Up @@ -114,14 +115,24 @@ LocalTransporter::LocalTransporter(SetupOptions const& options,

//---------------------------------------------------------------------------//
/*!
* Set the event ID at the start of an event.
* Set the event ID and reseed the Celeritas RNG at the start of an event.
*/
void LocalTransporter::SetEventId(int id)
void LocalTransporter::InitializeEvent(int id)
{
CELER_EXPECT(*this);
CELER_EXPECT(id >= 0);

event_id_ = EventId(id);
track_counter_ = 0;

if (!(G4Threading::IsMultithreadedApplication()
&& G4MTRunManager::SeedOncePerCommunication()))
{
// Since Geant4 schedules events dynamically, reseed the Celeritas RNGs
// using the Geant4 event ID for reproducibility. This guarantees that
// an event can be reproduced given the event ID.
step_->reseed(event_id_);
}
sethrj marked this conversation as resolved.
Show resolved Hide resolved
}

//---------------------------------------------------------------------------//
Expand Down
7 changes: 5 additions & 2 deletions src/accel/LocalTransporter.hh
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,11 @@ class LocalTransporter
inline void
Initialize(SetupOptions const& options, SharedParams const& params);

// Set the event ID
void SetEventId(int);
// Set the event ID and reseed the Celeritas RNG (remove in v1.0)
[[deprecated]] void SetEventId(int id) { this->InitializeEvent(id); }

// Set the event ID and reseed the Celeritas RNG at the start of an event
void InitializeEvent(int);

// Offload this track
void Push(G4Track const&);
Expand Down
7 changes: 4 additions & 3 deletions src/accel/SimpleOffload.cc
Original file line number Diff line number Diff line change
Expand Up @@ -87,16 +87,17 @@ void SimpleOffload::BeginOfRunAction(G4Run const*)

//---------------------------------------------------------------------------//
/*!
* Send Celeritas the event ID.
* Send Celeritas the event ID and reseed the Celeritas RNG.
*/
void SimpleOffload::BeginOfEventAction(G4Event const* event)
{
if (!*this)
return;

// Set event ID in local transporter
// Set event ID in local transporter and reseed RNG for reproducibility
ExceptionConverter call_g4exception{"celer0002"};
CELER_TRY_HANDLE(local_->SetEventId(event->GetEventID()), call_g4exception);
CELER_TRY_HANDLE(local_->InitializeEvent(event->GetEventID()),
call_g4exception);
}

//---------------------------------------------------------------------------//
Expand Down
1 change: 1 addition & 0 deletions src/celeritas/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -289,6 +289,7 @@ celeritas_polysource(global/alongstep/AlongStepUniformMscAction)
celeritas_polysource(global/alongstep/AlongStepRZMapFieldMscAction)
celeritas_polysource(phys/detail/DiscreteSelectAction)
celeritas_polysource(phys/detail/PreStepAction)
celeritas_polysource(random/RngReseed)
celeritas_polysource(random/detail/CuHipRngStateInit)
celeritas_polysource(track/detail/TrackInitAlgorithms)
celeritas_polysource(track/detail/TrackSortUtils)
Expand Down
16 changes: 16 additions & 0 deletions src/celeritas/global/Stepper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
#include "corecel/sys/ScopedProfiling.hh"
#include "orange/OrangeData.hh"
#include "celeritas/Types.hh"
#include "celeritas/random/RngParams.hh"
#include "celeritas/random/RngReseed.hh"
#include "celeritas/track/TrackInitParams.hh"

#include "CoreParams.hh"
Expand Down Expand Up @@ -103,6 +105,20 @@ auto Stepper<M>::operator()(SpanConstPrimary primaries) -> result_type
return (*this)();
}

//---------------------------------------------------------------------------//
/*!
* Reseed the RNGs at the start of an event for "strong" reproducibility.
*
* This reinitializes the RNG states using a single seed and unique subsequence
* for each thread. It ensures that given an event number, that event can be
* reproduced.
*/
template<MemSpace M>
void Stepper<M>::reseed(EventId event_id)
{
reseed_rng(get_ref<M>(*params_->rng()), state_.ref().rng, event_id.get());
}

//---------------------------------------------------------------------------//
// EXPLICIT INSTANTIATION
//---------------------------------------------------------------------------//
Expand Down
6 changes: 6 additions & 0 deletions src/celeritas/global/Stepper.hh
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,9 @@ class StepperInterface
// Transport existing states and these new primaries
virtual StepperResult operator()(SpanConstPrimary primaries) = 0;

// Reseed the RNGs at the start of an event for reproducibility
virtual void reseed(EventId event_id) = 0;

//! Get action sequence for timing diagnostics
virtual ActionSequence const& actions() const = 0;

Expand Down Expand Up @@ -139,6 +142,9 @@ class Stepper final : public StepperInterface
// Transport existing states and these new primaries
StepperResult operator()(SpanConstPrimary primaries) final;

// Reseed the RNGs at the start of an event for reproducibility
void reseed(EventId event_id) final;

//! Get action sequence for timing diagnostics
ActionSequence const& actions() const final { return *actions_; }

Expand Down
40 changes: 40 additions & 0 deletions src/celeritas/random/RngReseed.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
//----------------------------------*-C++-*----------------------------------//
// Copyright 2023 UT-Battelle, LLC, and other Celeritas developers.
// See the top-level COPYRIGHT file for details.
// SPDX-License-Identifier: (Apache-2.0 OR MIT)
//---------------------------------------------------------------------------//
//! \file celeritas/random/RngReseed.cc
//---------------------------------------------------------------------------//
#include "RngReseed.hh"

#include "corecel/cont/Range.hh"
#include "corecel/sys/ThreadId.hh"

#include "RngEngine.hh"

namespace celeritas
{
//---------------------------------------------------------------------------//
/*!
* Reinitialize the RNG states on host at the start of an event.
*
* Each thread's state is initialized using same seed and skipped ahead a
* different number of subsequences so the sequences on different threads will
* not have statistically correlated values.
*/
void reseed_rng(HostCRef<RngParamsData> const& params,
HostRef<RngStateData> const& state,
size_type event_id)
{
for (auto tid : range(TrackSlotId{state.size()}))
{
RngEngine::Initializer_t init;
init.seed = params.seed;
init.subsequence = event_id * state.size() + tid.get();
RngEngine engine(params, state, tid);
engine = init;
}
}

//---------------------------------------------------------------------------//
} // namespace celeritas
67 changes: 67 additions & 0 deletions src/celeritas/random/RngReseed.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
//---------------------------------*-CUDA-*----------------------------------//
// Copyright 2023 UT-Battelle, LLC, and other Celeritas developers.
// See the top-level COPYRIGHT file for details.
// SPDX-License-Identifier: (Apache-2.0 OR MIT)
//---------------------------------------------------------------------------//
//! \file celeritas/random/RngReseed.cu
//---------------------------------------------------------------------------//
#include "RngReseed.hh"

#include "corecel/device_runtime_api.h"
#include "corecel/Assert.hh"
#include "corecel/sys/Device.hh"
#include "corecel/sys/KernelParamCalculator.device.hh"

#include "RngEngine.hh"

namespace celeritas
{
namespace
{
//---------------------------------------------------------------------------//
// KERNELS
//---------------------------------------------------------------------------//
/*!
* Reinitialize the RNG states on device at the start of an event.
*/
__global__ void reseed_rng_kernel(DeviceCRef<RngParamsData> const params,
DeviceRef<RngStateData> const state,
size_type event_id)
{
auto tid = TrackSlotId{
celeritas::KernelParamCalculator::thread_id().unchecked_get()};
if (tid.get() < state.size())
{
TrackSlotId tsid{tid.unchecked_get()};
RngEngine::Initializer_t init;
init.seed = params.seed;
init.subsequence = event_id * state.size() + tsid.get();
RngEngine rng(params, state, tsid);
rng = init;
}
}

//---------------------------------------------------------------------------//
} // namespace

//---------------------------------------------------------------------------//
// KERNEL INTERFACE
//---------------------------------------------------------------------------//
/*!
* Reinitialize the RNG states on device at the start of an event.
*
* Each thread's state is initialized using same seed and skipped ahead a
* different number of subsequences so the sequences on different threads will
* not have statistically correlated values.
*/
void reseed_rng(DeviceCRef<RngParamsData> const& params,
DeviceRef<RngStateData> const& state,
size_type event_id)
{
CELER_EXPECT(state);
CELER_EXPECT(params);
CELER_LAUNCH_KERNEL(reseed_rng, state.size(), 0, params, state, event_id);
}

//---------------------------------------------------------------------------//
} // namespace celeritas
43 changes: 43 additions & 0 deletions src/celeritas/random/RngReseed.hh
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
//----------------------------------*-C++-*----------------------------------//
// Copyright 2023 UT-Battelle, LLC, and other Celeritas developers.
// See the top-level COPYRIGHT file for details.
// SPDX-License-Identifier: (Apache-2.0 OR MIT)
//---------------------------------------------------------------------------//
//! \file celeritas/random/RngReseed.hh
//---------------------------------------------------------------------------//
#pragma once

#include "corecel/Assert.hh"
#include "corecel/Macros.hh"
#include "corecel/Types.hh"
#include "corecel/data/Collection.hh"

#include "RngData.hh"

namespace celeritas
{
//---------------------------------------------------------------------------//
// Reinitialize the RNG states on host/device at the start of an event
void reseed_rng(DeviceCRef<RngParamsData> const&,
DeviceRef<RngStateData> const&,
size_type);

void reseed_rng(HostCRef<RngParamsData> const&,
HostRef<RngStateData> const&,
size_type);

#if !CELER_USE_DEVICE
//---------------------------------------------------------------------------//
/*!
* Reinitialize the RNG states on device at the start of an event.
*/
inline void reseed_rng(DeviceCRef<RngParamsData> const&,
DeviceRef<RngStateData> const&,
size_type)
{
CELER_ASSERT_UNREACHABLE();
}
#endif

//---------------------------------------------------------------------------//
} // namespace celeritas
2 changes: 1 addition & 1 deletion src/celeritas/random/XorwowRngData.hh
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ struct XorwowRngParamsData
*/
struct XorwowRngInitializer
{
ull_int seed{0};
Array<unsigned int, 1> seed{0};
ull_int subsequence{0};
ull_int offset{0};
};
Expand Down
2 changes: 1 addition & 1 deletion src/celeritas/random/XorwowRngEngine.hh
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ XorwowRngEngine::operator=(Initializer_t const& init)
auto& s = state_->xorstate;

// Initialize the state from the seed
SplitMix64 rng{init.seed};
SplitMix64 rng{init.seed[0]};
uint64_t seed = rng();
s[0] = static_cast<uint_t>(seed);
s[1] = static_cast<uint_t>(seed >> 32);
Expand Down
1 change: 1 addition & 0 deletions test/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -557,6 +557,7 @@ set(CELERITASTEST_PREFIX celeritas/random)

celeritas_add_device_test(celeritas/random/RngEngine)
celeritas_add_test(celeritas/random/Selector.test.cc)
celeritas_add_test(celeritas/random/RngReseed.test.cc)
celeritas_add_test(celeritas/random/XorwowRngEngine.test.cc GPU)

celeritas_add_test(celeritas/random/distribution/BernoulliDistribution.test.cc)
Expand Down
Loading