Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a simple spinlock mutex type #607

Merged
merged 3 commits into from
Jun 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 48 additions & 0 deletions device/common/include/traccc/device/impl/mutex.ipp
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
/**
* traccc library, part of the ACTS project (R&D line)
*
* (c) 2024 CERN for the benefit of the ACTS project
*
* Mozilla Public License Version 2.0
*/

#pragma once

#include <cassert>
#include <vecmem/memory/device_atomic_ref.hpp>

namespace traccc::device {
template <typename T>
mutex<T>::mutex(T& p) : m_atomic(p) {}

template <typename T>
mutex<T>::mutex(const vecmem::device_atomic_ref<T>& r) : m_atomic(r) {}

template <typename T>
void mutex<T>::lock() {
while (!try_lock())
;
}

template <typename T>
bool mutex<T>::try_lock() {
assert(!m_is_locked);

T false_v = static_cast<T>(false);
bool s = m_atomic.compare_exchange_strong(false_v, static_cast<T>(true),
vecmem::memory_order::acquire);

#ifndef NDEBUG
m_is_locked |= s;
#endif

return s;
}

template <typename T>
void mutex<T>::unlock() {
assert(m_is_locked);

m_atomic.store(static_cast<T>(false), vecmem::memory_order::release);
}
} // namespace traccc::device
65 changes: 65 additions & 0 deletions device/common/include/traccc/device/impl/unique_lock.ipp
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
/**
* traccc library, part of the ACTS project (R&D line)
*
* (c) 2024 CERN for the benefit of the ACTS project
*
* Mozilla Public License Version 2.0
*/

#pragma once

namespace traccc::device {
template <typename Mutex>
TRACCC_HOST_DEVICE unique_lock<Mutex>::unique_lock(mutex_type& m,
std::defer_lock_t) {
m_mutex_ptr = &m;
m_owns_lock = false;
}

template <typename Mutex>
TRACCC_HOST_DEVICE unique_lock<Mutex>::unique_lock(mutex_type& m,
std::try_to_lock_t) {
m_mutex_ptr = &m;
m_owns_lock = m_mutex_ptr->try_lock();
}

template <typename Mutex>
TRACCC_HOST_DEVICE unique_lock<Mutex>::unique_lock(mutex_type& m,
std::adopt_lock_t) {
m_mutex_ptr = &m;
m_owns_lock = true;
}

template <typename Mutex>
TRACCC_HOST_DEVICE unique_lock<Mutex>::~unique_lock() {
if (m_owns_lock) {
m_mutex_ptr->unlock();
}
}

template <typename Mutex>
TRACCC_HOST_DEVICE void unique_lock<Mutex>::lock() {
assert(!m_owns_lock);
m_mutex_ptr->lock();
m_owns_lock = true;
}

template <typename Mutex>
TRACCC_HOST_DEVICE bool unique_lock<Mutex>::try_lock() {
assert(!m_owns_lock);
m_owns_lock = m_mutex_ptr->try_lock();
return m_owns_lock;
}

template <typename Mutex>
TRACCC_HOST_DEVICE void unique_lock<Mutex>::unlock() {
assert(m_owns_lock);
m_mutex_ptr->unlock();
m_owns_lock = false;
}

template <typename Mutex>
TRACCC_HOST_DEVICE bool unique_lock<Mutex>::owns_lock() const {
return m_owns_lock;
}
} // namespace traccc::device
77 changes: 77 additions & 0 deletions device/common/include/traccc/device/mutex.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
/**
* traccc library, part of the ACTS project (R&D line)
*
* (c) 2024 CERN for the benefit of the ACTS project
*
* Mozilla Public License Version 2.0
*/

#pragma once

#include <cstdint>
#include <vecmem/memory/device_atomic_ref.hpp>

#include "traccc/definitions/qualifiers.hpp"

namespace traccc::device {
/*
* A mutex object over type T.
*
* @warning This class assumes that the value is written to _only_ by mutex
* class objects. Writing to the given pointers or atomic references in any
* other way is undefined behaviour. Furthermore, it is assumed that the
* initial value of the underlying pointer is false.
*
* @warning This is a spinlock. Do not use when more efficient implementations
* are available.
*/
template <typename T = uint32_t>
class mutex {
public:
/*
* Construct a mutex from a pointer.
*/
TRACCC_HOST_DEVICE
mutex(T &);

/*
* Construct a mutex from a vecmem atomic reference.
*/
TRACCC_HOST_DEVICE
mutex(const vecmem::device_atomic_ref<T> &);

/*
* Attempt to acquire a lock on the mutex. This method spins until a lock
* is acquired.
*
* @warning On lockstep devices, only one thread per thread group (e.g.
* warp) should call this function!
*/
TRACCC_HOST_DEVICE
void lock();

/*
* Try to acquire a lock on the mutex, returning whether the operation
* succeeded or not. */
TRACCC_HOST_DEVICE
bool try_lock();

/*
* Unlock the mutex.
*
* @warning Using this method on a mutex that is not locked is undefined
* behaviour.
*/
TRACCC_HOST_DEVICE
void unlock();

private:
const vecmem::device_atomic_ref<T> m_atomic;

#ifndef NDEBUG
bool m_is_locked = false;
#endif
};
} // namespace traccc::device

#include "impl/mutex.ipp"
75 changes: 75 additions & 0 deletions device/common/include/traccc/device/unique_lock.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
/**
* traccc library, part of the ACTS project (R&D line)
*
* (c) 2024 CERN for the benefit of the ACTS project
*
* Mozilla Public License Version 2.0
*/

#pragma once

#include <mutex>

#include "traccc/definitions/qualifiers.hpp"

namespace traccc::device {
template <typename Mutex>
class unique_lock {
public:
using mutex_type = Mutex;

/*
* Construct a unique lock without locking.
*/
TRACCC_HOST_DEVICE
unique_lock(mutex_type& m, std::defer_lock_t);

/*
* Construct a unique lock, attempting to lock it.
*/
TRACCC_HOST_DEVICE
unique_lock(mutex_type& m, std::try_to_lock_t);

/*
* Construct a unique lock which was previously locked.
*/
TRACCC_HOST_DEVICE
unique_lock(mutex_type& m, std::adopt_lock_t);

/*
* Destroy a lock, freeing the underlying mutex.
*/
TRACCC_HOST_DEVICE
~unique_lock();

/*
* Lock the lock, blocking until the operation succeeds.
*/
TRACCC_HOST_DEVICE
void lock();

/*
* Try to lock the lock without blocking.
*/
TRACCC_HOST_DEVICE
bool try_lock();

/*
* Explicitly lock the underlying lock.
*/
TRACCC_HOST_DEVICE
void unlock();

/*
* Check if the lock is locked by this object.
*/
TRACCC_HOST_DEVICE
bool owns_lock() const;

private:
mutex_type* m_mutex_ptr = nullptr;
bool m_owns_lock;
};
} // namespace traccc::device

#include "impl/unique_lock.ipp"
2 changes: 2 additions & 0 deletions tests/cuda/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ traccc_add_test(
test_thrust.cu
test_sync.cu
test_array_wrapper.cu
test_mutex.cu
test_unique_lock.cu

LINK_LIBRARIES
CUDA::cudart
Expand Down
48 changes: 48 additions & 0 deletions tests/cuda/test_mutex.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
/**
* traccc library, part of the ACTS project (R&D line)
*
* (c) 2024 CERN for the benefit of the ACTS project
*
* Mozilla Public License Version 2.0
*/

#include <gtest/gtest.h>

#include <vecmem/memory/cuda/managed_memory_resource.hpp>
#include <vecmem/memory/unique_ptr.hpp>

#include "../../device/cuda/src/utils/cuda_error_handling.hpp"
#include "traccc/device/mutex.hpp"

__global__ void mutex_add_kernel(uint32_t *out, uint32_t *lock) {
traccc::device::mutex m(*lock);

if (threadIdx.x == 0) {
m.lock();
uint32_t tmp = *out;
tmp += 1;
*out = tmp;
m.unlock();
}
}

TEST(CUDAMutex, MassAdditionKernel) {
vecmem::cuda::managed_memory_resource mr;

vecmem::unique_alloc_ptr<uint32_t> out =
vecmem::make_unique_alloc<uint32_t>(mr);
vecmem::unique_alloc_ptr<uint32_t> lock =
vecmem::make_unique_alloc<uint32_t>(mr);

TRACCC_CUDA_ERROR_CHECK(cudaMemset(lock.get(), 0, sizeof(uint32_t)));

uint32_t n_blocks = 262144;
uint32_t n_threads = 32;

mutex_add_kernel<<<n_blocks, n_threads>>>(out.get(), lock.get());

TRACCC_CUDA_ERROR_CHECK(cudaGetLastError());
TRACCC_CUDA_ERROR_CHECK(cudaDeviceSynchronize());

EXPECT_EQ(n_blocks, *out.get());
}
Loading
Loading