Skip to content

Commit

Permalink
[fiber] Implement std concurrency interfaces
Browse files Browse the repository at this point in the history
  • Loading branch information
salkinium committed Apr 21, 2024
1 parent 2a688fd commit 6803ce0
Show file tree
Hide file tree
Showing 16 changed files with 1,330 additions and 11 deletions.
4 changes: 4 additions & 0 deletions ext/gcc/assert.cpp.in
Original file line number Diff line number Diff line change
Expand Up @@ -76,5 +76,9 @@ _GLIBCXX_BEGIN_NAMESPACE_VERSION
void
__throw_bad_any_cast()
{ __modm_stdcpp_failure("bad_any_cast"); }

void
__throw_system_error(int errc __attribute__((unused)))
{ __modm_stdcpp_failure("system_error"); }
_GLIBCXX_END_NAMESPACE_VERSION
} // namespace
84 changes: 84 additions & 0 deletions src/modm/processing/fiber/barrier.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
/*
* Copyright (c) 2024, Niklas Hauser
*
* This file is part of the modm project.
*
* This Source Code Form is subject to the terms of the Mozilla Public
* License, v. 2.0. If a copy of the MPL was not distributed with this
* file, You can obtain one at http://mozilla.org/MPL/2.0/.
*/
// ----------------------------------------------------------------------------

#pragma once

#include "fiber.hpp"

namespace modm::fiber
{

/// @ingroup modm_processing_fiber
/// @{

/// Implements the `std::barrier` interface for fibers.
/// @see https://en.cppreference.com/w/cpp/thread/barrier
template< class CompletionFunction = decltype([](){}) >
class barrier
{
barrier(const barrier&) = delete;
barrier& operator=(const barrier&) = delete;
using count_t = uint16_t;

const CompletionFunction completion;
volatile count_t expected;
volatile count_t count;
volatile count_t sequence{};

public:
using arrival_token = count_t;

constexpr explicit
barrier(std::ptrdiff_t expected, CompletionFunction f = CompletionFunction())
: completion(std::move(f)), expected(expected), count(expected) {}

[[nodiscard]]
static constexpr std::ptrdiff_t
max() { return count_t(-1); }

[[nodiscard]]
arrival_token
arrive(std::ptrdiff_t n=1)
{
count_t last_arrival{sequence};
count -= n;
if (count == 0)
{
count = expected;
sequence++;
completion();
}
return last_arrival;
}

void
wait(arrival_token&& arrival) const
{
while (arrival == sequence) modm::this_fiber::yield();
}

void
arrive_and_wait()
{
wait(arrive());
}

void
arrive_and_drop()
{
expected--;
arrive();
}
};

/// @}

}
169 changes: 169 additions & 0 deletions src/modm/processing/fiber/condition_variable.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
/*
* Copyright (c) 2024, Niklas Hauser
*
* This file is part of the modm project.
*
* This Source Code Form is subject to the terms of the Mozilla Public
* License, v. 2.0. If a copy of the MPL was not distributed with this
* file, You can obtain one at http://mozilla.org/MPL/2.0/.
*/
// ----------------------------------------------------------------------------

#pragma once

#include "fiber.hpp"
#include "stop_token.hpp"
#include <atomic>


namespace modm::fiber
{

/// @ingroup modm_processing_fiber
/// @{

enum class
cv_status
{
no_timeout,
timeout
};

/// Implements the `std::condition_variable_any` interface for fibers.
/// @see https://en.cppreference.com/w/cpp/thread/condition_variable
class condition_variable_any
{
condition_variable_any(const condition_variable_any&) = delete;
condition_variable_any& operator=(const condition_variable_any&) = delete;

std::atomic<uint16_t> sequence{};
public:
constexpr condition_variable_any() = default;

void inline
notify_one()
{
sequence++;
}

void inline
notify_any()
{
sequence++;
}


template< class Lock >
void
wait(Lock& lock)
{
lock.unlock();
const auto my_sequence = sequence.load();
while(my_sequence == sequence.load()) modm::this_fiber::yield();
lock.lock();
}

template< class Lock, class Predicate >
void
wait(Lock& lock, Predicate&& pred)
{
while (not pred()) wait(lock);
}

template< class Lock, class Predicate >
bool
wait(Lock& lock, stop_token stoken, Predicate pred)
{
while (not stoken.stop_requested())
{
if (pred()) return true;
wait(lock);
}
return pred();
}


template< class Lock, class Rep, class Period >
cv_status
wait_for(Lock& lock, std::chrono::duration<Rep, Period> rel_time)
{
lock.unlock();
const auto condition = [this, my_sequence = sequence.load()]()
{ return my_sequence != sequence.load(); };
const bool result = this_fiber::poll_for(rel_time, condition);
lock.lock();
return result ? cv_status::no_timeout : cv_status::timeout;
}

template< class Lock, class Rep, class Period, class Predicate >
bool
wait_for(Lock& lock, std::chrono::duration<Rep, Period> rel_time, Predicate&& pred)
{
while (not pred())
{
if (wait_for(lock, rel_time) == cv_status::timeout)
return pred();
}
return true;
}

template< class Lock, class Rep, class Period, class Predicate >
bool
wait_for(Lock& lock, stop_token stoken,
std::chrono::duration<Rep, Period> rel_time, Predicate&& pred)
{
while (not stoken.stop_requested())
{
if (pred()) return true;
if (wait_for(lock, rel_time) == cv_status::timeout)
return pred();
}
return pred();
}


template< class Lock, class Clock, class Duration >
cv_status
wait_until(Lock& lock, std::chrono::time_point<Clock, Duration> abs_time)
{
lock.unlock();
const auto condition = [this, my_sequence = sequence.load()]()
{ return my_sequence != sequence.load(); };
const bool result = this_fiber::poll_until(abs_time, condition);
lock.lock();
return result ? cv_status::no_timeout : cv_status::timeout;
}

template< class Lock, class Clock, class Duration, class Predicate >
bool
wait_until(Lock& lock, std::chrono::time_point<Clock, Duration> abs_time, Predicate&& pred)
{
while (not pred())
{
if (wait_until(lock, abs_time) == cv_status::timeout)
return pred();
}
return true;
}

template< class Lock, class Clock, class Duration, class Predicate >
bool
wait_until(Lock& lock, stop_token stoken,
std::chrono::time_point<Clock, Duration> abs_time, Predicate&& pred)
{
while (not stoken.stop_requested())
{
if (pred()) return true;
if (wait_until(lock, abs_time) == cv_status::timeout)
return pred();
}
return pred();
}
};

// There is no specialization for std::unique_lock.
using condition_variable = condition_variable_any;

/// @}

}
71 changes: 71 additions & 0 deletions src/modm/processing/fiber/latch.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
/*
* Copyright (c) 2024, Niklas Hauser
*
* This file is part of the modm project.
*
* This Source Code Form is subject to the terms of the Mozilla Public
* License, v. 2.0. If a copy of the MPL was not distributed with this
* file, You can obtain one at http://mozilla.org/MPL/2.0/.
*/
// ----------------------------------------------------------------------------

#pragma once

#include "fiber.hpp"
#include <atomic>

namespace modm::fiber
{

/// @ingroup modm_processing_fiber
/// @{

/// Implements the `std::latch` interface for fibers.
/// @see https://en.cppreference.com/w/cpp/thread/latch
class latch
{
latch(const latch&) = delete;
latch& operator=(const latch&) = delete;

using count_t = uint16_t;
std::atomic<count_t> count;

public:
constexpr explicit
latch(std::ptrdiff_t expected)
: count(expected) {}

[[nodiscard]]
static constexpr std::ptrdiff_t
max() { return count_t(-1); }

void inline
count_down(std::ptrdiff_t n=1)
{
count -= n;
}

[[nodiscard]]
bool inline
try_wait() const
{
return count.load() == 0;
}

void inline
wait() const
{
while(not try_wait()) modm::this_fiber::yield();
}

void inline
arrive_and_wait(std::ptrdiff_t n=1)
{
count_down(n);
wait();
}
};

/// @}

}
11 changes: 10 additions & 1 deletion src/modm/processing/fiber/module.lb
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def is_enabled(env):
not env.has_module(":processing:protothread")

def prepare(module, options):
module.depends(":processing:timer")
module.depends(":processing:timer", ":architecture:atomic")

module.add_query(
EnvironmentQuery(name="__enabled", factory=is_enabled))
Expand All @@ -46,6 +46,7 @@ def build(env):
"with_fpu": with_fpu,
"target": env[":target"].identifier,
"multicore": env.has_module(":platform:multicore"),
"num_cores": 1,
}
if env.has_module(":platform:multicore"):
cores = int(env[":target"].identifier.cores)
Expand Down Expand Up @@ -77,3 +78,11 @@ def build(env):
env.copy("task.hpp")
env.copy("functions.hpp")
env.copy("fiber.hpp")

env.copy("mutex.hpp")
env.copy("shared_mutex.hpp")
env.copy("semaphore.hpp")
env.copy("latch.hpp")
env.copy("barrier.hpp")
env.copy("stop_token.hpp")
env.copy("condition_variable.hpp")
Loading

0 comments on commit 6803ce0

Please sign in to comment.