Skip to content

Commit

Permalink
[fiber] Implement stop_token interface for Task
Browse files Browse the repository at this point in the history
  • Loading branch information
salkinium committed Apr 21, 2024
1 parent cc3268b commit 2a688fd
Show file tree
Hide file tree
Showing 5 changed files with 346 additions and 34 deletions.
6 changes: 6 additions & 0 deletions src/modm/processing/fiber/scheduler.hpp.in
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,12 @@ protected:
public:
constexpr Scheduler() = default;

static constexpr unsigned int
hardware_concurrency()
{
return {{num_cores}};
}

/// Runs the currently active scheduler.
static void
run()
Expand Down
159 changes: 159 additions & 0 deletions src/modm/processing/fiber/stop_token.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
/*
* Copyright (c) 2024, Niklas Hauser
*
* This file is part of the modm project.
*
* This Source Code Form is subject to the terms of the Mozilla Public
* License, v. 2.0. If a copy of the MPL was not distributed with this
* file, You can obtain one at http://mozilla.org/MPL/2.0/.
*/
// ----------------------------------------------------------------------------

#pragma once

#include <atomic>

namespace modm::fiber
{

/// @ingroup modm_processing_fiber
/// @{

/// @cond
class stop_state
{
std::atomic_bool requested{false};

public:
[[nodiscard]]
bool inline
stop_requested() const
{
return requested.load();
}

bool inline
request_stop()
{
return not requested.exchange(true);
}
};
/// @endcond

/// Implements the `std::stop_token` interface for fibers.
/// @see https://en.cppreference.com/w/cpp/thread/stop_token
class stop_token
{
friend class Task;
friend class stop_source;
const stop_state *state{nullptr};

constexpr explicit stop_token(const stop_state *state) : state(state) {}
public:
constexpr stop_token() = default;
constexpr stop_token(const stop_token&) = default;
constexpr stop_token(stop_token&&) = default;
constexpr ~stop_token() = default;
constexpr stop_token& operator=(const stop_token&) = default;
constexpr stop_token& operator=(stop_token&&) = default;

[[nodiscard]]
bool inline
stop_possible() const
{
return state;
}

[[nodiscard]]
bool inline
stop_requested() const
{
return stop_possible() and state->stop_requested();
}

// void inline
// swap(stop_token& rhs)
// {
// std::swap(state, rhs.state);
// }

[[nodiscard]]
friend bool inline
operator==(const stop_token& a, const stop_token& b)
{
return a.state == b.state;
}

// friend void inline
// swap(stop_token& lhs, stop_token& rhs)
// {
// lhs.swap(rhs);
// }
};

/// Implements the `std::stop_source` interface for fibers.
/// @see https://en.cppreference.com/w/cpp/thread/stop_source
class stop_source
{
friend class Task;
stop_state *state{nullptr};
explicit constexpr stop_source(stop_state *state) : state(state) {}

public:
constexpr stop_source() = default;
constexpr stop_source(const stop_source&) = default;
constexpr stop_source(stop_source&&) = default;
constexpr ~stop_source() = default;
constexpr stop_source& operator=(const stop_source&) = default;
constexpr stop_source& operator=(stop_source&&) = default;

[[nodiscard]]
constexpr bool
stop_possible() const
{
return state;
}

[[nodiscard]]
bool inline
stop_requested() const
{
return stop_possible() and state->stop_requested();
}

bool inline
request_stop()
{
return stop_possible() and state->request_stop();
}

[[nodiscard]]
stop_token inline
get_token() const
{
return stop_token{state};
}

// void inline
// swap(stop_token& rhs)
// {
// std::swap(state, rhs.state);
// }

[[nodiscard]]
friend bool inline
operator==(const stop_source& a, const stop_source& b)
{
return a.state == b.state;
}

// friend void inline
// swap(stop_token& lhs, stop_token& rhs)
// {
// lhs.swap(rhs);
// }
};

/// @}

}
103 changes: 90 additions & 13 deletions src/modm/processing/fiber/task.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,16 @@

#include "stack.hpp"
#include "context.h"
#include "stop_token.hpp"
#include <type_traits>

// forward declaration
namespace modm::this_fiber { void yield(); }

namespace modm::fiber
{

// forward declaration
class Scheduler;

/// The Fiber scheduling policy.
Expand Down Expand Up @@ -60,33 +65,86 @@ class Task
modm_context_t ctx;
Task* next;
Scheduler *scheduler{nullptr};
stop_state stop{};

public:
/// @param stack A stack object that is *NOT* shared with other tasks.
/// @param closure A callable object of signature `void(*)()`.
/// @param closure A callable object of signature `void()`.
/// @param start When to start this task.
template<size_t Size, class T>
Task(Stack<Size>& stack, T&& closure, Start start=Start::Now);
template<size_t Size, class Callable>
Task(Stack<Size>& stack, Callable&& closure, Start start=Start::Now);

inline
~Task()
{
request_stop();
join();
}

/// Returns the number of concurrent threads supported by the implementation.
[[nodiscard]] static constexpr unsigned int
hardware_concurrency();

/// Returns a value of std::thread::id identifying the thread associated
/// with `*this`.
[[nodiscard]] modm::fiber::id inline
get_id() const
{
return reinterpret_cast<id>(this);
}

/// Checks if the Task object identifies an active fiber of execution.
[[nodiscard]] bool
joinable() const;

/// Blocks the current fiber until the fiber identified by `*this`
/// finishes its execution. Returns immediately if the thread is not joinable.
void inline
join()
{
if (joinable()) while(isRunning()) modm::this_fiber::yield();
}

[[nodiscard]]
stop_source inline
get_stop_source()
{
return stop_source{&stop};
}

[[nodiscard]]
stop_token inline
get_stop_token()
{
return stop_token{&stop};
}

bool inline
request_stop()
{
return stop.request_stop();
}


/// Watermarks the stack to measure `stack_usage()` later.
/// @see `modm_context_watermark()`.
void
void inline
watermark_stack()
{
modm_context_watermark(&ctx);
}

/// @returns the stack usage as measured by a watermark level.
/// @see `modm_context_stack_usage()`.
[[nodiscard]] size_t
[[nodiscard]] size_t inline
stack_usage() const
{
return modm_context_stack_usage(&ctx);
}

/// @returns if the bottom word on the stack has been overwritten.
/// @see `modm_context_stack_overflow()`.
[[nodiscard]] bool
[[nodiscard]] bool inline
stack_overflow() const
{
return modm_context_stack_overflow(&ctx);
Expand All @@ -98,7 +156,7 @@ class Task
start();

/// @returns if the fiber is attached to a scheduler.
[[nodiscard]] bool
[[nodiscard]] bool inline
isRunning() const
{
return scheduler;
Expand All @@ -117,32 +175,39 @@ namespace modm::fiber
template<size_t Size, class T>
Task::Task(Stack<Size>& stack, T&& closure, Start start)
{
if constexpr (std::is_convertible_v<T, void(*)()>)
constexpr bool with_stop_token = std::is_invocable_r_v<void, T, stop_token>;
if constexpr (std::is_convertible_v<T, void(*)()> or
std::is_convertible_v<T, void(*)(stop_token)>)
{
// A plain function without closure
auto caller = (uintptr_t) +[](void(*fn)())
using Callable = std::conditional_t<with_stop_token, void(*)(stop_token), void(*)()>;
auto caller = (uintptr_t) +[](Callable fn)
{
fn();
if constexpr (with_stop_token) {
fn(fiber::Scheduler::instance().current->get_stop_token());
} else fn();
fiber::Scheduler::instance().unschedule();
};
modm_context_init(&ctx, stack.memory, stack.memory + stack.words,
caller, (uintptr_t) static_cast<void(*)()>(closure));
caller, (uintptr_t) static_cast<Callable>(closure));
}
else
{
// lambda functions with a closure must be allocated on the stack ALIGNED!
constexpr size_t align_mask = std::max(StackAlignment, alignof(std::decay_t<T>)) - 1u;
constexpr size_t closure_size = (sizeof(std::decay_t<T>) + align_mask) & ~align_mask;
static_assert(Size >= closure_size + StackSizeMinimum,
"Stack size must ≥({{min_stack_size}}B + aligned sizeof(closure))!");
"Stack size must be larger than minimum stack size + aligned sizeof(closure))!");
// Find a suitable aligned area at the top of stack to allocate the closure
const uintptr_t ptr = uintptr_t(stack.memory + stack.words) - closure_size;
// construct closure in place
::new ((void*)ptr) std::decay_t<T>{std::forward<T>(closure)};
// Encapsulate the proper ABI function call into a simpler function
auto caller = (uintptr_t) +[](std::decay_t<T>* closure)
{
(*closure)();
if constexpr (with_stop_token) {
(*closure)(fiber::Scheduler::instance().current->get_stop_token());
} else (*closure)();
fiber::Scheduler::instance().unschedule();
};
// initialize the stack below the allocated closure
Expand All @@ -160,5 +225,17 @@ Task::start()
return true;
}

constexpr unsigned int
Task::hardware_concurrency()
{
return fiber::Scheduler::hardware_concurrency();
}

bool inline
Task::joinable() const
{
return isRunning() and get_id() != modm::fiber::Scheduler::instance().get_id();
}

}
/// @endcond
Loading

0 comments on commit 2a688fd

Please sign in to comment.