diff --git a/CMakePresets.json b/CMakePresets.json index d18cb80..2b7a740 100644 --- a/CMakePresets.json +++ b/CMakePresets.json @@ -10,8 +10,8 @@ "cacheVariables": { "CMAKE_BUILD_TYPE": "Debug", "MACORO_FETCH_AUTO": "ON", - "MACORO_CPP_20": false, - "CMAKE_INSTALL_PREFIX": "${sourceDir}/out/install/${presetName}", + "MACORO_CPP_VER": "20", + "CMAKE_INSTALL_PREFIX": "${sourceDir}/out/install/${presetName}" //"CMAKE_C_COMPILER": "clang-12", //"CMAKE_CXX_COMPILER": "clang++-12" }, @@ -33,7 +33,7 @@ "cacheVariables": { "CMAKE_BUILD_TYPE": "Debug", "MACORO_FETCH_AUTO": "ON", - "MACORO_CPP_VER": "14", + "MACORO_CPP_VER": "20", "CMAKE_INSTALL_PREFIX": "${sourceDir}/out/install/${presetName}" }, "vendor": { "microsoft.com/VisualStudioSettings/CMake/1.0": { "hostOS": [ "Windows" ] } } diff --git a/macoro/CMakeLists.txt b/macoro/CMakeLists.txt index 1699180..4336d9c 100644 --- a/macoro/CMakeLists.txt +++ b/macoro/CMakeLists.txt @@ -14,7 +14,7 @@ if(MSVC) detail/win32.cpp) endif() -add_library(macoro STATIC ${SRC}) +add_library(macoro STATIC ${SRC} "trace.h" "barrier.h") diff --git a/macoro/async_scope.h b/macoro/async_scope.h new file mode 100644 index 0000000..37af3cc --- /dev/null +++ b/macoro/async_scope.h @@ -0,0 +1,355 @@ +#pragma once + +#include "macoro/trace.h" +#include "macoro/coroutine_handle.h" +#include "macoro/barrier.h" +#include "macoro/awaiter.h" +#include +#include "macoro/detail/scoped_task_promise.h" + +namespace macoro +{ + template + struct scoped_task + { + friend struct when_all_scope; + + using promise_type = detail::scoped_task_promise; + + scoped_task() = default; + scoped_task(scoped_task&& other) + { + *this = std::move(other); + } + scoped_task& operator=(scoped_task&& other) noexcept + { + if (std::addressof(other) != this) + { + if (m_coroutine) + this->~scoped_task(); + m_coroutine = std::exchange(other.m_coroutine, nullptr); + } + return *this; + } + + scoped_task(coroutine_handle&& c) + : m_coroutine(c) + {} + + ~scoped_task(); + + auto MACORO_OPERATOR_COAWAIT() && noexcept + { + return detail::scoped_task_awaiter{ std::exchange(m_coroutine, nullptr) }; + } + + bool is_ready() const + { + if (m_coroutine) + { + auto v = m_coroutine.promise().m_completion_state.load(std::memory_order_relaxed); + return v & (unsigned char)promise_type::completion_state::has_completed; + } + return false; + + } + private: + + coroutine_handle m_coroutine; + }; + + // an exception representing a one or more sub exceptions + // throw from within async_scope::add(...) but not manually joined. + // when async_scope is awaited, the unhandled exceptions from add(...) + // will be collected and throw using async_scope_exception. + struct async_scope_exception : std::exception + { + friend struct async_scope; + + std::vector m_exceptions; + + mutable std::string m_what; + private: + async_scope_exception(std::vector ex) + : m_exceptions(std::move(ex)) + {} + + char const* what() const noexcept override + { + try { + if (m_exceptions.size()) + { + if (m_what.size() == 0) + { + std::stringstream ss; + ss << "async_scope_exception: "; + for (auto i = 0ull; i < m_exceptions.size(); ++i) + { + ss << i << ": "; + try { + std::rethrow_exception(m_exceptions[i]); + } + catch (std::exception& e) + { + ss << e.what() << "\n"; + } + catch (...) + { + ss << "unknown exception\n"; + } + } + + m_what = ss.str(); + } + return m_what.c_str(); + } + } + catch (...) {} + + return ""; + } + + // get the list of exceptions. + auto& exceptions() { return m_exceptions; } + + }; + + // a scope object that the caller can add awaitables to + // via add(...). When the scope object is awaited, e.g. + // + // co_await scope; + // + // then all tasks will be joined. individual tasks can be + // joined by awaiting the scoped_task that is returned by + // add(...). e.g. + // + // scoped_task task = scope.add(...); + // ... + // int i = co_await std::move(task); + // + // exceptions will be propegated correct. + // If some tasks are not explicitly awaited and they throw + // an exception, then awaiting the scope will throw + // async_scope_exception that will contain a vector with + // all exceptions that were thrown. + struct async_scope : public traceable + { + void add_child(traceable* child) final override + { + assert(0); + } + + void remove_child(traceable* child) final override + { + assert(0); + } + + + template + friend struct scoped_task; + + template + friend struct detail::scoped_task_promise; + + friend struct when_all_scope; + + // add an awaitable to the scope. The awaitable + // will be executed immediately. A scoped_task + // is returned. The caller may await the scoped_task + // or may let it get destroyed. The underlying + // awaitable will continued to be executed. All + // awaitables can be joined by calling join(). + // exceptions + template + scoped_task>> + add(Awaitable a, std::source_location loc = std::source_location::current()) noexcept + { + co_return co_await std::move(a); + } + + + // add an awaitable to the scope. The awaitable + // will be executed immediately. A scoped_task + // is returned. The caller may await the scoped_task + // or may let it get destroyed. The underlying + // awaitable will continued to be executed. All + // awaitables can be joined by calling join(). + // exceptions + template + scoped_task>> + add(Awaitable& a, std::source_location loc = std::source_location::current()) noexcept + { + co_return co_await a; + } + + struct join_awaiter + { + barrier_awaiter m_barrier; + std::vector m_exceptions; + + bool await_ready() const noexcept { return m_barrier.await_ready(); } + + template + coroutine_handle<> await_suspend( + coroutine_handle awaiting_coroutine, + std::source_location loc = std::source_location::current()) noexcept + { + return m_barrier.await_suspend(awaiting_coroutine, loc); + } + + template + std::coroutine_handle<> await_suspend( + std::coroutine_handle awaiting_coroutine, + std::source_location loc = std::source_location::current()) noexcept + { + return m_barrier.await_suspend(awaiting_coroutine, loc); + } + + void await_resume() + { + m_barrier.await_resume(); + if (m_exceptions.size()) + throw async_scope_exception(std::move(m_exceptions)); + } + }; + + // returns an awaiter that can be co_awaited to suspend + // until all added awaitables have been completed. Only + // one + join_awaiter join() noexcept + { + return MACORO_OPERATOR_COAWAIT(); + } + + // returns an awaiter that can be co_awaited to suspend + // until all added awaitables have been completed. + join_awaiter MACORO_OPERATOR_COAWAIT() & noexcept + { + return { m_barrier.MACORO_OPERATOR_COAWAIT(), std::move(m_exceptions) }; + } + + private: + // used to count how many active tasks exists. barrier + // then then be awaited to + traceable* m_parent = nullptr; + barrier m_barrier; + std::mutex m_exception_mutex; + std::vector m_exceptions; + }; + + + namespace detail + { + + + template + template + scoped_task_promise::scoped_task_promise(async_scope& scope, Awaiter&&, std::source_location& loc) noexcept + : m_completion_state((unsigned char)completion_state::inprogress) + , m_scope(&scope) + { + set_parent(&scope, loc); + scope.m_barrier.increment(); + } + + + + template + bool scoped_task_promise::final_awaitable::await_ready() noexcept + { + scoped_task_promise& promise = m_promise; + async_scope* scope = promise.m_scope; + bool done = false; + + // race to see if we need to call the completion handle + // it might not have been set yet (or ever). + // release our result. acquire their continuation (if they beat us). + completion_state b = promise.completion_state_fetch_or( + completion_state::has_completed, + std::memory_order_acq_rel); + + // if b, then we have finished after the continuation was set. + if (b == completion_state::has_continuation_or_dropped) + { + // if we dont have a continuation, that means the + // scoped_task has been dropped. Therefore, no one + // wants the result and we can destroy the coro. + // + // if we do have a continuation, that means we have + // a scoped_task::awaiter somewhere and they will + // call destroy for us, once the continuation consumes + // the return value. + if (promise.m_continuation) + { + m_run_continutation = true; + done = false; + } + else + { + if (promise.m_result_type == result_type::exception) + { + // todo, replace vector by a lock free linked list, either to the full + // coro frame (no allocations) or to the execptions (node allocations). + // for now, locking on expcetions should be fine. + std::lock_guard lock(promise.m_scope->m_exception_mutex); + promise.m_scope->m_exceptions.push_back(std::move(promise.m_exception)); + } + + // destroy the coro. + // this is the same as coro.destroy() + // but there is a msvc bug if you do that. + done = true; + } + } + else + { + // the handle is still alive and the user might set + // a continuation. If they set a continuation, the + // awaiter will destroy the coro once the continuation + // constumes the result. + // + // If a continuation is never set, then the scoped_task + // will destroy the coro in its distructor. + assert(b == completion_state::inprogress); + done = false; + } + + // notify the scope that we are done. + scope->m_barrier.decrement(); + return done; + } + } + + + template + scoped_task::~scoped_task() + { + if (m_coroutine) + { + auto& promise = m_coroutine.promise(); + // if we are here, then the coro has been started but + // we have not set a continuation. As such, we need to + // query the promise to see if the coro has completed. + // If so, we need to destroy it. Otherwise, we let it + // know that there wont be a continuation and the coro + // will destroy itself onces it completes. + auto b = + promise.try_set_continuation(coroutine_handle<>{}); + + if (b == promise_type::completion_state::has_completed) + { + if (promise.m_result_type == promise_type::result_type::exception) + { + // todo, replace vector by a lock free linked list, either to the full + // coro frame (no allocations) or to the execptions (node allocations). + // for now, locking on expcetions should be fine. + std::lock_guard lock(promise.m_scope->m_exception_mutex); + promise.m_scope->m_exceptions.push_back(std::move(promise.m_exception)); + } + + m_coroutine.destroy(); + } + } + } + +} \ No newline at end of file diff --git a/macoro/barrier.h b/macoro/barrier.h new file mode 100644 index 0000000..2698b5d --- /dev/null +++ b/macoro/barrier.h @@ -0,0 +1,140 @@ +#pragma once + +#include "macoro/config.h" +#include "macoro/coroutine_handle.h" +#include "macoro/trace.h" + +#include +#include + +namespace macoro +{ + class barrier; + + class barrier_awaiter : basic_traceable + { + barrier& m_barrier; + + public: + + barrier_awaiter(barrier& barrier) + : m_barrier(barrier) + {} + + bool await_ready() const noexcept { return false; } + + template + coroutine_handle<> await_suspend( + coroutine_handle continuation, + std::source_location loc = std::source_location::current()) noexcept; + + template + std::coroutine_handle<> await_suspend( + std::coroutine_handle continuation, + std::source_location loc = std::source_location::current()) noexcept + { + return await_suspend(coroutine_handle(continuation), loc).std_cast(); + } + + void await_resume() const noexcept {} + }; + + // a barrier is used to release a waiting coroutine + // once a count reaches 0. The count can be initialy + // set and the incremented and decremented. Once it + // hits zero, an await coroutine is resumed. + // + // for exmaple: + // + // barrier b(1); // count = 1 + // b.add(2); // count = 3 + // b.increment();// count = 4 + // co_await b; // suspend + // b.decrement();// count = 3 + // b.decrement();// count = 2 + // b.decrement();// count = 1 + // b.decrement();// count = 0 + // resumed... + // + // The barrior can be reused once it reaches zero. + // + class barrier + { + friend barrier_awaiter; + + // the coro that is awaiting the barrier + coroutine_handle<> m_continuation; + + // the current count, once it decrements to 0, + // m_continuation is called if set. + std::atomic m_count; + public: + + // construct a new barrier with the given initial value. + explicit barrier(std::size_t initial_count = 0) noexcept + : m_count(initial_count) + {} + + // increament the count by amount and return the new count. + std::size_t add(size_t amount) noexcept + { + auto old = m_count.fetch_add(amount, std::memory_order_relaxed); + auto ret = old + amount; + assert(ret >= old); + return ret; + } + + // increment the count by 1 + std::size_t increment() noexcept + { + return add(1); + } + + // returns the current count. + std::size_t count() const noexcept + { + return m_count.load(std::memory_order_acquire); + } + + // decrease the count by 1. Returns the new count. + std::size_t decrement() noexcept + { + const std::size_t old_count = m_count.fetch_sub(1, std::memory_order_acq_rel); + + assert(old_count); + + if (old_count == 1 && m_continuation) + { + std::exchange(m_continuation, nullptr).resume(); + } + return old_count - 1; + } + + // suspend until the count reached zero. Only + // one caller can await the barrier at a time. + // if needed, the impl could be extended to support + // more callers, see async_manual_reset_event. + auto MACORO_OPERATOR_COAWAIT() noexcept + { + return barrier_awaiter{ *this }; + } + }; + + template + coroutine_handle<> barrier_awaiter::await_suspend( + coroutine_handle continuation, + std::source_location loc) noexcept + { + coroutine_handle<> ret = continuation; + if (m_barrier.increment() > 1) + { + set_parent(get_traceable(continuation), loc); + assert(m_barrier.m_continuation == nullptr); + m_barrier.m_continuation = std::exchange(ret, noop_coroutine()); + } + m_barrier.decrement(); + + return ret; + } + +} \ No newline at end of file diff --git a/macoro/coro_frame.h b/macoro/coro_frame.h index 14079f0..094034e 100644 --- a/macoro/coro_frame.h +++ b/macoro/coro_frame.h @@ -55,6 +55,8 @@ namespace macoro assert(value); return value; } + + auto return_value() { return get_handle(); } }; template<> @@ -70,6 +72,9 @@ namespace macoro { return noop_coroutine(); } + + auto return_value() { return value; } + }; template diff --git a/macoro/detail/scoped_task_promise.h b/macoro/detail/scoped_task_promise.h new file mode 100644 index 0000000..6d3983d --- /dev/null +++ b/macoro/detail/scoped_task_promise.h @@ -0,0 +1,403 @@ +#pragma once + +#include "macoro/trace.h" +#include "macoro/coroutine_handle.h" +#include "macoro/barrier.h" +#include "macoro/awaiter.h" +#include + +namespace macoro +{ + struct async_scope; + struct when_all_scope; + + template + struct scoped_task; + namespace detail + { + // A CRTP class that customizes the value type + // based on if its a ref, value or void. + template + class scoped_task_promise_storage; + + + // A CRTP class that customizes the value type + // based on if its a ref, value or void. + template + class scoped_task_promise_storage + { + public: + Self& self() { return *(Self*)this; } + + using value_type = T&; + + template< + typename VALUE, + typename = enable_if_t::value> + > + void return_value(VALUE&& value) + noexcept(std::is_nothrow_constructible::value) + { + m_value = &value; + self().m_result_type = Self::result_type::value; + } + + // the actual storage, only one is ever active. + union + { + T* m_value; + std::exception_ptr m_exception; + }; + + scoped_task_promise_storage() {} + + ~scoped_task_promise_storage() + { + if (self().m_result_type == Self::result_type::exception) + m_exception.~exception_ptr(); + } + + T& result() + { + if (self().m_result_type == Self::result_type::exception) + std::rethrow_exception(self().m_exception); + assert(self().m_result_type == Self::result_type::value); + return *self().m_value; + } + }; + + + // A CRTP class that customizes the value type + // based on if its a ref, value or void. + template + class scoped_task_promise_storage + { + public: + Self& self() { return *(Self*)this; } + + using value_type = T&&; + + template< + typename VALUE, + typename = enable_if_t::value> + > + void return_value(VALUE&& value) + noexcept(std::is_nothrow_constructible::value) + { + m_value = &value; + self().m_result_type = Self::result_type::value; + } + + // the actual storage, only one is ever active. + union + { + T* m_value; + std::exception_ptr m_exception; + }; + + scoped_task_promise_storage() {} + + ~scoped_task_promise_storage() + { + if (self().m_result_type == Self::result_type::exception) + m_exception.~exception_ptr(); + } + + T&& result()& + { + if (self().m_result_type == Self::result_type::exception) + std::rethrow_exception(self().m_exception); + assert(self().m_result_type == Self::result_type::value); + return std::move(*self().m_value); + } + }; + + + // A CRTP class that customizes the value type + // based on if its a ref, value or void. + template + class scoped_task_promise_storage + { + public: + Self& self() { return *(Self*)this; } + + void return_void() noexcept + { + self().m_result_type = Self::result_type::value; + } + + scoped_task_promise_storage() {} + ~scoped_task_promise_storage() + { + if (self().m_result_type == Self::result_type::exception) + m_exception.~exception_ptr(); + } + + // the actual storage, only one is ever active. + union { + std::exception_ptr m_exception; + }; + + void result() + { + if (self().m_result_type == Self::result_type::exception) + std::rethrow_exception(self().m_exception); + assert(self().m_result_type == Self::result_type::value); + } + }; + + + // A CRTP class that customizes the value type + // based on if its a ref, value or void. + template + class scoped_task_promise_storage + { + public: + Self& self() { return *(Self*)this; } + + using value_type = T; + + template< + typename VALUE, + typename = enable_if_t::value> + > + void return_value(VALUE&& value) + noexcept(std::is_nothrow_constructible::value) + { + ::new (static_cast(std::addressof(m_value))) value_type(std::forward(value)); + self().m_result_type = Self::result_type::value; + } + + // the actual storage, only one is ever active. + union + { + value_type m_value; + std::exception_ptr m_exception; + }; + + scoped_task_promise_storage() {} + + ~scoped_task_promise_storage() + { + switch (self().m_result_type) + { + case Self::result_type::value: + if constexpr (std::is_destructible_v) + m_value.~T(); + break; + case Self::result_type::exception: + m_exception.~exception_ptr(); + break; + default: + break; + } + } + + T& result()& + { + if (self().m_result_type == Self::result_type::exception) + std::rethrow_exception(self().m_exception); + assert(self().m_result_type == Self::result_type::value); + return self().m_value; + } + + T result()&& + { + if (self().m_result_type == Self::result_type::exception) + std::rethrow_exception(self().m_exception); + assert(self().m_result_type == Self::result_type::value); + return std::move(self().m_value); + } + }; + + template + class scoped_task_promise final : + public scoped_task_promise_storage, T>, + public basic_traceable + { + public: + using value_type = T; + + friend struct final_awaitable; + + enum class completion_state : unsigned char { + // has not completed and does not have a continuation + inprogress = 0, + + // has not completed and does have a continuation + has_continuation_or_dropped = 1, + + // has completed and does not have a continuation + has_completed = 2 + }; + + // has the task completed or the continueation been set. + std::atomic m_completion_state; + + completion_state completion_state_fetch_or(completion_state state, std::memory_order order) + { + return (completion_state)m_completion_state.fetch_or((unsigned char)state, order); + } + + // the continuation if present. + coroutine_handle<> m_continuation; + + // the containing scope. + async_scope* m_scope = nullptr; + + enum class result_type { empty, value, exception }; + + // the status of the value/exception storage. + result_type m_result_type = result_type::empty; + + // this final awaiter runs the suspend logic in await_ready + // due to a bug un msvc. Using symetric transfer and called + // destroy when done = true causes a segfault. This alternative + // implementation implicitly destroys the coro by returning + // true from await_ready. + struct final_awaitable + { + scoped_task_promise& m_promise; + + // a variable used to remember if we need to call the continuation. + bool m_run_continutation = false; + + bool await_ready() noexcept; + +#ifdef MACORO_CPP_20 + std::coroutine_handle<> await_suspend( + std::coroutine_handle coro) noexcept + { + if (m_run_continutation) + return m_promise.m_continuation.std_cast(); + else + return std::noop_coroutine(); + } +#endif + + coroutine_handle<> await_suspend( + coroutine_handle coro) noexcept + { + if (m_run_continutation) + return m_promise.m_continuation; + else + return noop_coroutine(); + } + + void await_resume() noexcept { + } + }; + + + + public: + + template + scoped_task_promise(async_scope& scope, Awaiter&&, std::source_location&) noexcept; + + ~scoped_task_promise() { + } + + suspend_never initial_suspend() noexcept { return {}; } + final_awaitable final_suspend() noexcept { return { *this }; } + +#ifdef MACORO_CPP_20 + completion_state try_set_continuation(std::coroutine_handle<> continuation) noexcept + { + return try_set_continuation(coroutine_handle<>(continuation)); + } +#endif + completion_state try_set_continuation(coroutine_handle<> continuation) noexcept + { + m_continuation = continuation; + return completion_state_fetch_or( + completion_state::has_continuation_or_dropped, + std::memory_order_acq_rel); + } + + scoped_task get_return_object() noexcept { + return + scoped_task{ + coroutine_handle::from_promise(*this, coroutine_handle_type::std) + }; + } + scoped_task macoro_get_return_object() noexcept; + + void unhandled_exception() noexcept + { + ::new (static_cast(std::addressof(this->m_exception))) std::exception_ptr( + std::current_exception()); + m_result_type = result_type::exception; + } + }; + + + template + struct scoped_task_awaiter + { + using promise_type = scoped_task_promise; + + coroutine_handle m_coroutine; + + bool await_ready() const noexcept { return false; } + + scoped_task_awaiter(coroutine_handle c) + :m_coroutine(c) + {} + + ~scoped_task_awaiter() + { + assert(m_coroutine); + m_coroutine.destroy(); + } + +#ifdef MACORO_CPP_20 + template + std::coroutine_handle<> await_suspend( + std::coroutine_handle awaitingCoroutine) noexcept + { + return await_suspend( + coroutine_handle<>(awaitingCoroutine)).std_cast(); + } +#endif + + template + coroutine_handle<> await_suspend( + coroutine_handle awaiting_coroutine) noexcept + { + // release our continuation. acquire their result (if they beat us). + auto b = m_coroutine.promise().try_set_continuation(awaiting_coroutine); + + // if b, then the result is ready and we should just resume the + // awaiting coroutine. + if (b == promise_type::completion_state::has_completed) + { + // the coro has previously completed, we can just resume the + // continue the awaiting coro. When this awaiter is destroyed, + // the awaited coro will be destroyed. + return awaiting_coroutine; + } + else + { + // otherwise we have finished first and our continuation will be resumed + // when the coroutine finishes. + return noop_coroutine(); + } + } + + decltype(auto) await_resume() + { + if (!this->m_coroutine) + { + throw broken_promise{}; + } + + promise_type& p = this->m_coroutine.promise(); + + return static_cast(p).result(); + } + }; + + + } +} \ No newline at end of file diff --git a/macoro/detail/when_all_awaitable.h b/macoro/detail/when_all_awaitable.h index c0433e6..337970e 100644 --- a/macoro/detail/when_all_awaitable.h +++ b/macoro/detail/when_all_awaitable.h @@ -9,6 +9,7 @@ #include #include "macoro/coroutine_handle.h" #include "when_all_counter.h" +#include "macoro/trace.h" #ifdef MACORO_CPP_20 #include @@ -18,267 +19,239 @@ namespace macoro { namespace detail { - template - class when_all_ready_awaitable; - - template<> - class when_all_ready_awaitable> - { - public: - constexpr when_all_ready_awaitable() noexcept {} - explicit constexpr when_all_ready_awaitable(std::tuple<>) noexcept {} - constexpr bool await_ready() const noexcept { return true; } - void await_suspend(coroutine_handle<>) noexcept {} -#ifdef MACORO_CPP_20 - void await_suspend(std::coroutine_handle<>) noexcept {} -#endif - std::tuple<> await_resume() const noexcept { return {}; } - - }; - - template - class when_all_ready_awaitable> + class when_all_ready_awaitable_base : public traceable { public: - explicit when_all_ready_awaitable(TASKS&&... tasks) - noexcept(conjunction...>::value) - : m_counter(sizeof...(TASKS)) - , m_tasks(std::move(tasks)...) - {} - - explicit when_all_ready_awaitable(std::tuple&& tasks) - noexcept(std::is_nothrow_move_constructible>::value) - : m_counter(sizeof...(TASKS)) - , m_tasks(std::move(tasks)) + when_all_ready_awaitable_base(std::size_t count) noexcept + : m_count(count + 1) + , m_awaitingCoroutine(nullptr) {} - when_all_ready_awaitable(when_all_ready_awaitable&& other) noexcept - : m_counter(sizeof...(TASKS)) - , m_tasks(std::move(other.m_tasks)) - {} - - auto MACORO_OPERATOR_COAWAIT() & noexcept + bool is_ready() const noexcept { - struct awaiter - { - awaiter(when_all_ready_awaitable& awaitable) noexcept - : m_awaitable(awaitable) - {} - - bool await_ready() const noexcept - { - return m_awaitable.is_ready(); - } - -#ifdef MACORO_CPP_20 - bool await_suspend(std::coroutine_handle<> awaitingCoroutine) noexcept - { - return await_suspend(coroutine_handle<>(awaitingCoroutine)); - } -#endif - bool await_suspend(coroutine_handle<> awaitingCoroutine) noexcept - { - return m_awaitable.try_await(awaitingCoroutine); - } - - std::tuple& await_resume() noexcept - { - return m_awaitable.m_tasks; - } - - private: - - when_all_ready_awaitable& m_awaitable; - - }; + // We consider this complete if we're asking whether it's ready + // after a coroutine has already been registered. + return static_cast(m_awaitingCoroutine); + } - return awaiter{ *this }; + bool try_await(coroutine_handle<> awaitingCoroutine) noexcept + { + m_awaitingCoroutine = awaitingCoroutine; + return m_count.fetch_sub(1, std::memory_order_acq_rel) > 1; } - auto MACORO_OPERATOR_COAWAIT() && noexcept + void notify_awaitable_completed() noexcept { - struct awaiter + if (m_count.fetch_sub(1, std::memory_order_acq_rel) == 1) { - awaiter(when_all_ready_awaitable& awaitable) noexcept - : m_awaitable(awaitable) - {} - - bool await_ready() const noexcept - { - return m_awaitable.is_ready(); - } -#ifdef MACORO_CPP_20 - bool await_suspend(std::coroutine_handle<> awaitingCoroutine) noexcept - { - return await_suspend(coroutine_handle<>(awaitingCoroutine)); - } -#endif - bool await_suspend(coroutine_handle<> awaitingCoroutine) noexcept - { - return m_awaitable.try_await(awaitingCoroutine); - } - - std::tuple&& await_resume() noexcept - { - return std::move(m_awaitable.m_tasks); - } + m_awaitingCoroutine.resume(); + } + } - private: + protected: - when_all_ready_awaitable& m_awaitable; + std::atomic m_count; + coroutine_handle<> m_awaitingCoroutine; - }; + }; - return awaiter{ *this }; - } + template struct is_tuple : std::false_type {}; - private: + template struct is_tuple> : std::true_type {}; - bool is_ready() const noexcept - { - return m_counter.is_ready(); - } + template + class when_all_ready_awaitable; - bool try_await(coroutine_handle<> awaitingCoroutine) noexcept - { - start_tasks(std::make_integer_sequence{}); - return m_counter.try_await(awaitingCoroutine); - } + template<> + class when_all_ready_awaitable> + { + public: - template - void start_tasks(std::integer_sequence) noexcept - { - (void)std::initializer_list{ - (std::get(m_tasks).start(m_counter), 0)... - }; - } + constexpr when_all_ready_awaitable() noexcept {} + explicit constexpr when_all_ready_awaitable(std::tuple<>) noexcept {} - when_all_counter m_counter; - std::tuple m_tasks; + constexpr bool await_ready() const noexcept { return true; } + void await_suspend(coroutine_handle<>) noexcept {} +#ifdef MACORO_CPP_20 + void await_suspend(std::coroutine_handle<>) noexcept {} +#endif + std::tuple<> await_resume() const noexcept { return {}; } }; + // the when_all_read awaitable that holds all the tasks. + // TASK_CONTAINER should be either an std::tuple or vector like + // container that contains when_all_task. template - class when_all_ready_awaitable + class when_all_ready_awaitable : public when_all_ready_awaitable_base { public: + size_t tasks_size(TASK_CONTAINER& tasks) + { + if constexpr (is_tuple::value) + return std::tuple_size::value; + else + return tasks.size(); + } + + explicit when_all_ready_awaitable(TASK_CONTAINER&& tasks) noexcept - : m_counter(tasks.size()) + : when_all_ready_awaitable_base(tasks_size(tasks)) , m_tasks(std::forward(tasks)) {} when_all_ready_awaitable(when_all_ready_awaitable&& other) noexcept(std::is_nothrow_move_constructible::value) - : m_counter(other.m_tasks.size()) + : when_all_ready_awaitable_base(tasks_size(other.m_tasks)) , m_tasks(std::move(other.m_tasks)) {} when_all_ready_awaitable(const when_all_ready_awaitable&) = delete; when_all_ready_awaitable& operator=(const when_all_ready_awaitable&) = delete; - auto MACORO_OPERATOR_COAWAIT() & noexcept + class lval_awaiter { - class awaiter - { - public: + public: - awaiter(when_all_ready_awaitable& awaitable) - : m_awaitable(awaitable) - {} + lval_awaiter(when_all_ready_awaitable& awaitable) + : m_awaitable(awaitable) + {} - bool await_ready() const noexcept - { - return m_awaitable.is_ready(); - } + bool await_ready() const noexcept + { + return m_awaitable.is_ready(); + } #ifdef MACORO_CPP_20 - bool await_suspend(std::coroutine_handle<> awaitingCoroutine) noexcept - { - return await_suspend(coroutine_handle<>(awaitingCoroutine)); - } + template + bool await_suspend(std::coroutine_handle awaitingCoroutine, std::source_location loc = std::source_location::current()) noexcept + { + m_awaitable.set_parent(get_traceable(awaitingCoroutine), loc); + return await_suspend(coroutine_handle<>(awaitingCoroutine), loc); + } #endif - bool await_suspend(coroutine_handle<> awaitingCoroutine) noexcept - { - return m_awaitable.try_await(awaitingCoroutine); - } + template + bool await_suspend(coroutine_handle awaitingCoroutine, std::source_location loc = std::source_location::current()) noexcept + { + m_awaitable.set_parent(get_traceable(awaitingCoroutine), loc); + return m_awaitable.try_await(awaitingCoroutine, loc); + } - TASK_CONTAINER& await_resume() noexcept - { - return m_awaitable.m_tasks; - } + TASK_CONTAINER& await_resume() noexcept + { + return m_awaitable.m_tasks; + } - private: + private: - when_all_ready_awaitable& m_awaitable; + when_all_ready_awaitable& m_awaitable; - }; + }; - return awaiter{ *this }; + auto MACORO_OPERATOR_COAWAIT() & noexcept + { + return lval_awaiter{ *this }; } - auto MACORO_OPERATOR_COAWAIT() && noexcept + class rval_awaiter { - class awaiter - { - public: + public: - awaiter(when_all_ready_awaitable& awaitable) - : m_awaitable(awaitable) - {} + rval_awaiter(when_all_ready_awaitable& awaitable) + : m_awaitable(awaitable) + {} - bool await_ready() const noexcept - { - return m_awaitable.is_ready(); - } + bool await_ready() const noexcept + { + return m_awaitable.is_ready(); + } #ifdef MACORO_CPP_20 - bool await_suspend(std::coroutine_handle<> awaitingCoroutine) noexcept - { - return await_suspend(coroutine_handle<>(awaitingCoroutine)); - } + template + bool await_suspend(std::coroutine_handle awaitingCoroutine, std::source_location loc = std::source_location::current()) noexcept + { + m_awaitable.set_parent(get_traceable(awaitingCoroutine), loc); + return m_awaitable.try_await(coroutine_handle<>(awaitingCoroutine)); + + //return await_suspend(coroutine_handle<>(awaitingCoroutine), loc); + } #endif - bool await_suspend(coroutine_handle<> awaitingCoroutine) noexcept - { - return m_awaitable.try_await(awaitingCoroutine); - } + template + bool await_suspend(coroutine_handle awaitingCoroutine, std::source_location loc = std::source_location::current()) noexcept + { + m_awaitable.set_parent(get_traceable(awaitingCoroutine), loc); + return m_awaitable.try_await(awaitingCoroutine); + } - TASK_CONTAINER&& await_resume() noexcept - { - return std::move(m_awaitable.m_tasks); - } + TASK_CONTAINER&& await_resume() noexcept + { + return std::move(m_awaitable.m_tasks); + } - private: + private: - when_all_ready_awaitable& m_awaitable; + when_all_ready_awaitable& m_awaitable; - }; + }; - return awaiter{ *this }; + auto MACORO_OPERATOR_COAWAIT() && noexcept + { + return rval_awaiter{ *this }; } - private: - bool is_ready() const noexcept + + void add_child(traceable* child) final override + { + assert(0); + } + + void remove_child(traceable* child) final override + { + assert(0); + } + + //private: + + //bool is_ready() const noexcept + //{ + // return m_counter.is_ready(); + //} + + template + void start_tasks(std::integer_sequence) noexcept { - return m_counter.is_ready(); + if constexpr (is_tuple::value) + { + (void)std::initializer_list{ + (std::get(m_tasks).start(*this), 0)... + }; + } } bool try_await(coroutine_handle<> awaitingCoroutine) noexcept { - for (auto&& task : m_tasks) + if constexpr (is_tuple::value) { - task.start(m_counter); + start_tasks( + std::make_integer_sequence::value>{}); + } + else + { + for (auto&& task : m_tasks) + { + task.start(*this); + } } - return m_counter.try_await(awaitingCoroutine); + return when_all_ready_awaitable_base::try_await(awaitingCoroutine); } - when_all_counter m_counter; TASK_CONTAINER m_tasks; }; diff --git a/macoro/detail/when_all_counter.h b/macoro/detail/when_all_counter.h index f0080c8..928866e 100644 --- a/macoro/detail/when_all_counter.h +++ b/macoro/detail/when_all_counter.h @@ -13,42 +13,6 @@ namespace macoro { namespace detail { - class when_all_counter - { - public: - - when_all_counter(std::size_t count) noexcept - : m_count(count + 1) - , m_awaitingCoroutine(nullptr) - {} - - bool is_ready() const noexcept - { - // We consider this complete if we're asking whether it's ready - // after a coroutine has already been registered. - return static_cast(m_awaitingCoroutine); - } - - bool try_await(coroutine_handle<> awaitingCoroutine) noexcept - { - m_awaitingCoroutine = awaitingCoroutine; - return m_count.fetch_sub(1, std::memory_order_acq_rel) > 1; - } - - void notify_awaitable_completed() noexcept - { - if (m_count.fetch_sub(1, std::memory_order_acq_rel) == 1) - { - m_awaitingCoroutine.resume(); - } - } - - protected: - - std::atomic m_count; - coroutine_handle<> m_awaitingCoroutine; - - }; } } diff --git a/macoro/detail/when_all_task.h b/macoro/detail/when_all_task.h index 174ef87..888ec60 100644 --- a/macoro/detail/when_all_task.h +++ b/macoro/detail/when_all_task.h @@ -6,10 +6,11 @@ #include "macoro/coroutine_handle.h" -#include "when_all_counter.h" +#include "when_all_awaitable.h" #include "macoro/type_traits.h" #include "macoro/macros.h" +#include "macoro/trace.h" #include #include @@ -72,7 +73,7 @@ namespace macoro void await_suspend(coroutine_handle_t coro) const noexcept { - coro.promise().m_counter->notify_awaitable_completed(); + coro.promise().m_awaitable->notify_awaitable_completed(); } void await_resume() const noexcept {} @@ -101,9 +102,9 @@ namespace macoro return final_suspend(); } - void start(when_all_counter& counter) noexcept + void start(when_all_ready_awaitable_base& awaitable) noexcept { - m_counter = &counter; + m_awaitable = &awaitable; coroutine_handle_t::from_promise(*this, m_type).resume(); } @@ -119,7 +120,8 @@ namespace macoro return std::forward(*m_result); } - private: + + //private: void rethrow_if_exception() { @@ -129,8 +131,13 @@ namespace macoro } } + traceable* get_traceable() + { + return m_awaitable; + } + coroutine_handle_type m_type; - when_all_counter* m_counter; + when_all_ready_awaitable_base* m_awaitable = nullptr; std::exception_ptr m_exception; std::add_pointer_t m_result; @@ -174,13 +181,13 @@ namespace macoro #ifdef MACORO_CPP_20 void await_suspend(std_coroutine_handle_t coro) const noexcept { - coro.promise().m_counter->notify_awaitable_completed(); + coro.promise().m_awaitable->notify_awaitable_completed(); } #endif void await_suspend(coroutine_handle_t coro) const noexcept { - coro.promise().m_counter->notify_awaitable_completed(); + coro.promise().m_awaitable->notify_awaitable_completed(); } void await_resume() const noexcept {} @@ -199,9 +206,9 @@ namespace macoro { } - void start(when_all_counter& counter) noexcept + void start(when_all_ready_awaitable_base& awaitable) noexcept { - m_counter = &counter; + m_awaitable = &awaitable; coroutine_handle_t::from_promise(*this, m_type).resume(); } @@ -213,10 +220,11 @@ namespace macoro } } - private: + //private: + when_all_ready_awaitable_base* m_awaitable; coroutine_handle_type m_type; - when_all_counter* m_counter; + //when_all_counter* m_counter; std::exception_ptr m_exception; }; @@ -263,14 +271,14 @@ namespace macoro return std::move(m_coroutine.promise()).result(); } - private: + //private: template friend class when_all_ready_awaitable; - void start(when_all_counter& counter) noexcept + void start(when_all_ready_awaitable_base& awaiter) noexcept { - m_coroutine.promise().start(counter); + m_coroutine.promise().start(awaiter); } coroutine_handle_t m_coroutine; diff --git a/macoro/manual_reset_event.h b/macoro/manual_reset_event.h index 066b91a..52524a3 100644 --- a/macoro/manual_reset_event.h +++ b/macoro/manual_reset_event.h @@ -88,6 +88,7 @@ namespace macoro bool await_ready() const noexcept; bool await_suspend(coroutine_handle<> awaiter) noexcept; + bool await_suspend(std::coroutine_handle<> awaiter) noexcept; void await_resume() const noexcept {} private: @@ -191,6 +192,12 @@ namespace macoro return true; } + + inline bool async_manual_reset_event_operation::await_suspend( + std::coroutine_handle<> awaiter) noexcept + { + return await_suspend(coroutine_handle<>(awaiter)); + } } #endif diff --git a/macoro/sync_wait.h b/macoro/sync_wait.h index 18f496a..766bd28 100644 --- a/macoro/sync_wait.h +++ b/macoro/sync_wait.h @@ -18,158 +18,155 @@ namespace macoro struct blocking_task; - template - struct blocking_promise_base + template + struct blocking_promise : basic_traceable { - struct final_awaiter + + // can be used to check if it allocates + //void* operator new(std::size_t n) noexcept + //{ + // return std::malloc(n); + //} + //void operator delete(void* ptr, std::size_t sz) + //{ + // std::free(ptr); + //} + + // construct the awaiter in the promise + template + blocking_promise(A&&a, std::source_location& loc) + : m_awaiter(get_awaiter(std::forward(a))) { - bool await_ready() noexcept { return false; } -#ifdef MACORO_CPP_20 - template - void await_suspend(std::coroutine_handle

h) noexcept { h.promise().set(); } -#endif - template - void await_suspend(coroutine_handle

h) noexcept { h.promise().set(); } - void await_resume() noexcept {} - }; + set_parent(nullptr, loc); + } + + using inner_awaiter = decltype(get_awaiter(std::declval())); + + // the promise will hold the awaiter. this way when + // we return the value, you can directly get it + // by calling m_awaiter.await_resume() + inner_awaiter m_awaiter; + + // any active exceptions that is thrown. + std::exception_ptr m_exception; - std::exception_ptr exception; - std::mutex mutex; - std::condition_variable cv; - bool is_set = false; + // mutex and cv to block until the caller until the result is ready. + std::mutex m_mutex; + std::condition_variable m_cv; + bool m_is_set = false; + // block the caller. void wait() { - std::unique_lock lock(mutex); - cv.wait(lock, [this] { return is_set; }); + std::unique_lock lock(m_mutex); + m_cv.wait(lock, [this] { return m_is_set; }); } + // notify the caller. void set() { - assert(is_set == false); - std::lock_guard lock(this->mutex); - this->is_set = true; - this->cv.notify_all(); + assert(m_is_set == false); + std::lock_guard lock(this->m_mutex); + this->m_is_set = true; + this->m_cv.notify_all(); } + void return_void() {} + + struct final_awaiter + { + bool await_ready() noexcept { return false; }; + + template + void await_suspend(C c) noexcept { c.promise().set(); } + + void await_resume() noexcept {} + }; suspend_always initial_suspend() noexcept { return{}; } final_awaiter final_suspend() noexcept { - return { }; + return {}; } void unhandled_exception() noexcept { - exception = std::current_exception(); + m_exception = std::current_exception(); } - }; - template - struct blocking_promise : public blocking_promise_base - { - typename std::remove_reference::type* mVal = nullptr; - - blocking_task get_return_object() noexcept; - blocking_task macoro_get_return_object() noexcept; + blocking_task get_return_object() noexcept; + blocking_task macoro_get_return_object() noexcept; - using reference_type = T&&; - void return_value(reference_type v) noexcept + // intercept the dummy co_await and manually co_await m_awaiter. + auto& await_transform(Awaitable&& a) { - mVal = std::addressof(v); + return *this; } - reference_type value() + // forward to m_awaiter + bool await_ready() //noexcept(std::declval().await_ready()) { - if (this->exception) - std::rethrow_exception(this->exception); - return static_cast(*mVal); + return m_awaiter.await_ready(); } - }; - template<> - struct blocking_promise : public blocking_promise_base - { - blocking_task get_return_object() noexcept; - blocking_task macoro_get_return_object() noexcept; + // forward to m_awaiter + auto await_suspend(std::coroutine_handle c) //noexcept(macoro::await_suspend(m_awaiter, c)) + { + return m_awaiter.await_suspend(c); + } - void return_void() {} + // onces m_awaiter completes, notify the caller that they + // can get the result. this will happen by calling m_awaiter.await_resume(); + void await_resume() { } - void value() + // get the value out of m_awaiter. + decltype(auto) value() { - if (this->exception) - std::rethrow_exception(this->exception); + if (this->m_exception) + std::rethrow_exception(this->m_exception); + assert(m_is_set); + + return m_awaiter.await_resume(); } }; - template + template struct blocking_task { - using promise_type = blocking_promise; - coroutine_handle handle; + using promise_type = blocking_promise; + coroutine_handle m_handle; blocking_task(coroutine_handle h) - : handle(h) + : m_handle(h) {} blocking_task() = delete; - blocking_task(blocking_task&& h) noexcept :handle(std::exchange(h.handle, std::nullptr_t{})) {} - blocking_task& operator=(blocking_task&& h) { handle = std::exchange(h.handle, std::nullptr_t{}); } + //blocking_task(blocking_task&& h) noexcept :m_handle(std::exchange(h.m_handle, std::nullptr_t{})) {} + //blocking_task& operator=(blocking_task&& h) { m_handle = std::exchange(h.m_handle, std::nullptr_t{}); } ~blocking_task() { - if (handle) - handle.destroy(); - } - - void start() - { - handle.resume(); + m_handle.destroy(); } decltype(auto) get() { - handle.promise().wait(); - return handle.promise().value(); + m_handle.resume(); + m_handle.promise().wait(); + return m_handle.promise().value(); } }; + template + blocking_task blocking_promise::get_return_object() noexcept { return { coroutine_handle::from_promise(*this, coroutine_handle_type::std) }; } + template + blocking_task blocking_promise::macoro_get_return_object() noexcept { return { coroutine_handle::from_promise(*this, coroutine_handle_type::macoro) }; } - - template - inline blocking_task blocking_promise::get_return_object() noexcept { return { coroutine_handle>::from_promise(*this, coroutine_handle_type::std) }; } - template - inline blocking_task blocking_promise::macoro_get_return_object() noexcept { return { coroutine_handle>::from_promise(*this, coroutine_handle_type::macoro) }; } - inline blocking_task blocking_promise::get_return_object() noexcept { return { coroutine_handle>::from_promise(*this, coroutine_handle_type::std) }; } - inline blocking_task blocking_promise::macoro_get_return_object() noexcept { return { coroutine_handle>::from_promise(*this, coroutine_handle_type::macoro) }; } - - template< - typename Awaitable, - typename ResultType = typename awaitable_traits::await_result> - enable_if_t::value, - blocking_task - > - make_blocking_task(Awaitable&& awaitable) - { -#if MACORO_MAKE_BLOCKING_20 - co_return co_await static_cast(awaitable); -#else - MC_BEGIN(blocking_task, &awaitable); - MC_RETURN_AWAIT(static_cast(awaitable)); - MC_END(); - -#endif - } - - template< - typename Awaitable, - typename ResultType = typename awaitable_traits::await_result> - enable_if_t::value, - blocking_task - > - make_blocking_task(Awaitable&& awaitable) + template + auto make_blocking_task(Awaitable&& awaitable, std::source_location loc) + -> blocking_task { #if MACORO_MAKE_BLOCKING_20 - co_await std::forward(awaitable); + co_await static_cast(awaitable); #else MC_BEGIN(blocking_task, &awaitable); MC_AWAIT(static_cast(awaitable)); @@ -180,23 +177,13 @@ namespace macoro template - typename awaitable_traits::await_result - sync_wait(Awaitable&& awaitable) + decltype(auto) sync_wait(Awaitable&& awaitable, std::source_location loc = std::source_location::current()) { - auto task = detail::make_blocking_task(std::forward(awaitable)); - task.start(); - return task.get(); + return detail::make_blocking_task(std::forward(awaitable), loc).get(); } - - struct sync_wait_t - { - }; - - inline sync_wait_t sync_wait() - { - return {}; - } + struct sync_wait_t { }; + inline sync_wait_t sync_wait() { return {}; } template diff --git a/macoro/task.h b/macoro/task.h index 4199620..fe28443 100644 --- a/macoro/task.h +++ b/macoro/task.h @@ -18,11 +18,35 @@ #include "macoro/awaiter.h" #include "macoro/type_traits.h" #include "macoro/macros.h" +#include "macoro/trace.h" namespace macoro { template class task; + struct getPromise + { + template + struct awaitable + { + bool await_ready() const noexcept { return true; } +#ifdef MACORO_CPP_20 + void await_suspend(std::coroutine_handle<> coro) noexcept + { + assert(0); + } +#endif + + void await_suspend( + coroutine_handle<> coro) noexcept + { + assert(0); + } + Promise& await_resume() noexcept { return *mProm; } + Promise* mProm; + }; + }; + namespace detail { template @@ -32,7 +56,7 @@ namespace macoro struct task_awaitable_base; template<> - class task_promise_base + class task_promise_base : public basic_traceable { friend struct final_awaitable; @@ -64,7 +88,8 @@ namespace macoro public: task_promise_base() noexcept - {} + { + } auto initial_suspend() noexcept { @@ -77,24 +102,30 @@ namespace macoro } #ifdef MACORO_CPP_20 - void set_continuation(std::coroutine_handle<> continuation) noexcept + template + void set_continuation(std::coroutine_handle continuation, std::source_location l) noexcept { assert(!m_continuation); m_continuation = coroutine_handle<>(continuation); + set_parent(get_traceable(continuation), l); } #endif - void set_continuation(coroutine_handle<> continuation) noexcept + template + void set_continuation(coroutine_handle continuation, std::source_location l) noexcept { assert(!m_continuation); m_continuation = continuation; + set_parent(get_traceable(continuation), l); } private: coroutine_handle<> m_continuation; }; - + + //static_assert(has_async_stack_frame < task_promise_base>); + template<> - class task_promise_base + class task_promise_base : public basic_traceable { friend struct final_awaitable; @@ -220,18 +251,18 @@ namespace macoro return m_value; } - // HACK: Need to have co_await of task return prvalue rather than - // rvalue-reference to work around an issue with MSVC where returning - // rvalue reference of a fundamental type from await_resume() will - // cause the value to be copied to a temporary. This breaks the - // sync_wait() implementation. - // See https://github.com/lewissbaker/cppcoro/issues/40#issuecomment-326864107 - using rvalue_type = typename std::conditional< - std::is_arithmetic::value || std::is_pointer::value, - T, - T&&>::type; - - rvalue_type result()&& + //// HACK: Need to have co_await of task return prvalue rather than + //// rvalue-reference to work around an issue with MSVC where returning + //// rvalue reference of a fundamental type from await_resume() will + //// cause the value to be copied to a temporary. This breaks the + //// sync_wait() implementation. + //// See https://github.com/lewissbaker/cppcoro/issues/40#issuecomment-326864107 + //using rvalue_type = typename std::conditional< + // std::is_arithmetic::value || std::is_pointer::value, + // T, + // T&&>::type; + + T result()&& { if (m_resultType == result_type::exception) { @@ -243,6 +274,23 @@ namespace macoro return std::move(m_value); } + getPromise::awaitable await_transform(getPromise) noexcept + { + return getPromise::awaitable{this}; + } + + + auto await_transform(get_trace&& t) noexcept + { + return get_trace::awaitable(trace(t.location, *this)); + } + + + template + decltype(auto) await_transform(A&&a) noexcept + { + return std::forward(a); + } private: enum class result_type { empty, value, exception }; @@ -283,6 +331,23 @@ namespace macoro } } + getPromise::awaitable await_transform(getPromise)noexcept + { + return getPromise::awaitable{this}; + } + + + auto await_transform(get_trace&& t) noexcept + { + return get_trace::awaitable(trace(t.location, *this)); + } + + + template + decltype(auto) await_transform(A&& a) noexcept + { + return std::forward(a); + } private: std::exception_ptr m_exception; @@ -319,6 +384,21 @@ namespace macoro return *m_value; } + getPromise::awaitable await_transform(getPromise)noexcept + { + return getPromise::awaitable{this}; + } + + auto await_transform(get_trace&& t) noexcept + { + return get_trace::awaitable(trace(t.location, *this)); + } + + template + decltype(auto) await_transform(A&& a) noexcept + { + return std::forward(a); + } private: T* m_value = nullptr; @@ -329,7 +409,7 @@ namespace macoro template - struct task_awaitable_base + struct task_awaitable_base { using promise_type = task_promise; coroutine_handle m_coroutine; @@ -348,18 +428,22 @@ namespace macoro } #ifdef MACORO_CPP_20 + template std::coroutine_handle<> await_suspend( - std::coroutine_handle<> awaitingCoroutine) noexcept + std::coroutine_handle awaitingCoroutine, + std::source_location l = std::source_location::current()) noexcept { - m_coroutine.promise().set_continuation(awaitingCoroutine); + auto& prom = m_coroutine.promise(); + prom.set_continuation(awaitingCoroutine, l); return m_coroutine.std_cast(); } #endif - + template coroutine_handle<> await_suspend( - coroutine_handle<> awaitingCoroutine) noexcept + coroutine_handle awaitingCoroutine, + std::source_location l = std::source_location::current()) noexcept { - m_coroutine.promise().set_continuation(awaitingCoroutine); + m_coroutine.promise().set_continuation(awaitingCoroutine, l); return m_coroutine; } }; @@ -385,16 +469,18 @@ namespace macoro } #ifdef MACORO_CPP_20 + template std::coroutine_handle<> await_suspend( - std::coroutine_handle<> awaitingCoroutine) noexcept + std::coroutine_handle awaitingCoroutine) noexcept { return await_suspend( coroutine_handle<>(awaitingCoroutine)).std_cast(); } #endif + template coroutine_handle<> await_suspend( - coroutine_handle<> awaitingCoroutine) noexcept + coroutine_handle awaitingCoroutine) noexcept { // release our continuation. acquire their result (if they beat us). auto b = m_coroutine.promise().try_set_continuation(awaitingCoroutine); @@ -432,9 +518,6 @@ namespace macoro using value_type = T; - private: - - public: task() noexcept @@ -477,8 +560,7 @@ namespace macoro m_coroutine.destroy(); } - m_coroutine = other.m_coroutine; - other.m_coroutine = nullptr; + m_coroutine = std::exchange(other.m_coroutine, nullptr); } return *this; @@ -490,7 +572,7 @@ namespace macoro /// Awaiting a task that is ready is guaranteed not to block/suspend. bool is_ready() const noexcept { - return !m_coroutine || m_coroutine.done(); + return m_coroutine && m_coroutine.done(); } struct ref_awaitable : detail::task_awaitable_base @@ -507,9 +589,9 @@ namespace macoro return this->m_coroutine.promise().result(); } }; + auto MACORO_OPERATOR_COAWAIT() const& noexcept { - return ref_awaitable{ m_coroutine }; } @@ -539,6 +621,7 @@ namespace macoro void await_resume() const noexcept {} }; + /// \brief /// Returns an awaitable that will await completion of the task without /// attempting to retrieve the result. @@ -555,7 +638,6 @@ namespace macoro private: coroutine_handle m_coroutine; - }; diff --git a/macoro/trace.h b/macoro/trace.h new file mode 100644 index 0000000..4c4ab5c --- /dev/null +++ b/macoro/trace.h @@ -0,0 +1,285 @@ +#pragma once + + +#include "macoro/config.h" +#include +#include +#include "macoro/coroutine_handle.h" +#include +#include + +namespace macoro +{ + struct traceable + { + // where in the parent this frame was started. + std::source_location m_location; + + // the caller. Can be null. We will set this value + traceable* m_parent = nullptr; + + void set_parent(traceable* parent, const std::source_location& l) + { + m_location = l; + m_parent = parent; + } + + virtual void get_call_stack(std::vector& stack) + { + stack.push_back(m_location); + if (m_parent) + m_parent->get_call_stack(stack); + } + + + virtual void add_child(traceable* child) = 0; + virtual void remove_child(traceable* child) = 0; + + //{ + // m_child = child; + //} + + //{ + // m_child = child; + //} + + // point to the first (active) child. + // The child will set this value. + //traceable* m_child = 0; + + + //struct multi_child + //{ + // std::mutex m_mutex; + // std::vector m_children; + //}; + + //std::optional m_multi_child; + + //void set_parent(traceable* parent, std::source_location l) + //{ + // assert(mParent == nullptr); + // if (parent) + // { + // parent->mFile = l.file_name(); + // parent->mLine = l.line(); + + // mParent = parent; + // assert(parent->mChild == nullptr); + // parent->mChild = this; + // } + //} + + //void set_detatched(traceable* parent, std::source_location l) + //{ + // assert(mParent == nullptr); + // if (parent) + // { + // parent->mFile = l.file_name(); + // parent->mLine = l.line(); + + // mParent = parent; + // if (mParent->m_multi_child.has_value() == false) + // mParent->m_multi_child.emplace(); + + // std::lock_guard lck(mParent->m_multi_child->m_mutex); + // mParent->m_multi_child->m_children.push_back(this); + // } + //} + + //~traceable() + //{ + // // we are a suspend child of some coro + // if (mParent) + // { + // if (mParent->mChild == this) + // { + // // we are a normal child. + // mParent->mChild = nullptr; + // } + // else + // { + // auto& children = mParent->m_multi_child; + // assert(children.has_value()); + // std::lock_guard lck(children->m_mutex); + + // auto iter = std::find( + // children->m_children.begin(), + // children->m_children.end(), + // this); + // assert(iter != children->m_children.end()); + // std::swap(*iter, children->m_children.back()); + // children->m_children.pop_back(); + // } + // } + + // if (mChild) + // { + // mChild->mParent = nullptr; + // } + //} + }; + + // a traceable that can only have one parent and child. + // This is not thread safe. + struct basic_traceable : public traceable + { + traceable* m_child = nullptr; + + void get_call_stack(std::vector& stack) final override + { + stack.push_back(m_location); + if (m_parent) + m_parent->get_call_stack(stack); + } + + void add_child(traceable* child) final override + { + assert(m_child == nullptr); + m_child = child; + } + + void remove_child(traceable* child) final override + { + assert(m_child == child); + m_child = nullptr; + } + + }; + + + template + concept is_traceable = + requires(T t) { + { t } -> std::convertible_to; + }; + + template + concept has_traceable = + requires(T t) { + { t.m_trace } -> std::convertible_to; + }; + + template + concept has_traceable_fn = + requires(T t) { + { t.get_traceable() } -> std::convertible_to; + }; + + template + concept has_traceable_ptr = + requires(T t) { + { t.m_trace } -> std::convertible_to; + }; + + namespace detail + { + + template + traceable* get_traceable(Promise& p) + { + if constexpr (is_traceable) + return &p; + if constexpr (has_traceable_fn) + return p.get_traceable(); + if constexpr (has_traceable) + return &p.m_trace; + if constexpr (has_traceable_ptr) + return p.m_trace; + else + return nullptr; + } + + template + traceable* get_traceable(std::coroutine_handle p) + { + if constexpr (!std::is_same_v) + return get_traceable(p.promise()); + else + return nullptr; + + } + + template + traceable* get_traceable(coroutine_handle p) + { + if constexpr (!std::is_same_v) + return get_traceable(p.promise()); + else + return nullptr; + + } + } + + + + struct trace { + + trace(std::source_location l, traceable& c) + { + stack.push_back(l); + c.get_call_stack(stack); + } + + std::vector stack; + + std::string str() + { + std::stringstream ss; + for(auto i = 0ull; i < stack.size(); ++i) + { + ss << i << " " << stack[i].file_name() << ":" << stack[i].line() << "\n"; + } + return ss.str(); + } + }; + + struct get_trace { + + get_trace(std::source_location l = std::source_location::current()) + :location(l) + { } + std::source_location location; + + struct awaitable + { + awaitable(trace&& t) + : mTrace(std::move(t)) + {} + + bool await_ready() const noexcept { return true; } +#ifdef MACORO_CPP_20 + void await_suspend(std::coroutine_handle<> coro) noexcept + { + assert(0); + } +#endif + + void await_suspend( + coroutine_handle<> coro) noexcept + { + assert(0); + } + + + trace await_resume() noexcept { return mTrace; } + trace mTrace; + }; + }; + //template + //void set_trace(Promise& p, async_stack_frame* a) + //{ + // if constexpr (has_async_stack_frame) + // return p.async_stack_frame = a; + // else + // return; + //} + + //template + //void connect_trace(Promise1& patent, Promise2 child) noexcept + //{ + // set_trace(child, get_trace(patent); + //} + + //thread_local async_stack_frame* current_async_stack_frame = nullptr; +} diff --git a/macoro/when_all_scope.h b/macoro/when_all_scope.h new file mode 100644 index 0000000..c8bd176 --- /dev/null +++ b/macoro/when_all_scope.h @@ -0,0 +1,126 @@ +#pragma once +#include "macoro/async_scope.h" +#include "macoro/manual_reset_event.h" +#include "macoro/coroutine_handle.h" + +namespace macoro +{ + + struct when_all_scope + { + struct promise_type + { + promise_type() + { + m_scope.m_barrier.increment(); + } + + struct final_awaitable + { + bool await_ready() noexcept { return false; } + void await_suspend(std::coroutine_handle h) noexcept + { + h.promise().m_scope.m_barrier.decrement(); + } + void await_resume() noexcept {} + }; + + suspend_always initial_suspend() noexcept { return {}; } + final_awaitable final_suspend() noexcept { return {}; } + + when_all_scope get_return_object() noexcept { + return { coroutine_handle::from_promise(*this, macoro::coroutine_handle_type::std) }; + } + + void unhandled_exception() noexcept { + + std::lock_guard lock(m_scope.m_exception_mutex); + m_scope.m_exceptions.push_back(std::move(std::current_exception())); + } + + void return_void() noexcept {}; + + //template + //decltype(auto) await_transform(A&& a) noexcept + //{ + // return std::forward(a); + //} + + template + decltype(auto) await_transform(A&& a, std::source_location loc = std::source_location::current()) noexcept + { + using scoped_task_of_a = scoped_task;// decltype(m_scope.add(std::forward(a))); + struct awaiter + { + scoped_task_of_a m_scoped_task; + bool await_ready() noexcept { return true; } + void await_suspend(std::coroutine_handle) noexcept {} + scoped_task_of_a await_resume()noexcept + { + return std::move(m_scoped_task); + } + }; + return awaiter{ m_scope.add(std::forward(a), loc) }; + } + + + template + decltype(auto) await_transform(scoped_task& a) + { + if (a.m_coroutine.promise().m_scope != &m_scope) + throw std::runtime_error("in a when_all_scope, only awaiting scoped_task's from this scope is supported."); + return a; + } + template + decltype(auto) await_transform(scoped_task&& a) + { + auto& prom = a.m_coroutine.promise(); + if (prom.m_scope != &m_scope) + throw std::runtime_error("in a when_all_scope, only awaiting scoped_task's from this scope is supported."); + return std::move(a); + } + + traceable* get_traceable() + { + return &m_scope; + } + + async_scope m_scope; + }; + + when_all_scope(coroutine_handle h) + : m_handle(h) + {} + + struct awaiter + { + coroutine_handle m_handle; + async_scope::join_awaiter m_join; + + bool await_ready() { return m_join.await_ready(); } + + template + std::coroutine_handle<> await_suspend( + std::coroutine_handle p, + std::source_location loc = std::source_location::current()) + { + m_handle.promise().m_scope.set_parent(get_traceable(p.promise()), loc); + m_handle.resume(); + return m_join.await_suspend(p); + } + + void await_resume() + { + m_join.await_resume(); + } + }; + + auto MACORO_OPERATOR_COAWAIT() const& noexcept + { + return awaiter{ m_handle, m_handle.promise().m_scope.MACORO_OPERATOR_COAWAIT() }; + } + + + coroutine_handle m_handle; + }; +} \ No newline at end of file diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index bd77a74..02ac605 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -14,7 +14,7 @@ add_library(macoroTests "CLP.cpp" "CLP.h" "channel_spsc_tests.cpp" - "channel_mpsc_tests.cpp") + "channel_mpsc_tests.cpp" "async_scope_tests.cpp" "async_scope_tests.h") target_link_libraries(macoroTests macoro) diff --git a/tests/async_scope_tests.cpp b/tests/async_scope_tests.cpp new file mode 100644 index 0000000..5489b23 --- /dev/null +++ b/tests/async_scope_tests.cpp @@ -0,0 +1,434 @@ +#include "macoro/async_scope.h" +#include "macoro/task.h" +#include "macoro/sync_wait.h" +#include "macoro/manual_reset_event.h" + +namespace macoro +{ + + namespace tests + { + + //struct task_of + //{ + // struct promise_type + // { + // struct final_suspend + // { + // bool await_ready() noexcept { return false; } + + // std::coroutine_handle<> await_suspend(std::coroutine_handle c)noexcept + // { + // c.destroy(); + // return std::noop_coroutine(); + // } + + // void await_resume() noexcept{} + // }; + + // std::suspend_always initial_suspend() noexcept { return {}; } + // std::suspend_always final_suspend() noexcept { return {}; } + // void unhandled_exception() noexcept {} + // task_of get_return_object() noexcept { return {}; } + + // void return_void() {} + // }; + + //}; + + //task_of foo() + //{ + // co_return; + //} + + void scoped_task_test() + { + + struct scope_guard + { + bool* destroyed = nullptr; + scope_guard(bool& d) + :destroyed(&d) + {} + + scope_guard(scope_guard&& g) + :destroyed(std::exchange(g.destroyed, nullptr)) + {} + + ~scope_guard() + { + if (destroyed) + *destroyed = true; + } + }; + + // continuation after completion + { + bool destroyed = false; + auto f = [&](scope_guard g) -> task { + co_return 42; + }; + + async_scope scope; + scoped_task t = scope.add(f({ destroyed })); + + if (t.is_ready() == false) + throw std::runtime_error(MACORO_LOCATION); + + //auto a = sync_wait(std::move(f({ destroyed }))); + auto v = sync_wait(std::move(t)); + + if (v != 42) + throw std::runtime_error(MACORO_LOCATION); + + if (!destroyed) + throw std::runtime_error(MACORO_LOCATION); + } + + // continuation before completion. + { + async_manual_reset_event e; + + + auto f = [&](scope_guard g) -> task { + co_await e; + co_return 42; + }; + async_scope scope; + bool destroyed = false; + scoped_task t = scope.add(f({ destroyed })); + + if (t.is_ready() == true) + throw std::runtime_error(MACORO_LOCATION); + + if (destroyed) + throw std::runtime_error(MACORO_LOCATION); + + auto eager = [&]() -> eager_task { + auto v = co_await std::move(t); + co_return v; + } + (); + + + if (destroyed) + throw std::runtime_error(MACORO_LOCATION); + + if (eager.is_ready() == true) + throw std::runtime_error(MACORO_LOCATION); + + e.set(); + + if (eager.is_ready() == false) + throw std::runtime_error(MACORO_LOCATION); + + if (destroyed == false) + throw std::runtime_error(MACORO_LOCATION); + + auto v = sync_wait(eager); + + if (v != 42) + throw std::runtime_error(MACORO_LOCATION); + } + + + // drops before completion + { + async_manual_reset_event e; + + bool called = false, destroyed = false; + auto f = [&](scope_guard g) -> task { + co_await e; + called = true; + co_return 42; + }; + async_scope scope; + scope.add(f({ destroyed })); + + if (called) + throw std::runtime_error(MACORO_LOCATION); + + e.set(); + + if (!called) + throw std::runtime_error(MACORO_LOCATION); + + if (destroyed == false) + throw std::runtime_error(MACORO_LOCATION); + + sync_wait(scope); + } + + + // joins before completion + { + async_manual_reset_event e; + + bool destroyed = false; + auto f = [&](scope_guard g) -> task { + co_await e; + co_return 42; + }; + async_scope scope; + scoped_task t = scope.add(f({ destroyed })); + + if (t.is_ready() == true) + throw std::runtime_error(MACORO_LOCATION); + + auto eager = [&]() -> eager_task<> { + co_await scope; + } + (); + + if (eager.is_ready() == true) + throw std::runtime_error(MACORO_LOCATION); + + e.set(); + + if (t.is_ready() == false) + throw std::runtime_error(MACORO_LOCATION); + if (eager.is_ready() == false) + throw std::runtime_error(MACORO_LOCATION); + + sync_wait(eager); + + + if (destroyed) + throw std::runtime_error(MACORO_LOCATION); + + auto v = sync_wait(std::move(t)); + + if(v != 42) + throw std::runtime_error(MACORO_LOCATION); + + if (!destroyed) + throw std::runtime_error(MACORO_LOCATION); + } + + + // joins before completion and then discard + { + async_manual_reset_event e; + + bool destroyed = false; + auto f = [&](scope_guard g) -> task { + co_await e; + co_return 42; + }; + async_scope scope; + scoped_task t = scope.add(f({ destroyed })); + + if (t.is_ready() == true) + throw std::runtime_error(MACORO_LOCATION); + + auto eager = [&]() -> eager_task<> { + co_await scope; + } + (); + + if (eager.is_ready() == true) + throw std::runtime_error(MACORO_LOCATION); + + e.set(); + + if (t.is_ready() == false) + throw std::runtime_error(MACORO_LOCATION); + if (eager.is_ready() == false) + throw std::runtime_error(MACORO_LOCATION); + + sync_wait(eager); + + if (destroyed) + throw std::runtime_error(MACORO_LOCATION); + + t = {}; + + if (!destroyed) + throw std::runtime_error(MACORO_LOCATION); + } + + + struct test_exception : std::exception + { + + }; + + // get exception + { + bool destroyed = false; + auto f = [&](scope_guard g) -> task { + throw test_exception{}; + co_return 42; + }; + async_scope scope; + + { + scoped_task t = scope.add(f({ destroyed })); + + if (t.is_ready() == false) + throw std::runtime_error(MACORO_LOCATION); + + if (destroyed == true) + throw std::runtime_error(MACORO_LOCATION); + + try + { + auto v = sync_wait(std::move(t)); + throw std::runtime_error("failed"); + } + catch (test_exception& e) + { + } + } + + + if (destroyed == false) + throw std::runtime_error(MACORO_LOCATION); + + } + + + // get exception join + { + bool destroyed = false; + auto f = [&](scope_guard g) -> task { + throw test_exception{}; + co_return 42; + }; + async_scope scope; + + { + scoped_task t = scope.add(f({ destroyed })); + + if (t.is_ready() == false) + throw std::runtime_error(MACORO_LOCATION); + + if (destroyed == true) + throw std::runtime_error(MACORO_LOCATION); + } + + if (destroyed == false) + throw std::runtime_error(MACORO_LOCATION); + + try + { + sync_wait(scope); + throw std::runtime_error("failed"); + } + catch (async_scope_exception& e) + { + if(e.m_exceptions.size() != 1) + throw std::runtime_error(MACORO_LOCATION); + + try + { + std::rethrow_exception(e.m_exceptions[0]); + } + catch (test_exception& e) + { + } + } + } + + // reference type + { + int x = 0; + bool destroyed = false; + auto f = [&](scope_guard g) -> task { + co_return x; + }; + + async_scope scope; + scoped_task t = scope.add(f({ destroyed })); + + if (t.is_ready() == false) + throw std::runtime_error(MACORO_LOCATION); + + auto& v = sync_wait(std::move(t)); + + if (&v != &x) + throw std::runtime_error(MACORO_LOCATION); + + if (!destroyed) + throw std::runtime_error(MACORO_LOCATION); + } + + // move only type + { + std::unique_ptr x(new int{ 42 }); + auto addr = x.get(); + bool destroyed = false; + auto f = [&](scope_guard g) -> task> { + co_return std::move(x); + }; + + async_scope scope; + scoped_task> t = scope.add(f({ destroyed })); + + if (t.is_ready() == false) + throw std::runtime_error(MACORO_LOCATION); + + auto v = sync_wait(std::move(t)); + + if (v.get() != addr) + throw std::runtime_error(MACORO_LOCATION); + + if (!destroyed) + throw std::runtime_error(MACORO_LOCATION); + } + + + // pointer type + { + int X = 42; + int* x = &X; + bool destroyed = false; + auto f = [&](scope_guard g) -> task { + co_return x; + }; + + async_scope scope; + scoped_task t = scope.add(f({ destroyed })); + + if (t.is_ready() == false) + throw std::runtime_error(MACORO_LOCATION); + + auto v = sync_wait(std::move(t)); + + if (v != x) + throw std::runtime_error(MACORO_LOCATION); + + if (!destroyed) + throw std::runtime_error(MACORO_LOCATION); + } + + + + // void type + { + bool destroyed = false; + auto f = [&](scope_guard g) -> task<> { + co_return; + }; + + async_scope scope; + scoped_task<> t = scope.add(f({ destroyed })); + + if (t.is_ready() == false) + throw std::runtime_error(MACORO_LOCATION); + + sync_wait(std::move(t)); + + if (!destroyed) + throw std::runtime_error(MACORO_LOCATION); + } + } + + + void async_scope_test() + { + + + } + } +} \ No newline at end of file diff --git a/tests/async_scope_tests.h b/tests/async_scope_tests.h new file mode 100644 index 0000000..087e534 --- /dev/null +++ b/tests/async_scope_tests.h @@ -0,0 +1,15 @@ +#pragma once + +namespace macoro +{ + + namespace tests + { + + void scoped_task_test(); + void async_scope_test(); + } + + + +} \ No newline at end of file diff --git a/tests/await_lifetime_tests.cpp b/tests/await_lifetime_tests.cpp index 7931397..90c15d1 100644 --- a/tests/await_lifetime_tests.cpp +++ b/tests/await_lifetime_tests.cpp @@ -13,6 +13,9 @@ + __GNUC_PATCHLEVEL__) #endif +#define test_assert(X) if(!(X)) throw std::runtime_error(MACORO_LOCATION) + + namespace macoro { @@ -535,7 +538,8 @@ namespace macoro // gcc has a bug that the awaiter is moved when it shouldnt be. This makes some tests fail. #ifdef GCC_VERSION #ifndef __llvm__ -#define HAS_GCC_MOVE_AWAITER_BUG (GCC_VERSION < 12 * 10000) +#define HAS_GCC_MOVE_AWAITER_BUG 0 +//#define HAS_GCC_MOVE_AWAITER_BUG (GCC_VERSION < 12 * 10000) #else #define HAS_GCC_MOVE_AWAITER_BUG 0 #endif @@ -618,7 +622,7 @@ namespace macoro { print(log, exp); } - assert(log == exp); + test_assert(log == exp); } else { @@ -635,7 +639,7 @@ namespace macoro }; //std::cout << "n"; //print(log); - assert(log == exp); + test_assert(log == exp); } } @@ -738,7 +742,7 @@ namespace macoro {"del promise_type", 0, 0}, {"del test_value", 5, 0} }; - assert(log == exp); + test_assert(log == exp); } else { @@ -754,7 +758,7 @@ namespace macoro {"del promise_type", 1, 0}, {"del promise_type", 0, 0} }; - assert(log == exp); + test_assert(log == exp); } } //std::cout << " passed " << std::endl;; @@ -808,7 +812,7 @@ namespace macoro gg.h.resume(); } - assert(done); + test_assert(done); auto log1 = getLog(); if (!(log1 == exp)) @@ -816,7 +820,7 @@ namespace macoro print(log1, exp); } - assert(log1 == exp); + test_assert(log1 == exp); #endif counter = 0; done = false; @@ -825,7 +829,7 @@ namespace macoro auto r = f(); r.h.resume(); } - assert(done); + test_assert(done); auto log2 = getLog(); @@ -834,7 +838,7 @@ namespace macoro { print(log2, exp); } - assert(log2 == exp); + test_assert(log2 == exp); //std::cout << " passed " << std::endl;; } @@ -895,7 +899,7 @@ namespace macoro gg.h.resume(); } - assert(done); + test_assert(done); auto log1 = getLog(); @@ -904,14 +908,14 @@ namespace macoro print(log1, exp); } - assert(log1 == exp); + test_assert(log1 == exp); #endif done = false; { auto r = f(); r.h.resume(); } - assert(done); + test_assert(done); auto log2 = getLog(); @@ -919,7 +923,7 @@ namespace macoro { print(log2, exp); } - assert(log2 == exp); + test_assert(log2 == exp); //std::cout << " passed " << std::endl;; #endif @@ -994,7 +998,7 @@ namespace macoro gg.h.resume(); } - assert(done); + test_assert(done); auto log1 = getLog(); //std::cout << "\n\n"; @@ -1004,7 +1008,7 @@ namespace macoro { print(log1, exp20); } - assert(log1 == exp20); + test_assert(log1 == exp20); //print(log1); #endif @@ -1013,7 +1017,7 @@ namespace macoro auto r = f(); r.h.resume(); } - assert(done); + test_assert(done); auto log2 = getLog(); @@ -1021,7 +1025,7 @@ namespace macoro { print(log2, exp); } - assert(log2 == exp); + test_assert(log2 == exp); //std::cout << " passed " << std::endl;; } @@ -1095,7 +1099,7 @@ namespace macoro gg.h.resume(); } - assert(done); + test_assert(done); auto log1 = getLog(); //print(log1); @@ -1106,14 +1110,14 @@ namespace macoro print(log1, exp20); } - assert(log1 == exp20); + test_assert(log1 == exp20); #endif done = false; { auto r = f(); r.h.resume(); } - assert(done); + test_assert(done); auto log2 = getLog(); @@ -1122,7 +1126,7 @@ namespace macoro { print(log2, exp); } - assert(log2 == exp); + test_assert(log2 == exp); //std::cout << " passed " << std::endl;; } } diff --git a/tests/task_tests.cpp b/tests/task_tests.cpp index 442c500..ebc3a63 100644 --- a/tests/task_tests.cpp +++ b/tests/task_tests.cpp @@ -4,14 +4,12 @@ #include "macoro/stop.h" #include #include +#include namespace { #ifdef MACORO_CPP_20 macoro::task taskInt20() { - //assert(0); - //throw std::runtime_error(""); - co_return 42; } #endif @@ -30,7 +28,6 @@ namespace macoro void task_int_test() { - //std::cout << "task_int_test "; #ifdef MACORO_CPP_20 { task t = taskInt20(); @@ -168,9 +165,6 @@ namespace macoro ++taskRef_val; assert(v == 43); } - - - //std::cout << "passed" << std::endl; } namespace { @@ -316,6 +310,7 @@ namespace macoro //std::cout << "task_blocking_int_test "; #ifdef MACORO_CPP_20 { + auto l = std::source_location::current(); task t = taskInt20(); int i = sync_wait(t); assert(i == 42); @@ -446,17 +441,17 @@ namespace macoro auto t = [](stop_token t) -> task - { - MC_BEGIN(task<>, t); - while (true) { - if (t.stop_requested()) - throw operation_cancelled(); - } - - MC_RETURN_VOID(); - MC_END(); - }; + MC_BEGIN(task<>, t); + while (true) + { + if (t.stop_requested()) + throw operation_cancelled(); + } + + MC_RETURN_VOID(); + MC_END(); + }; stop_source mSrc; diff --git a/tests/tests.cpp b/tests/tests.cpp index 2d4d0e3..295dc2e 100644 --- a/tests/tests.cpp +++ b/tests/tests.cpp @@ -9,6 +9,7 @@ #include "sequence_tests.h" #include "channel_spsc_tests.h" #include "channel_mpsc_tests.h" +#include "async_scope_tests.h" #ifdef _MSC_VER #include @@ -292,7 +293,12 @@ namespace macoro t.add("task_blocking_ex_test ", task_blocking_ex_test); t.add("task_blocking_cancel_test ", task_blocking_cancel_test); - //t.add("when_all_basic_tests ", when_all_basic_tests); + t.add("scoped_task_test ", scoped_task_test); + t.add("async_scope_test ", async_scope_test); + + t.add("when_all_basic_tests ", when_all_basic_tests); + t.add("when_all_scope_test ", when_all_scope_test); + t.add("schedule_after_test ", schedule_after); t.add("take_until_tests ", take_until_tests); t.add("schedule_after_cancaled ", schedule_after_cancaled); diff --git a/tests/when_all_tests.cpp b/tests/when_all_tests.cpp index cf6fe25..d919316 100644 --- a/tests/when_all_tests.cpp +++ b/tests/when_all_tests.cpp @@ -2,34 +2,36 @@ #include "macoro/when_all.h" #include "macoro/task.h" #include "macoro/sync_wait.h" +#include "macoro/manual_reset_event.h" +#include "macoro/async_scope.h" +#include "macoro/when_all_scope.h" namespace macoro { namespace tests { - void when_all_basic_tests() - { - auto f = []() -> task { - MC_BEGIN(task); - MC_RETURN(42); - MC_END(); + auto f = []() -> task { + //std::cout << "trace\n" << (co_await get_trace{}).str() << std::endl; + co_return(42); }; - auto g = []() -> task { - MC_BEGIN(task); - MC_RETURN(true); - MC_END(); + auto g = []() -> task { + MC_BEGIN(task); + MC_RETURN(true); + MC_END(); }; - bool b; - auto h = [&]() -> task { - MC_BEGIN(task, &b); - MC_RETURN(b); - MC_END(); + bool b = true; + auto h = []() -> task { + MC_BEGIN(task); + MC_RETURN(b); + MC_END(); }; + void when_all_basic_tests() + { static_assert( is_awaitable< task @@ -59,26 +61,51 @@ namespace macoro >::value , ""); - std::tuple < - detail::when_all_task, - detail::when_all_task - > - r = sync_wait(when_all_ready(f(), g())); - auto ff = f(); auto gg = g(); - std::tuple < + auto t = [&]()->task< + std::tuple < detail::when_all_task, detail::when_all_task - > - r2 = sync_wait(when_all_ready(std::move(ff), h())); - //auto r = sync_wait(w); - detail::when_all_task r0 = std::move(std::get<0>(r)); - assert(std::get<0>(r).result() == 42); - assert(std::get<1>(r).result() == true); + >> + { + co_return co_await when_all_ready(std::move(ff), h()); + }(); + + auto r2 = sync_wait(std::move(t)); + + detail::when_all_task r0 = std::move(std::get<0>(r2)); + detail::when_all_task r1 = std::move(std::get<1>(r2)); + + if (r0.result() != 42) + throw std::runtime_error(MACORO_LOCATION); + if (&r1.result() != &b) + throw std::runtime_error(MACORO_LOCATION); + } + + void when_all_scope_test() + { + + auto tr = []()->when_all_scope { + + scoped_task r = co_await f(); + + int i = co_await std::move(r); + co_return; + + } + (); + + sync_wait(std::move(tr)); + using namespace macoro; + using Awaitable = task; + Awaitable a = f(); + auto t = detail::make_when_all_task(std::move(a)); + + } } } diff --git a/tests/when_all_tests.h b/tests/when_all_tests.h index 94123bd..5ea1692 100644 --- a/tests/when_all_tests.h +++ b/tests/when_all_tests.h @@ -7,6 +7,8 @@ namespace macoro namespace tests { void when_all_basic_tests(); + + void when_all_scope_test(); }