Skip to content

Commit

Permalink
co_await support for geode Task
Browse files Browse the repository at this point in the history
see comment at the bottom of the header for more information
  • Loading branch information
matcool committed Nov 12, 2024
1 parent acad3d2 commit e61b2c0
Showing 1 changed file with 135 additions and 0 deletions.
135 changes: 135 additions & 0 deletions loader/include/Geode/utils/Task.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,17 @@
#include "../loader/Loader.hpp"
#include <mutex>
#include <string_view>
#include <coroutine>

namespace geode {
namespace geode_internal {
template <class T>
struct TaskPromise;

template <class T, class P>
struct TaskAwaiter;
}

/**
* Tasks represent an asynchronous operation that will be finished at some
* unknown point in the future. Tasks can report their progress, and will
Expand Down Expand Up @@ -152,6 +161,12 @@ namespace geode {
template <std::move_constructible T2, std::move_constructible P2>
friend class Task;

template <class>
friend struct geode_internal::TaskPromise;

template <class, class>
friend struct geode_internal::TaskAwaiter;

public:
Handle(PrivateMarker, std::string_view name) : m_name(name) {}
~Handle() {
Expand Down Expand Up @@ -307,6 +322,12 @@ namespace geode {
template <std::move_constructible T2, std::move_constructible P2>
friend class Task;

template <class>
friend struct geode_internal::TaskPromise;

template <class, class>
friend struct geode_internal::TaskAwaiter;

public:
// Allow default-construction
Task() : m_handle(nullptr) {}
Expand Down Expand Up @@ -883,3 +904,117 @@ namespace geode {

static_assert(is_filter<Task<int>>, "The Task class must be a valid event filter!");
}

// - C++20 coroutine support for Task - //

// Example usage (function must return a Task):
// ```
// Task<int> someTask() {
// auto response = co_await web::WebRequest().get("https://example.com");
// co_return response.code();
// }
// ```
// This will create a Task that will finish with the response code of the
// web request.
//
// Note: If the Task the coroutine is waiting on is cancelled, the coroutine
// will be destroyed and the Task will be cancelled as well. If the Task returned
// by the coroutine is cancelled, the coroutine will be destroyed as well and execution
// stops as soon as possible.
//
// The body of the coroutine is ran in whatever thread it got called in.
// TODO: maybe guarantee main thread?

namespace geode {
namespace geode_internal {
template <class T>
struct TaskPromise {
using MyTask = Task<T>;
std::weak_ptr<typename MyTask::Handle> m_handle;

~TaskPromise() {
// does nothing if its not pending
MyTask::cancel(m_handle.lock());
}

std::suspend_never initial_suspend() noexcept { return {}; }
std::suspend_never final_suspend() noexcept { return {}; }
// TODO: do something here?
void unhandled_exception() {}

MyTask get_return_object() {
auto handle = MyTask::Handle::create("<Coroutine Task>");
m_handle = handle;
return handle;
}

void return_value(T&& x) {
MyTask::finish(m_handle.lock(), std::move(x));
}

bool isCancelled() {
if (auto p = m_handle.lock()) {
return p->is(MyTask::Status::Cancelled);
}
return true;
}
};

template <class T, class P>
struct TaskAwaiter {
Task<T, P> task;

bool await_ready() {
return task.isFinished();
}

template <class U>
void await_suspend(std::coroutine_handle<TaskPromise<U>> handle) {
if (handle.promise().isCancelled()) {
handle.destroy();
return;
}
// this should be fine because the parent task can only have
// one pending task at a time
std::shared_ptr<Task<U>::Handle> parentHandle = handle.promise().m_handle.lock();
if (!parentHandle) {
handle.destroy();
return;
}
parentHandle->m_extraData = std::make_unique<typename Task<U>::Handle::ExtraData>(
static_cast<void*>(new EventListener<Task<T, P>>(
[handle](auto* event) {
if (event->getValue()) {
handle.resume();
}
if (event->isCancelled()) {
handle.destroy();
}
},
task
)),
+[](void* ptr) {
delete static_cast<EventListener<Task<T, P>>*>(ptr);
},
+[](void* ptr) {
static_cast<EventListener<Task<T, P>>*>(ptr)->getFilter().cancel();
}
);
}

T await_resume() {
return std::move(*task.getFinishedValue());
}
};
}
}

template <class T, class P>
auto operator co_await(geode::Task<T, P> task) {
return geode::geode_internal::TaskAwaiter<T, P>{task};
}

template <class T, class... Args>
struct std::coroutine_traits<geode::Task<T>, Args...> {
using promise_type = geode::geode_internal::TaskPromise<T>;
};

0 comments on commit e61b2c0

Please sign in to comment.