diff --git a/src/modm/processing/fiber/barrier.hpp b/src/modm/processing/fiber/barrier.hpp new file mode 100644 index 0000000000..110497e18d --- /dev/null +++ b/src/modm/processing/fiber/barrier.hpp @@ -0,0 +1,80 @@ +/* + * Copyright (c) 2023, Niklas Hauser + * + * This file is part of the modm project. + * + * This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. + */ +// ---------------------------------------------------------------------------- + +#pragma once + +#include "fiber.hpp" + +namespace modm::fiber +{ + +/// @ingroup modm_processing_fiber +/// @{ + +/// Implements the `std::barrier` interface for fibers. +/// @see https://en.cppreference.com/w/cpp/thread/barrier +template< class CompletionFunction = decltype([](){}) > +class barrier +{ + barrier(const barrier&) = delete; + barrier& operator=(const barrier&) = delete; + using count_t = uint16_t; + + const CompletionFunction completion; + volatile count_t expected; + volatile count_t count; + volatile count_t sequence{}; +public: + using arrival_token = count_t; + constexpr explicit barrier(std::ptrdiff_t expected, + CompletionFunction f = CompletionFunction()) + : completion(std::move(f)), expected(expected), count(expected) {} + + [[nodiscard]] static constexpr std::ptrdiff_t + max() { return count_t(-1); } + + [[nodiscard]] arrival_token + arrive(std::ptrdiff_t n=1) + { + count_t last_arrival{sequence}; + count -= n; + if (count == 0) + { + count = expected; + sequence++; + completion(); + } + return last_arrival; + } + + void + wait(arrival_token&& arrival) const + { + while (arrival == sequence) modm::this_fiber::yield(); + } + + void + arrive_and_wait() + { + wait(arrive()); + } + + void + arrive_and_drop() + { + expected--; + arrive(); + } +}; + +/// @} + +} diff --git a/src/modm/processing/fiber/condition_variable.hpp b/src/modm/processing/fiber/condition_variable.hpp new file mode 100644 index 0000000000..0acf8246f6 --- /dev/null +++ b/src/modm/processing/fiber/condition_variable.hpp @@ -0,0 +1,161 @@ +/* + * Copyright (c) 2023, Niklas Hauser + * + * This file is part of the modm project. + * + * This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. + */ +// ---------------------------------------------------------------------------- + +#pragma once + +#include "fiber.hpp" +#include +#include + +namespace modm::fiber +{ + +/// @ingroup modm_processing_fiber +/// @{ + +/// Implements the `std::condition_variable_any` interface for fibers. +/// @see https://en.cppreference.com/w/cpp/thread/condition_variable +class condition_variable_any +{ + condition_variable_any(const condition_variable_any&) = delete; + condition_variable_any& operator=(const condition_variable_any&) = delete; + + std::atomic sequence{}; +public: + constexpr condition_variable_any() = default; + + void inline + notify_one() + { + sequence++; + } + + void inline + notify_any() + { + sequence++; + } + + + template< class Lock > + void + wait(Lock& lock) + { + lock.unlock(); + const auto my_sequence = sequence.load(); + while(my_sequence == sequence.load()) modm::this_fiber::yield(); + lock.lock(); + } + + template< class Lock, class Predicate > + void + wait(Lock& lock, Predicate&& pred) + { + while (not pred()) wait(lock); + } + + template< class Lock, class Predicate > + bool + wait(Lock& lock, std::stop_token stoken, Predicate pred) + { + while (not stoken.stop_requested()) + { + if (pred()) return true; + wait(lock); + } + return pred(); + } + + + template< class Lock, class Rep, class Period > + std::cv_status + wait_for(Lock& lock, std::chrono::duration rel_time) + { + lock.unlock(); + const auto condition = [this, my_sequence = sequence.load()]() + { return my_sequence != sequence.load(); }; + const bool result = this_fiber::poll_for(rel_time, condition); + lock.lock(); + return result ? std::cv_status::no_timeout : std::cv_status::timeout; + } + + template< class Lock, class Rep, class Period, class Predicate > + bool + wait_for(Lock& lock, std::chrono::duration rel_time, Predicate&& pred) + { + while (not pred()) + { + if (wait_for(lock, rel_time) == std::cv_status::timeout) + return pred(); + } + return true; + } + + template< class Lock, class Rep, class Period, class Predicate > + bool + wait_for(Lock& lock, std::stop_token stoken, + std::chrono::duration rel_time, Predicate&& pred) + { + while (not stoken.stop_requested()) + { + if (pred()) return true; + if (wait_for(lock, rel_time) == std::cv_status::timeout) + return pred(); + } + return pred(); + } + + + template< class Lock, class Clock, class Duration > + std::cv_status + wait_until(Lock& lock, std::chrono::time_point abs_time) + { + lock.unlock(); + const auto condition = [this, my_sequence = sequence.load()]() + { return my_sequence != sequence.load(); }; + const bool result = this_fiber::poll_until(abs_time, condition); + lock.lock(); + return result ? std::cv_status::no_timeout : std::cv_status::timeout; + } + + template< class Lock, class Clock, class Duration, class Predicate > + bool + wait_until(Lock& lock, std::chrono::time_point abs_time, Predicate&& pred) + { + while (not pred()) + { + if (wait_until(lock, abs_time) == std::cv_status::timeout) + return pred(); + } + return true; + } + + template< class Lock, class Clock, class Duration, class Predicate > + bool + wait_until(Lock& lock, std::stop_token stoken, + std::chrono::time_point abs_time, Predicate&& pred) + { + while (not stoken.stop_requested()) + { + if (pred()) return true; + if (wait_until(lock, abs_time) == std::cv_status::timeout) + return pred(); + } + return pred(); + } +}; + +// There is no specialization for std::unique_lock. +using condition_variable = condition_variable_any; + +/// @} + +} diff --git a/src/modm/processing/fiber/latch.hpp b/src/modm/processing/fiber/latch.hpp new file mode 100644 index 0000000000..b6b70a5479 --- /dev/null +++ b/src/modm/processing/fiber/latch.hpp @@ -0,0 +1,68 @@ +/* + * Copyright (c) 2023, Niklas Hauser + * + * This file is part of the modm project. + * + * This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. + */ +// ---------------------------------------------------------------------------- + +#pragma once + +#include "fiber.hpp" +#include + +namespace modm::fiber +{ + +/// @ingroup modm_processing_fiber +/// @{ + +/// Implements the `std::latch` interface for fibers. +/// @see https://en.cppreference.com/w/cpp/thread/latch +class latch +{ + latch(const latch&) = delete; + latch& operator=(const latch&) = delete; + + using count_t = uint16_t; + std::atomic count; +public: + constexpr explicit + latch(std::ptrdiff_t expected) + : count(expected) {} + + [[nodiscard]] static constexpr std::ptrdiff_t + max() { return count_t(-1); } + + void inline + count_down(std::ptrdiff_t n=1) + { + count -= n; + } + + [[nodiscard]] bool inline + try_wait() const + { + return count.load() == 0; + } + + void inline + wait() const + { + while(not try_wait()) modm::this_fiber::yield(); + } + + void inline + arrive_and_wait(std::ptrdiff_t n=1) + { + count_down(n); + wait(); + } +}; + +/// @} + +} diff --git a/src/modm/processing/fiber/module.lb b/src/modm/processing/fiber/module.lb index 9289a2608f..8bc7bc7395 100644 --- a/src/modm/processing/fiber/module.lb +++ b/src/modm/processing/fiber/module.lb @@ -21,7 +21,7 @@ def is_enabled(env): not env.has_module(":processing:protothread") def prepare(module, options): - module.depends(":processing:timer") + module.depends(":processing:timer", ":architecture:atomic") module.add_query( EnvironmentQuery(name="__enabled", factory=is_enabled)) @@ -77,3 +77,10 @@ def build(env): env.copy("task.hpp") env.copy("functions.hpp") env.copy("fiber.hpp") + + env.copy("mutex.hpp") + env.copy("shared_mutex.hpp") + env.copy("semaphore.hpp") + env.copy("latch.hpp") + env.copy("barrier.hpp") + env.copy("condition_variable.hpp") diff --git a/src/modm/processing/fiber/mutex.hpp b/src/modm/processing/fiber/mutex.hpp new file mode 100644 index 0000000000..6b6e502832 --- /dev/null +++ b/src/modm/processing/fiber/mutex.hpp @@ -0,0 +1,150 @@ +/* + * Copyright (c) 2023, Niklas Hauser + * + * This file is part of the modm project. + * + * This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. + */ +// ---------------------------------------------------------------------------- + +#pragma once + +#include "fiber.hpp" +#include +#include + +namespace modm::fiber +{ + +/// @ingroup modm_processing_fiber +/// @{ + +/// Implements the `std::mutex` interface for fibers. +/// @see https://en.cppreference.com/w/cpp/thread/mutex +class mutex +{ + mutex(const mutex&) = delete; + mutex& operator=(const mutex&) = delete; + + std::atomic_bool locked{false}; +public: + constexpr mutex() = default; + + [[nodiscard]] bool inline + try_lock() + { + bool expected{false}; + return locked.compare_exchange_strong(expected, true); + } + + void inline + lock() + { + while(not try_lock()) modm::this_fiber::yield(); + } + + void inline + unlock() + { + locked = false; + } +}; + +/// Implements the `std::timed_mutex` interface for fibers. +/// @see https://en.cppreference.com/w/cpp/thread/timed_mutex +class timed_mutex : public mutex +{ +public: + template< typename Rep, typename Period > + [[nodiscard]] bool + try_lock_for(std::chrono::duration sleep_duration) + { + return this_fiber::poll_for(sleep_duration, [this](){ return try_lock(); }); + } + + template< class Clock, class Duration > + [[nodiscard]] bool + try_lock_until(std::chrono::time_point sleep_time) + { + return this_fiber::poll_until(sleep_time, [this](){ return try_lock(); }); + } +}; + +/// Implements the `std::recursive_mutex` interface for fibers. +/// @see https://en.cppreference.com/w/cpp/thread/recursive_mutex +class recursive_mutex +{ + recursive_mutex(const recursive_mutex&) = delete; + recursive_mutex& operator=(const recursive_mutex&) = delete; + using count_t = uint16_t; + + static constexpr fiber::id NoOwner{fiber::id(-1)}; + volatile fiber::id owner{NoOwner}; + static constexpr count_t countMax{count_t(-1)}; + volatile count_t count{1}; + +public: + constexpr recursive_mutex() = default; + + [[nodiscard]] bool inline + try_lock() + { + const auto id = modm::this_fiber::get_id(); + { + modm::atomic::Lock _; + if (owner == NoOwner) { + owner = id; + // count = 1; is implicit + return true; + } + if (owner == id and count < countMax) { + count++; + return true; + } + } + return false; + } + + void inline + lock() + { + while(not try_lock()) modm::this_fiber::yield(); + } + + void inline + unlock() + { + modm::atomic::Lock _; + if (count > 1) count--; + else { + // count = 1; is implicit + owner = NoOwner; + } + } +}; + +/// Implements the `std::timed_recursive_mutex` interface for fibers. +/// @see https://en.cppreference.com/w/cpp/thread/recursive_mutex +class timed_recursive_mutex : public recursive_mutex +{ +public: + template< typename Rep, typename Period > + [[nodiscard]] bool + try_lock_for(std::chrono::duration sleep_duration) + { + return this_fiber::poll_for(sleep_duration, [this](){ return try_lock(); }); + } + + template< class Clock, class Duration > + [[nodiscard]] bool + try_lock_until(std::chrono::time_point sleep_time) + { + return this_fiber::poll_until(sleep_time, [this](){ return try_lock(); }); + } +}; + +/// @} + +} diff --git a/src/modm/processing/fiber/semaphore.hpp b/src/modm/processing/fiber/semaphore.hpp new file mode 100644 index 0000000000..95d84eb158 --- /dev/null +++ b/src/modm/processing/fiber/semaphore.hpp @@ -0,0 +1,84 @@ +/* + * Copyright (c) 2023, Niklas Hauser + * + * This file is part of the modm project. + * + * This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. + */ +// ---------------------------------------------------------------------------- + +#pragma once + +#include "fiber.hpp" +#include + +namespace modm::fiber +{ + +/// @ingroup modm_processing_fiber +/// @{ + +/// Implements the `std::counting_semaphore` interface for fibers. +/// @see https://en.cppreference.com/w/cpp/thread/counting_semaphore +template< std::ptrdiff_t LeastMaxValue = 255 > +class counting_semaphore +{ + counting_semaphore(const counting_semaphore&) = delete; + counting_semaphore& operator=(const counting_semaphore&) = delete; + + static_assert(LeastMaxValue <= uint16_t(-1), "counting_semaphore uses a 16-bit counter!"); + using count_t = std::conditional_t; + std::atomic count{}; +public: + constexpr explicit + counting_semaphore(std::ptrdiff_t desired) + : count(desired) {} + + [[nodiscard]] static constexpr std::ptrdiff_t + max() { return count_t(-1); } + + [[nodiscard]] bool inline + try_acquire() + { + count_t current = count.load(); + do if (current == 0) return false; + while(count.compare_exchange_weak(current, current - 1) == false); + return true; + } + + void inline + acquire() + { + while(not try_acquire()) modm::this_fiber::yield(); + } + + void inline + release() + { + count++; + } + + template< typename Rep, typename Period > + [[nodiscard]] bool + try_acquire_for(std::chrono::duration sleep_duration) + { + return this_fiber::poll_for(sleep_duration, [this](){ return try_acquire(); }); + } + + template< class Clock, class Duration > + [[nodiscard]] bool + try_acquire_until(std::chrono::time_point sleep_time) + { + return this_fiber::poll_until(sleep_time, [this](){ return try_acquire(); }); + } +}; + +/// Implements the `std::binary_semaphore` interface for fibers. +/// @see https://en.cppreference.com/w/cpp/thread/counting_semaphore +using binary_semaphore = counting_semaphore<1>; + +/// @} + +} diff --git a/src/modm/processing/fiber/shared_mutex.hpp b/src/modm/processing/fiber/shared_mutex.hpp new file mode 100644 index 0000000000..9ab765f560 --- /dev/null +++ b/src/modm/processing/fiber/shared_mutex.hpp @@ -0,0 +1,114 @@ +/* + * Copyright (c) 2023, Niklas Hauser + * + * This file is part of the modm project. + * + * This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. + */ +// ---------------------------------------------------------------------------- + +#pragma once + +#include "fiber.hpp" +#include + +namespace modm::fiber +{ + +/// @ingroup modm_processing_fiber +/// @{ + +/// Implements the `std::shared_mutex` interface for fibers. +/// @see https://en.cppreference.com/w/cpp/thread/shared_mutex +class shared_mutex +{ + shared_mutex(const shared_mutex&) = delete; + shared_mutex& operator=(const shared_mutex&) = delete; + + static constexpr fiber::id NoOwner{fiber::id(-1)}; + static constexpr fiber::id SharedOwner{fiber::id(-2)}; + std::atomic owner{NoOwner}; +public: + constexpr shared_mutex() = default; + + [[nodiscard]] bool inline + try_lock() + { + const fiber::id new_owner = modm::this_fiber::get_id(); + fiber::id expected{NoOwner}; + return owner.compare_exchange_strong(expected, new_owner); + } + + void inline + lock() + { + while(not try_lock()) modm::this_fiber::yield(); + } + + void inline + unlock() + { + owner = NoOwner; + } + + [[nodiscard]] bool inline + try_lock_shared() + { + fiber::id current = owner.load(); + do if (current < SharedOwner) return false; + while(owner.compare_exchange_weak(current, SharedOwner) == false); + return true; + } + + void inline + lock_shared() + { + while(not try_lock_shared()) modm::this_fiber::yield(); + } + + void inline + unlock_shared() + { + owner = NoOwner; + } +}; + +/// Implements the `std::timed_shared_mutex` interface for fibers. +/// @see https://en.cppreference.com/w/cpp/thread/timed_shared_mutex +class timed_shared_mutex : public shared_mutex +{ +public: + template< typename Rep, typename Period > + [[nodiscard]] bool + try_lock_for(std::chrono::duration sleep_duration) + { + return this_fiber::poll_for(sleep_duration, [this](){ return try_lock(); }); + } + + template< class Clock, class Duration > + [[nodiscard]] bool + try_lock_until(std::chrono::time_point sleep_time) + { + return this_fiber::poll_until(sleep_time, [this](){ return try_lock(); }); + } + + template< typename Rep, typename Period > + [[nodiscard]] bool + try_lock_shared_for(std::chrono::duration sleep_duration) + { + return this_fiber::poll_for(sleep_duration, [this](){ return try_lock_shared(); }); + } + + template< class Clock, class Duration > + [[nodiscard]] bool + try_lock_shared_until(std::chrono::time_point sleep_time) + { + return this_fiber::poll_until(sleep_time, [this](){ return try_lock_shared(); }); + } +}; + +/// @} + +} diff --git a/test/modm/processing/fiber/fiber_mutex_test.cpp b/test/modm/processing/fiber/fiber_mutex_test.cpp new file mode 100644 index 0000000000..70a711f498 --- /dev/null +++ b/test/modm/processing/fiber/fiber_mutex_test.cpp @@ -0,0 +1,388 @@ +/* + * Copyright (c) 2024, Niklas Hauser + * + * This file is part of the modm project. + * + * This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. + */ +// ---------------------------------------------------------------------------- + +#include "fiber_mutex_test.hpp" +#include "shared.hpp" +#include +#include + +enum State : uint8_t +{ + INVALID, + + F1_START, + F1_LOCK1, + F1_LOCK2, + F1_LOCK3, + F1_LOCK4, + F1_END, + + F2_START, + F2_UNLOCK1, + F2_UNLOCK2, + F2_END, + + F3_START, + F3_LOCK1, + F3_LOCK2, + F3_UNLOCK1, + F3_UNLOCK2, + F3_UNLOCK3, + F3_UNLOCK4, + F3_END, + + F4_START, + F4_LOCK1, + F4_LOCK2, + F4_LOCK3, + F4_UNLOCK1, + F4_UNLOCK2, + F4_UNLOCK3, + F4_END, + + F5_START, + F5_LOCK1, + F5_LOCK2, + F5_UNLOCK1, + F5_UNLOCK2, + F5_END, + + F6_START, + F6_LOCK1, + F6_LOCK2, + F6_UNLOCK1, + F6_UNLOCK2, + F6_END, +}; + +// ================================== MUTEX =================================== +static modm::fiber::mutex mtx; + +static void +f1() +{ + ADD_STATE(F1_START); + TEST_ASSERT_TRUE(mtx.try_lock()); + TEST_ASSERT_FALSE(mtx.try_lock()); + TEST_ASSERT_FALSE(mtx.try_lock()); + mtx.unlock(); + mtx.unlock(); + + ADD_STATE(F1_LOCK1); + mtx.lock(); // should not yield + ADD_STATE(F1_LOCK2); + mtx.lock(); // yields + mtx.unlock(); + mtx.unlock(); + ADD_STATE(F1_LOCK3); + mtx.lock(); // should not yield + ADD_STATE(F1_LOCK4); + mtx.lock(); // yields again + + ADD_STATE(F1_END); +} + +static void +f2() +{ + ADD_STATE(F2_START); + // let f1 wait for a while + modm::this_fiber::yield(); + modm::this_fiber::yield(); + modm::this_fiber::yield(); + ADD_STATE(F2_UNLOCK1); + mtx.unlock(); + modm::this_fiber::yield(); + ADD_STATE(F2_UNLOCK2); + mtx.unlock(); + modm::this_fiber::yield(); + ADD_STATE(F2_END); +} + +void +FiberMutexTest::testMutex() +{ + // should not block + TEST_ASSERT_TRUE(mtx.try_lock()); + TEST_ASSERT_FALSE(mtx.try_lock()); + TEST_ASSERT_FALSE(mtx.try_lock()); + mtx.unlock(); + // multiple unlock calls should be fine too + mtx.unlock(); + mtx.unlock(); + mtx.unlock(); + // shouldn't block without scheduler + mtx.lock(); + mtx.unlock(); + + states_pos = 0; + modm::fiber::Task fiber1(stack1, f1), fiber2(stack2, f2); + modm::fiber::Scheduler::run(); + TEST_ASSERT_EQUALS(states_pos, 10u); + TEST_ASSERT_EQUALS(states[0], F1_START); + TEST_ASSERT_EQUALS(states[1], F1_LOCK1); + TEST_ASSERT_EQUALS(states[2], F1_LOCK2); + + TEST_ASSERT_EQUALS(states[3], F2_START); + TEST_ASSERT_EQUALS(states[4], F2_UNLOCK1); + + TEST_ASSERT_EQUALS(states[5], F1_LOCK3); + TEST_ASSERT_EQUALS(states[6], F1_LOCK4); + + TEST_ASSERT_EQUALS(states[7], F2_UNLOCK2); + + TEST_ASSERT_EQUALS(states[8], F1_END); + TEST_ASSERT_EQUALS(states[9], F2_END); +} + +// ============================== RECURSIVE MUTEX ============================= +static modm::fiber::recursive_mutex rc_mtx; + +static void +f3() +{ + ADD_STATE(F3_START); + + ADD_STATE(F3_LOCK1); + TEST_ASSERT_TRUE(rc_mtx.try_lock()); + TEST_ASSERT_TRUE(rc_mtx.try_lock()); + TEST_ASSERT_TRUE(rc_mtx.try_lock()); + modm::this_fiber::yield(); // goto F4_START + + ADD_STATE(F3_UNLOCK1); + rc_mtx.unlock(); + modm::this_fiber::yield(); + + ADD_STATE(F3_UNLOCK2); + rc_mtx.unlock(); + modm::this_fiber::yield(); + + ADD_STATE(F3_UNLOCK3); + rc_mtx.unlock(); + rc_mtx.unlock(); // more than necessary + rc_mtx.unlock(); + rc_mtx.unlock(); + modm::this_fiber::yield(); // goto F4_LOCK3 + + ADD_STATE(F3_LOCK2); + rc_mtx.lock(); // goto F4_UNLOCK1 + + ADD_STATE(F3_UNLOCK4); + rc_mtx.unlock(); + rc_mtx.unlock(); + + ADD_STATE(F3_END); +} + +static void +f4() +{ + ADD_STATE(F4_START); + ADD_STATE(F4_LOCK1); + TEST_ASSERT_FALSE(rc_mtx.try_lock()); + TEST_ASSERT_FALSE(rc_mtx.try_lock()); + TEST_ASSERT_FALSE(rc_mtx.try_lock()); + + ADD_STATE(F4_LOCK2); + rc_mtx.lock(); // goto F3_UNLOCK1 + + ADD_STATE(F4_LOCK3); + rc_mtx.lock(); + rc_mtx.lock(); + modm::this_fiber::yield(); // goto F3_LOCK2 + + ADD_STATE(F4_UNLOCK1); + rc_mtx.unlock(); + modm::this_fiber::yield(); + ADD_STATE(F4_UNLOCK2); + rc_mtx.unlock(); + modm::this_fiber::yield(); + ADD_STATE(F4_UNLOCK3); + rc_mtx.unlock(); + modm::this_fiber::yield(); // goto F3_UNLOCK4 + + ADD_STATE(F4_END); +} + +void +FiberMutexTest::testRecursiveMutex() +{ + // this should also work without scheduler, since fiber id is zero then + TEST_ASSERT_TRUE(rc_mtx.try_lock()); + TEST_ASSERT_TRUE(rc_mtx.try_lock()); + TEST_ASSERT_TRUE(rc_mtx.try_lock()); + rc_mtx.unlock(); + rc_mtx.unlock(); + rc_mtx.unlock(); + // more unlocks should be fine + rc_mtx.unlock(); + rc_mtx.unlock(); + + // should not block either without scheduler + rc_mtx.lock(); + rc_mtx.lock(); + rc_mtx.lock(); + rc_mtx.unlock(); + rc_mtx.unlock(); + rc_mtx.unlock(); + rc_mtx.unlock(); + + states_pos = 0; + modm::fiber::Task fiber1(stack1, f3), fiber2(stack2, f4); + modm::fiber::Scheduler::run(); + + TEST_ASSERT_TRUE(rc_mtx.try_lock()); + rc_mtx.unlock(); + rc_mtx.unlock(); + + TEST_ASSERT_EQUALS(states_pos, 16u); + TEST_ASSERT_EQUALS(states[0], F3_START); + TEST_ASSERT_EQUALS(states[1], F3_LOCK1); + + TEST_ASSERT_EQUALS(states[2], F4_START); + TEST_ASSERT_EQUALS(states[3], F4_LOCK1); + TEST_ASSERT_EQUALS(states[4], F4_LOCK2); + + TEST_ASSERT_EQUALS(states[5], F3_UNLOCK1); + TEST_ASSERT_EQUALS(states[6], F3_UNLOCK2); + TEST_ASSERT_EQUALS(states[7], F3_UNLOCK3); + + TEST_ASSERT_EQUALS(states[8], F4_LOCK3); + + TEST_ASSERT_EQUALS(states[9], F3_LOCK2); + + TEST_ASSERT_EQUALS(states[10], F4_UNLOCK1); + TEST_ASSERT_EQUALS(states[11], F4_UNLOCK2); + TEST_ASSERT_EQUALS(states[12], F4_UNLOCK3); + + TEST_ASSERT_EQUALS(states[13], F3_UNLOCK4); + TEST_ASSERT_EQUALS(states[14], F3_END); + + TEST_ASSERT_EQUALS(states[15], F4_END); +} + +// =============================== SHARED MUTEX =============================== +static modm::fiber::shared_mutex sh_mtx; + +static void +f5() +{ + ADD_STATE(F5_START); + ADD_STATE(F5_LOCK1); + // get the exclusive lock + sh_mtx.lock(); + TEST_ASSERT_FALSE(sh_mtx.try_lock()); + modm::this_fiber::yield(); // goto F6_LOCK1 + + ADD_STATE(F5_UNLOCK1); + sh_mtx.unlock(); + modm::this_fiber::yield(); // goto F6_UNLOCK1 + + ADD_STATE(F5_LOCK2); + // get the shared lock + sh_mtx.lock_shared(); + sh_mtx.lock_shared(); + modm::this_fiber::yield(); // goto F6_LOCK2 + + ADD_STATE(F5_UNLOCK2); + modm::this_fiber::yield(); + modm::this_fiber::yield(); + modm::this_fiber::yield(); + // still locked + sh_mtx.unlock_shared(); + modm::this_fiber::yield(); // goto F6_UNLOCK2 + + ADD_STATE(F5_END); +} + +static void +f6() +{ + ADD_STATE(F6_START); + ADD_STATE(F6_LOCK1); + // cannot get exclusive lock + sh_mtx.lock(); // goto F5_UNLOCK1 + + ADD_STATE(F6_UNLOCK1); + TEST_ASSERT_FALSE(sh_mtx.try_lock()); + sh_mtx.unlock(); + modm::this_fiber::yield(); // goto F5_LOCK2 + + ADD_STATE(F6_LOCK2); + // can get shared lock + sh_mtx.lock_shared(); + sh_mtx.lock_shared(); + // cannot get exclusive lock + sh_mtx.lock(); // goto F5_UNLOCK2 + + ADD_STATE(F6_UNLOCK2); + sh_mtx.unlock(); + + ADD_STATE(F6_END); +} + +void +FiberMutexTest::testSharedMutex() +{ + TEST_ASSERT_TRUE(sh_mtx.try_lock()); + TEST_ASSERT_FALSE(sh_mtx.try_lock()); + TEST_ASSERT_FALSE(sh_mtx.try_lock()); + sh_mtx.unlock(); + // more unlocks should be fine + sh_mtx.unlock(); + sh_mtx.unlock(); + + TEST_ASSERT_TRUE(sh_mtx.try_lock_shared()); + TEST_ASSERT_TRUE(sh_mtx.try_lock_shared()); + TEST_ASSERT_TRUE(sh_mtx.try_lock_shared()); + sh_mtx.unlock(); + // more unlocks should be fine + sh_mtx.unlock(); + sh_mtx.unlock(); + + + states_pos = 0; + modm::fiber::Task fiber1(stack1, f5), fiber2(stack2, f6); + modm::fiber::Scheduler::run(); + + TEST_ASSERT_TRUE(rc_mtx.try_lock()); + rc_mtx.unlock(); + rc_mtx.unlock(); + + TEST_ASSERT_EQUALS(states_pos, 12u); + TEST_ASSERT_EQUALS(states[0], F5_START); + TEST_ASSERT_EQUALS(states[1], F5_LOCK1); + + TEST_ASSERT_EQUALS(states[2], F6_START); + TEST_ASSERT_EQUALS(states[3], F6_LOCK1); + + TEST_ASSERT_EQUALS(states[4], F5_UNLOCK1); + + TEST_ASSERT_EQUALS(states[5], F6_UNLOCK1); + + TEST_ASSERT_EQUALS(states[6], F5_LOCK2); + + TEST_ASSERT_EQUALS(states[7], F6_LOCK2); + + TEST_ASSERT_EQUALS(states[8], F5_UNLOCK2); + + TEST_ASSERT_EQUALS(states[9], F6_UNLOCK2); + + TEST_ASSERT_EQUALS(states[10], F6_END); + + TEST_ASSERT_EQUALS(states[11], F5_END); +} + +// =============================== TIMED MUTEX ================================ +// =========================== TIMED RECURSIVE MUTEX ========================== +// ============================ TIMED SHARED MUTEX ============================ +// All implementations only add poll_for/_until with try_lock*() function. +// Explicit tests are omitted, since they are tested in FiberTest::testSleep*(). diff --git a/test/modm/processing/fiber/fiber_mutex_test.hpp b/test/modm/processing/fiber/fiber_mutex_test.hpp new file mode 100644 index 0000000000..f88095f92c --- /dev/null +++ b/test/modm/processing/fiber/fiber_mutex_test.hpp @@ -0,0 +1,28 @@ +/* + * Copyright (c) 2020, Erik Henriksson + * + * This file is part of the modm project. + * + * This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. + */ +// ---------------------------------------------------------------------------- + +#pragma once + +#include + +/// @ingroup modm_test_test_architecture +class FiberMutexTest : public unittest::TestSuite +{ +public: + void + testMutex(); + + void + testRecursiveMutex(); + + void + testSharedMutex(); +}; diff --git a/test/modm/processing/fiber/fiber_semaphore_test.cpp b/test/modm/processing/fiber/fiber_semaphore_test.cpp new file mode 100644 index 0000000000..56585e4553 --- /dev/null +++ b/test/modm/processing/fiber/fiber_semaphore_test.cpp @@ -0,0 +1,122 @@ +/* + * Copyright (c) 2024, Niklas Hauser + * + * This file is part of the modm project. + * + * This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. + */ +// ---------------------------------------------------------------------------- + +#include "fiber_semaphore_test.hpp" +#include "shared.hpp" +#include + +enum State : uint8_t +{ + INVALID, + + F1_START, + F1_ACQUIRE1, + F1_ACQUIRE2, + F1_RELEASE1, + F1_LOCK4, + F1_END, + + F2_START, + F2_ACQUIRE1, + F2_RELEASE1, + F2_RELEASE2, + F2_END, +}; + +// ================================== MUTEX =================================== +static modm::fiber::counting_semaphore sem{3}; + +static void +f1() +{ + ADD_STATE(F1_START); + TEST_ASSERT_TRUE(sem.try_acquire()); // 2 + TEST_ASSERT_TRUE(sem.try_acquire()); // 1 + TEST_ASSERT_TRUE(sem.try_acquire()); // 0 + TEST_ASSERT_FALSE(sem.try_acquire()); // 0 + TEST_ASSERT_FALSE(sem.try_acquire()); // 0 + sem.release(); // 1 + sem.acquire(); // 0 + ADD_STATE(F1_ACQUIRE1); + sem.acquire(); // goto F2_START, 0 + + ADD_STATE(F1_RELEASE1); + sem.release(); // 1 + sem.release(); // 2 + modm::this_fiber::yield(); // goto F2_ACQUIRE1 + + ADD_STATE(F1_ACQUIRE2); + sem.acquire(); + sem.acquire(); // goto F2_RELEASE2, 0 + + ADD_STATE(F1_END); +} + +static void +f2() +{ + ADD_STATE(F2_START); + modm::this_fiber::yield(); + + ADD_STATE(F2_RELEASE1); + sem.release(); // 1 + modm::this_fiber::yield(); // goto F1_RELEASE1 + + ADD_STATE(F2_ACQUIRE1); + sem.acquire(); // 1 + modm::this_fiber::yield(); // goto F1_ACQUIRE2 + + ADD_STATE(F2_RELEASE2); + sem.release(); // 1 + sem.release(); // 2 + sem.release(); // 3 + + ADD_STATE(F2_END); +} + +void +FiberSemaphoreTest::testCountingSemaphore() +{ + // should not block + TEST_ASSERT_TRUE(sem.try_acquire()); + TEST_ASSERT_TRUE(sem.try_acquire()); + TEST_ASSERT_TRUE(sem.try_acquire()); + TEST_ASSERT_FALSE(sem.try_acquire()); + TEST_ASSERT_FALSE(sem.try_acquire()); + sem.release(); + // shouldn't block without scheduler + sem.acquire(); + sem.release(); + sem.release(); + sem.release(); + + states_pos = 0; + modm::fiber::Task fiber1(stack1, f1), fiber2(stack2, f2); + modm::fiber::Scheduler::run(); + TEST_ASSERT_EQUALS(states_pos, 10u); + TEST_ASSERT_EQUALS(states[0], F1_START); + TEST_ASSERT_EQUALS(states[1], F1_ACQUIRE1); + + TEST_ASSERT_EQUALS(states[2], F2_START); + TEST_ASSERT_EQUALS(states[3], F2_RELEASE1); + + TEST_ASSERT_EQUALS(states[4], F1_RELEASE1); + + TEST_ASSERT_EQUALS(states[5], F2_ACQUIRE1); + + TEST_ASSERT_EQUALS(states[6], F1_ACQUIRE2); + + TEST_ASSERT_EQUALS(states[7], F2_RELEASE2); + TEST_ASSERT_EQUALS(states[8], F2_END); + + TEST_ASSERT_EQUALS(states[9], F1_END); +} + diff --git a/test/modm/processing/fiber/fiber_semaphore_test.hpp b/test/modm/processing/fiber/fiber_semaphore_test.hpp new file mode 100644 index 0000000000..eea7b4eaf9 --- /dev/null +++ b/test/modm/processing/fiber/fiber_semaphore_test.hpp @@ -0,0 +1,22 @@ +/* + * Copyright (c) 2020, Erik Henriksson + * + * This file is part of the modm project. + * + * This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. + */ +// ---------------------------------------------------------------------------- + +#pragma once + +#include + +/// @ingroup modm_test_test_architecture +class FiberSemaphoreTest : public unittest::TestSuite +{ +public: + void + testCountingSemaphore(); +}; diff --git a/test/modm/processing/fiber/fiber_test.cpp b/test/modm/processing/fiber/fiber_test.cpp index 89e969ab1a..9acf3d54de 100644 --- a/test/modm/processing/fiber/fiber_test.cpp +++ b/test/modm/processing/fiber/fiber_test.cpp @@ -11,16 +11,14 @@ // ---------------------------------------------------------------------------- #include "fiber_test.hpp" +#include "shared.hpp" -#include -#include -#include #include using namespace std::chrono_literals; using test_clock = modm_test::chrono::milli_clock; -enum State +enum State : uint8_t { INVALID, F1_START, @@ -48,11 +46,6 @@ enum State PRODUCER_END, }; -static std::array states = {}; -static size_t states_pos = 0; -static modm::fiber::Stack<1024> stack1, stack2; -#define ADD_STATE(state) states[states_pos++] = state; - static void f1() { diff --git a/test/modm/processing/fiber/shared.cpp b/test/modm/processing/fiber/shared.cpp new file mode 100644 index 0000000000..7d3ce4a0ed --- /dev/null +++ b/test/modm/processing/fiber/shared.cpp @@ -0,0 +1,16 @@ +/* + * Copyright (c) 2024, Niklas Hauser + * + * This file is part of the modm project. + * + * This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. + */ +// ---------------------------------------------------------------------------- + +#include "shared.hpp" + +std::array states{}; +uint8_t states_pos{}; +modm::fiber::Stack<> stack1, stack2; diff --git a/test/modm/processing/fiber/shared.hpp b/test/modm/processing/fiber/shared.hpp new file mode 100644 index 0000000000..8ce7a44086 --- /dev/null +++ b/test/modm/processing/fiber/shared.hpp @@ -0,0 +1,21 @@ +/* + * Copyright (c) 2024, Niklas Hauser + * + * This file is part of the modm project. + * + * This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. + */ +// ---------------------------------------------------------------------------- + +#pragma once +#include +#include + +// shared objects to reduce memory consumption + +extern std::array states; +extern uint8_t states_pos; +extern modm::fiber::Stack<> stack1, stack2; +#define ADD_STATE(state) states[states_pos++] = state;