From 2a688fd99285de69f0eb0389916baeaa22be7cbf Mon Sep 17 00:00:00 2001 From: Niklas Hauser Date: Sun, 21 Apr 2024 18:01:40 +0200 Subject: [PATCH] [fiber] Implement stop_token interface for Task --- src/modm/processing/fiber/scheduler.hpp.in | 6 + src/modm/processing/fiber/stop_token.hpp | 159 +++++++++++++++++++++ src/modm/processing/fiber/task.hpp | 103 +++++++++++-- test/modm/processing/fiber/fiber_test.cpp | 109 +++++++++++--- test/modm/processing/fiber/fiber_test.hpp | 3 + 5 files changed, 346 insertions(+), 34 deletions(-) create mode 100644 src/modm/processing/fiber/stop_token.hpp diff --git a/src/modm/processing/fiber/scheduler.hpp.in b/src/modm/processing/fiber/scheduler.hpp.in index f639b75343..40d08dac6a 100644 --- a/src/modm/processing/fiber/scheduler.hpp.in +++ b/src/modm/processing/fiber/scheduler.hpp.in @@ -155,6 +155,12 @@ protected: public: constexpr Scheduler() = default; + static constexpr unsigned int + hardware_concurrency() + { + return {{num_cores}}; + } + /// Runs the currently active scheduler. static void run() diff --git a/src/modm/processing/fiber/stop_token.hpp b/src/modm/processing/fiber/stop_token.hpp new file mode 100644 index 0000000000..4185236707 --- /dev/null +++ b/src/modm/processing/fiber/stop_token.hpp @@ -0,0 +1,159 @@ +/* + * Copyright (c) 2024, Niklas Hauser + * + * This file is part of the modm project. + * + * This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. + */ +// ---------------------------------------------------------------------------- + +#pragma once + +#include + +namespace modm::fiber +{ + +/// @ingroup modm_processing_fiber +/// @{ + +/// @cond +class stop_state +{ + std::atomic_bool requested{false}; + +public: + [[nodiscard]] + bool inline + stop_requested() const + { + return requested.load(); + } + + bool inline + request_stop() + { + return not requested.exchange(true); + } +}; +/// @endcond + +/// Implements the `std::stop_token` interface for fibers. +/// @see https://en.cppreference.com/w/cpp/thread/stop_token +class stop_token +{ + friend class Task; + friend class stop_source; + const stop_state *state{nullptr}; + + constexpr explicit stop_token(const stop_state *state) : state(state) {} +public: + constexpr stop_token() = default; + constexpr stop_token(const stop_token&) = default; + constexpr stop_token(stop_token&&) = default; + constexpr ~stop_token() = default; + constexpr stop_token& operator=(const stop_token&) = default; + constexpr stop_token& operator=(stop_token&&) = default; + + [[nodiscard]] + bool inline + stop_possible() const + { + return state; + } + + [[nodiscard]] + bool inline + stop_requested() const + { + return stop_possible() and state->stop_requested(); + } + + // void inline + // swap(stop_token& rhs) + // { + // std::swap(state, rhs.state); + // } + + [[nodiscard]] + friend bool inline + operator==(const stop_token& a, const stop_token& b) + { + return a.state == b.state; + } + + // friend void inline + // swap(stop_token& lhs, stop_token& rhs) + // { + // lhs.swap(rhs); + // } +}; + +/// Implements the `std::stop_source` interface for fibers. +/// @see https://en.cppreference.com/w/cpp/thread/stop_source +class stop_source +{ + friend class Task; + stop_state *state{nullptr}; + explicit constexpr stop_source(stop_state *state) : state(state) {} + +public: + constexpr stop_source() = default; + constexpr stop_source(const stop_source&) = default; + constexpr stop_source(stop_source&&) = default; + constexpr ~stop_source() = default; + constexpr stop_source& operator=(const stop_source&) = default; + constexpr stop_source& operator=(stop_source&&) = default; + + [[nodiscard]] + constexpr bool + stop_possible() const + { + return state; + } + + [[nodiscard]] + bool inline + stop_requested() const + { + return stop_possible() and state->stop_requested(); + } + + bool inline + request_stop() + { + return stop_possible() and state->request_stop(); + } + + [[nodiscard]] + stop_token inline + get_token() const + { + return stop_token{state}; + } + + // void inline + // swap(stop_token& rhs) + // { + // std::swap(state, rhs.state); + // } + + [[nodiscard]] + friend bool inline + operator==(const stop_source& a, const stop_source& b) + { + return a.state == b.state; + } + + // friend void inline + // swap(stop_token& lhs, stop_token& rhs) + // { + // lhs.swap(rhs); + // } +}; + +/// @} + +} diff --git a/src/modm/processing/fiber/task.hpp b/src/modm/processing/fiber/task.hpp index 1b850983d1..1d83f515e9 100644 --- a/src/modm/processing/fiber/task.hpp +++ b/src/modm/processing/fiber/task.hpp @@ -13,11 +13,16 @@ #include "stack.hpp" #include "context.h" +#include "stop_token.hpp" #include +// forward declaration +namespace modm::this_fiber { void yield(); } + namespace modm::fiber { +// forward declaration class Scheduler; /// The Fiber scheduling policy. @@ -60,17 +65,70 @@ class Task modm_context_t ctx; Task* next; Scheduler *scheduler{nullptr}; + stop_state stop{}; public: /// @param stack A stack object that is *NOT* shared with other tasks. - /// @param closure A callable object of signature `void(*)()`. + /// @param closure A callable object of signature `void()`. /// @param start When to start this task. - template - Task(Stack& stack, T&& closure, Start start=Start::Now); + template + Task(Stack& stack, Callable&& closure, Start start=Start::Now); + + inline + ~Task() + { + request_stop(); + join(); + } + + /// Returns the number of concurrent threads supported by the implementation. + [[nodiscard]] static constexpr unsigned int + hardware_concurrency(); + + /// Returns a value of std::thread::id identifying the thread associated + /// with `*this`. + [[nodiscard]] modm::fiber::id inline + get_id() const + { + return reinterpret_cast(this); + } + + /// Checks if the Task object identifies an active fiber of execution. + [[nodiscard]] bool + joinable() const; + + /// Blocks the current fiber until the fiber identified by `*this` + /// finishes its execution. Returns immediately if the thread is not joinable. + void inline + join() + { + if (joinable()) while(isRunning()) modm::this_fiber::yield(); + } + + [[nodiscard]] + stop_source inline + get_stop_source() + { + return stop_source{&stop}; + } + + [[nodiscard]] + stop_token inline + get_stop_token() + { + return stop_token{&stop}; + } + + bool inline + request_stop() + { + return stop.request_stop(); + } + /// Watermarks the stack to measure `stack_usage()` later. /// @see `modm_context_watermark()`. - void + void inline watermark_stack() { modm_context_watermark(&ctx); @@ -78,7 +136,7 @@ class Task /// @returns the stack usage as measured by a watermark level. /// @see `modm_context_stack_usage()`. - [[nodiscard]] size_t + [[nodiscard]] size_t inline stack_usage() const { return modm_context_stack_usage(&ctx); @@ -86,7 +144,7 @@ class Task /// @returns if the bottom word on the stack has been overwritten. /// @see `modm_context_stack_overflow()`. - [[nodiscard]] bool + [[nodiscard]] bool inline stack_overflow() const { return modm_context_stack_overflow(&ctx); @@ -98,7 +156,7 @@ class Task start(); /// @returns if the fiber is attached to a scheduler. - [[nodiscard]] bool + [[nodiscard]] bool inline isRunning() const { return scheduler; @@ -117,16 +175,21 @@ namespace modm::fiber template Task::Task(Stack& stack, T&& closure, Start start) { - if constexpr (std::is_convertible_v) + constexpr bool with_stop_token = std::is_invocable_r_v; + if constexpr (std::is_convertible_v or + std::is_convertible_v) { // A plain function without closure - auto caller = (uintptr_t) +[](void(*fn)()) + using Callable = std::conditional_t; + auto caller = (uintptr_t) +[](Callable fn) { - fn(); + if constexpr (with_stop_token) { + fn(fiber::Scheduler::instance().current->get_stop_token()); + } else fn(); fiber::Scheduler::instance().unschedule(); }; modm_context_init(&ctx, stack.memory, stack.memory + stack.words, - caller, (uintptr_t) static_cast(closure)); + caller, (uintptr_t) static_cast(closure)); } else { @@ -134,7 +197,7 @@ Task::Task(Stack& stack, T&& closure, Start start) constexpr size_t align_mask = std::max(StackAlignment, alignof(std::decay_t)) - 1u; constexpr size_t closure_size = (sizeof(std::decay_t) + align_mask) & ~align_mask; static_assert(Size >= closure_size + StackSizeMinimum, - "Stack size must ≥({{min_stack_size}}B + aligned sizeof(closure))!"); + "Stack size must be larger than minimum stack size + aligned sizeof(closure))!"); // Find a suitable aligned area at the top of stack to allocate the closure const uintptr_t ptr = uintptr_t(stack.memory + stack.words) - closure_size; // construct closure in place @@ -142,7 +205,9 @@ Task::Task(Stack& stack, T&& closure, Start start) // Encapsulate the proper ABI function call into a simpler function auto caller = (uintptr_t) +[](std::decay_t* closure) { - (*closure)(); + if constexpr (with_stop_token) { + (*closure)(fiber::Scheduler::instance().current->get_stop_token()); + } else (*closure)(); fiber::Scheduler::instance().unschedule(); }; // initialize the stack below the allocated closure @@ -160,5 +225,17 @@ Task::start() return true; } +constexpr unsigned int +Task::hardware_concurrency() +{ + return fiber::Scheduler::hardware_concurrency(); +} + +bool inline +Task::joinable() const +{ + return isRunning() and get_id() != modm::fiber::Scheduler::instance().get_id(); +} + } /// @endcond diff --git a/test/modm/processing/fiber/fiber_test.cpp b/test/modm/processing/fiber/fiber_test.cpp index 89e969ab1a..5f7ef621d0 100644 --- a/test/modm/processing/fiber/fiber_test.cpp +++ b/test/modm/processing/fiber/fiber_test.cpp @@ -24,8 +24,12 @@ enum State { INVALID, F1_START, + F1_YIELD1, + F1_YIELD2, + F1_YIELD3, F1_END, F2_START, + F2_JOIN, F2_END, F3_START, F3_END, @@ -40,6 +44,12 @@ enum State F6_SLEEP1, F6_SLEEP2, F6_END, + F7_START, + F7_YIELD, + F7_END, + F8_START, + F8_REQUEST_STOP, + F8_END, SUBROUTINE_START, SUBROUTINE_END, CONSUMER_START, @@ -61,20 +71,16 @@ f1() ADD_STATE(F1_END); } -static void -f2() -{ - ADD_STATE(F2_START); - modm::this_fiber::yield(); - ADD_STATE(F2_END); -} - void FiberTest::testOneFiber() { states_pos = 0; modm::fiber::Task fiber(stack1, f1); + TEST_ASSERT_DIFFERS(fiber.get_id(), modm::fiber::id(0)); + TEST_ASSERT_TRUE(fiber.joinable()); modm::fiber::Scheduler::run(); + + TEST_ASSERT_FALSE(fiber.joinable()); TEST_ASSERT_EQUALS(states_pos, 2u); TEST_ASSERT_EQUALS(states[0], F1_START); TEST_ASSERT_EQUALS(states[1], F1_END); @@ -84,13 +90,38 @@ void FiberTest::testTwoFibers() { states_pos = 0; - modm::fiber::Task fiber1(stack1, f1), fiber2(stack2, f2); + modm::fiber::Task fiber1(stack1, [&]() + { + ADD_STATE(F1_START); + modm::this_fiber::yield(); + ADD_STATE(F1_YIELD1); + modm::this_fiber::yield(); + ADD_STATE(F1_YIELD2); + modm::this_fiber::yield(); + ADD_STATE(F1_YIELD3); + modm::this_fiber::yield(); + ADD_STATE(F1_END); + }); + modm::fiber::Task fiber2(stack2, [&]() + { + ADD_STATE(F2_START); + TEST_ASSERT_TRUE(fiber1.joinable()); + TEST_ASSERT_FALSE(fiber2.joinable()); + fiber1.join(); // should wait + ADD_STATE(F2_JOIN); + fiber2.join(); // should not hang + ADD_STATE(F2_END); + }); modm::fiber::Scheduler::run(); - TEST_ASSERT_EQUALS(states_pos, 4u); + TEST_ASSERT_EQUALS(states_pos, 8u); TEST_ASSERT_EQUALS(states[0], F1_START); TEST_ASSERT_EQUALS(states[1], F2_START); - TEST_ASSERT_EQUALS(states[2], F1_END); - TEST_ASSERT_EQUALS(states[3], F2_END); + TEST_ASSERT_EQUALS(states[2], F1_YIELD1); + TEST_ASSERT_EQUALS(states[3], F1_YIELD2); + TEST_ASSERT_EQUALS(states[4], F1_YIELD3); + TEST_ASSERT_EQUALS(states[5], F1_END); + TEST_ASSERT_EQUALS(states[6], F2_JOIN); + TEST_ASSERT_EQUALS(states[7], F2_END); } static __attribute__((noinline)) void @@ -101,19 +132,18 @@ subroutine() ADD_STATE(SUBROUTINE_END); } -static void -f3() -{ - ADD_STATE(F3_START); - subroutine(); - ADD_STATE(F3_END); -} - void FiberTest::testYieldFromSubroutine() { states_pos = 0; - modm::fiber::Task fiber1(stack1, f1), fiber2(stack2, f3); + modm::fiber::Task fiber1(stack1, f1), fiber2(stack2, [&]() + { + ADD_STATE(F3_START); + TEST_ASSERT_TRUE(fiber1.joinable()); + subroutine(); + TEST_ASSERT_FALSE(fiber1.joinable()); + ADD_STATE(F3_END); + }); modm::fiber::Scheduler::run(); TEST_ASSERT_EQUALS(states_pos, 6u); TEST_ASSERT_EQUALS(states[0], F1_START); @@ -234,3 +264,40 @@ FiberTest::testSleepUntil() runSleepUntil(1502); runSleepUntil(0xffff'ffff - 30); } + +static void +f7(modm::fiber::stop_token stoken) +{ + ADD_STATE(F7_START); + while(not stoken.stop_requested()) + { + ADD_STATE(F7_YIELD); + modm::this_fiber::yield(); + } + ADD_STATE(F7_END); +} + +void +FiberTest::testStopToken() +{ + states_pos = 0; + modm::fiber::Task fiber1(stack1, f7), fiber2(stack2, [&]() + { + ADD_STATE(F8_START); + modm::this_fiber::yield(); + ADD_STATE(F8_REQUEST_STOP); + fiber1.request_stop(); + modm::this_fiber::yield(); + ADD_STATE(F8_END); + }); + modm::fiber::Scheduler::run(); + + TEST_ASSERT_EQUALS(states_pos, 7u); + TEST_ASSERT_EQUALS(states[0], F7_START); + TEST_ASSERT_EQUALS(states[1], F7_YIELD); + TEST_ASSERT_EQUALS(states[2], F8_START); + TEST_ASSERT_EQUALS(states[3], F7_YIELD); + TEST_ASSERT_EQUALS(states[4], F8_REQUEST_STOP); + TEST_ASSERT_EQUALS(states[5], F7_END); + TEST_ASSERT_EQUALS(states[6], F8_END); +} diff --git a/test/modm/processing/fiber/fiber_test.hpp b/test/modm/processing/fiber/fiber_test.hpp index 7e80f923aa..7f64240cbf 100644 --- a/test/modm/processing/fiber/fiber_test.hpp +++ b/test/modm/processing/fiber/fiber_test.hpp @@ -37,4 +37,7 @@ class FiberTest : public unittest::TestSuite void testSleepUntil(); + + void + testStopToken(); };