Skip to content

Commit

Permalink
Move device reactivation to corecel
Browse files Browse the repository at this point in the history
  • Loading branch information
sethrj committed Jun 1, 2023
1 parent dae5633 commit e71fa0c
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 6 deletions.
7 changes: 1 addition & 6 deletions app/celer-sim/celer-sim.cc
Original file line number Diff line number Diff line change
Expand Up @@ -127,12 +127,7 @@ void run(std::istream* is, std::shared_ptr<OutputRegistry> output)
#endif
for (size_type event = 0; event < run_stream.num_events(); ++event)
{
if (device())
{
// See
// https://developer.nvidia.com/blog/cuda-pro-tip-always-set-current-device-avoid-multithreading-bugs/
CELER_DEVICE_CALL_PREFIX(SetDevice(device().device_id()));
}
activate_device_local();

// Run a single event on a single thread
CELER_TRY_HANDLE(result.events[event] = run_stream(
Expand Down
17 changes: 17 additions & 0 deletions src/corecel/sys/Device.cc
Original file line number Diff line number Diff line change
Expand Up @@ -329,6 +329,23 @@ void activate_device(MpiCommunicator const& comm)
}
}

//---------------------------------------------------------------------------//
/*!
* Call cudaSetDevice using the existing device, for thread-local safety.
*
* See
* https://developer.nvidia.com/blog/cuda-pro-tip-always-set-current-device-avoid-multithreading-bugs
*
* \pre activate_device was called to set \c device()
*/
void activate_device_local()
{
if (device())
{
CELER_DEVICE_CALL_PREFIX(SetDevice(device().device_id()));
}
}

//---------------------------------------------------------------------------//
/*!
* Print device info.
Expand Down
3 changes: 3 additions & 0 deletions src/corecel/sys/Device.hh
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,9 @@ void activate_device();
// Initialize a device in a round-robin fashion from a communicator.
void activate_device(MpiCommunicator const&);

// Call cudaSetDevice using the existing device, for thread-local safety
void activate_device_local();

// Print device info
std::ostream& operator<<(std::ostream&, Device const&);

Expand Down

0 comments on commit e71fa0c

Please sign in to comment.