Skip to content

Commit

Permalink
Disallow task.yield under callback, provide callback way to yield
Browse files Browse the repository at this point in the history
  • Loading branch information
lukewagner committed Jul 3, 2024
1 parent 979a36c commit 0ab6048
Show file tree
Hide file tree
Showing 3 changed files with 108 additions and 60 deletions.
79 changes: 53 additions & 26 deletions design/mvp/CanonicalABI.md
Original file line number Diff line number Diff line change
Expand Up @@ -444,6 +444,27 @@ created by `canon_lift` and `Subtask`, which is created by `canon_lower`.
Additional sync-/async-specialized mutable state is added by the `SyncTask`,
`AsyncTask` and `AsyncSubtask` subclasses.

The `Task` class and its subclasses depend on the following two enums:
```python
class AsyncCallState(IntEnum):
STARTING = 0
STARTED = 1
RETURNED = 2
DONE = 3

class EventCode(IntEnum):
CALL_STARTING = AsyncCallState.STARTING
CALL_STARTED = AsyncCallState.STARTED
CALL_RETURNED = AsyncCallState.RETURNED
CALL_DONE = AsyncCallState.DONE
YIELDED = 4
```
The `AsyncCallState` enum describes the linear sequence of states that an async
call necessarily transitions through: [`STARTING`](Async.md#starting),
`STARTED`, [`RETURNING`](Async.md#returning) and `DONE`. The `EventCode` enum
shares common code values with `AsyncCallState` to define the set of integer
event codes that are delivered to [waiting](Async.md#waiting) or polling tasks.

A `Task` object is created for each call to `canon_lift` and is implicitly
threaded through all core function calls. This implicit `Task` parameter
specifies a concept of [the current task](Async.md#current-task) and inherently
Expand Down Expand Up @@ -520,8 +541,7 @@ All `Task`s (whether lifted `async` or not) are allowed to call `async`-lowered
imports. Calling an `async`-lowered import creates an `AsyncSubtask` (defined
below) which is stored in the current component instance's `async_subtasks`
table and tracked by the current task's `num_async_subtasks` counter, which is
guarded to be `0` in `Task.exit` (below) to ensure the
tree-structured-concurrency [component invariant].
guarded to be `0` in `Task.exit` (below) to ensure [structured concurrency].
```python
def add_async_subtask(self, subtask):
assert(subtask.supertask is None and subtask.index is None)
Expand Down Expand Up @@ -549,7 +569,7 @@ tree-structured-concurrency [component invariant].
if subtask.state == AsyncCallState.DONE:
self.inst.async_subtasks.remove(subtask.index)
self.num_async_subtasks -= 1
return (subtask.state, subtask.index)
return (EventCode(subtask.state), subtask.index)
```
While a task is running, it may call `wait` (via `canon task.wait` or, when a
`callback` is present, by returning to the event loop) to block until there is
Expand All @@ -573,6 +593,16 @@ another task:
return self.process_event(self.events.get_nowait())
```

A task may also cooperatively yield the current thread, explicitly allowing
the runtime to switch to another ready task, but without blocking on I/O (as
emulated in the Python code here by awaiting a `sleep(0)`).
```python
async def yield_(self):
self.inst.thread.release()
await asyncio.sleep(0)
await self.inst.thread.acquire()
```

Lastly, when a task exists, the runtime enforces the guard conditions mentioned
above and releases the `thread` lock, allowing other tasks to start or make
progress.
Expand Down Expand Up @@ -641,17 +671,6 @@ implementation should be able to avoid separately allocating
`pending_sync_tasks` by instead embedding a "next pending" linked list in the
`Subtask` table element of the caller.

The `AsyncTask` class dynamically checks that the task calls the
`canon_task_start` and `canon_task_return` (defined below) in the right order
before finishing the task. "The right order" is defined in terms of a simple
linear state machine that progresses through the following 4 states:
```python
class AsyncCallState(IntEnum):
STARTING = 0
STARTED = 1
RETURNED = 2
DONE = 3
```
The first 3 fields of `AsyncTask` are simply immutable copies of
arguments/immediates passed to `canon_lift` that are used later on. The last 2
fields are used to check the above-mentioned state machine transitions and also
Expand Down Expand Up @@ -1952,10 +1971,16 @@ async def canon_lift(opts, inst, callee, ft, caller, start_thunk, return_thunk):
if not opts.callback:
[] = await call_and_trap_on_throw(callee, task, [])
else:
[ctx] = await call_and_trap_on_throw(callee, task, [])
while ctx != 0:
event, payload = await task.wait()
[ctx] = await call_and_trap_on_throw(opts.callback, task, [ctx, event, payload])
[packed_ctx] = await call_and_trap_on_throw(callee, task, [])
while packed_ctx != 0:
is_yield = bool(packed_ctx & 1)
ctx = packed_ctx & ~1
if is_yield:
await task.yield_()
event, payload = (EventCode.YIELDED, 0)
else:
event, payload = await task.wait()
[packed_ctx] = await call_and_trap_on_throw(opts.callback, task, [ctx, event, payload])

assert(opts.post_return is None)
task.exit()
Expand Down Expand Up @@ -1983,11 +2008,13 @@ allow the callee to reclaim any memory. An async call doesn't need a

Within the async case, there are two sub-cases depending on whether the
`callback` `canonopt` was set. When `callback` is present, waiting happens in
an "event loop" inside `canon_lift`. Otherwise, waiting must happen by calling
`task.wait` (defined below), which potentially requires the runtime
implementation to use a fiber (aka. stackful coroutine) to switch to another
task. Thus, `callback` is an optimization for avoiding fiber creation for async
languages that don't need it (e.g., JS, Python, C# and Rust).
an "event loop" inside `canon_lift` which also allows yielding (i.e., allowing
other tasks to run without blocking) by setting the LSB of the returned `i32`.
Otherwise, waiting must happen by calling `task.wait` (defined below), which
potentially requires the runtime implementation to use a fiber (aka. stackful
coroutine) to switch to another task. Thus, `callback` is an optimization for
avoiding fiber creation for async languages that don't need it (e.g., JS,
Python, C# and Rust).

Uncaught Core WebAssembly [exceptions] result in a trap at component
boundaries. Thus, if a component wishes to signal an error, it must use some
Expand Down Expand Up @@ -2332,9 +2359,8 @@ Python `asyncio.sleep(0)` in the middle to make it clear that other
coroutines are allowed to acquire the `lock` and execute.
```python
async def canon_task_yield(task):
task.inst.thread.release()
await asyncio.sleep(0)
await task.inst.thread.acquire()
trap_if(task.opts.callback is not None)
await task.yield_()
return []
```

Expand Down Expand Up @@ -2415,6 +2441,7 @@ def canon_thread_hw_concurrency():
[JavaScript Embedding]: Explainer.md#JavaScript-embedding
[Adapter Functions]: FutureFeatures.md#custom-abis-via-adapter-functions
[Shared-Everything Dynamic Linking]: examples/SharedEverythingDynamicLinking.md
[Structured Concurrency]: Async.md#structured-concurrency

[Administrative Instructions]: https://webassembly.github.io/spec/core/exec/runtime.html#syntax-instr-admin
[Implementation Limits]: https://webassembly.github.io/spec/core/appendix/implementation.html
Expand Down
45 changes: 31 additions & 14 deletions design/mvp/canonical-abi/definitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,6 +384,19 @@ def __init__(self, rep, own, scope = None):
self.scope = scope
self.lend_count = 0

class AsyncCallState(IntEnum):
STARTING = 0
STARTED = 1
RETURNED = 2
DONE = 3

class EventCode(IntEnum):
CALL_STARTING = AsyncCallState.STARTING
CALL_STARTED = AsyncCallState.STARTED
CALL_RETURNED = AsyncCallState.RETURNED
CALL_DONE = AsyncCallState.DONE
YIELDED = 4

class Task(CallContext):
caller: Optional[Task]
borrow_count: int
Expand Down Expand Up @@ -440,13 +453,18 @@ def process_event(self, subtask):
if subtask.state == AsyncCallState.DONE:
self.inst.async_subtasks.remove(subtask.index)
self.num_async_subtasks -= 1
return (subtask.state, subtask.index)
return (EventCode(subtask.state), subtask.index)

def poll(self):
if self.events.empty():
return None
return self.process_event(self.events.get_nowait())

async def yield_(self):
self.inst.thread.release()
await asyncio.sleep(0)
await self.inst.thread.acquire()

def exit(self):
assert(self.events.empty())
trap_if(self.borrow_count != 0)
Expand Down Expand Up @@ -486,12 +504,6 @@ def exit(self):
if self.inst.pending_sync_tasks:
self.inst.pending_sync_tasks.pop(0).set_result(None)

class AsyncCallState(IntEnum):
STARTING = 0
STARTED = 1
RETURNED = 2
DONE = 3

class AsyncTask(Task):
ft: FuncType
start_thunk: Callable
Expand Down Expand Up @@ -1367,10 +1379,16 @@ async def canon_lift(opts, inst, callee, ft, caller, start_thunk, return_thunk):
if not opts.callback:
[] = await call_and_trap_on_throw(callee, task, [])
else:
[ctx] = await call_and_trap_on_throw(callee, task, [])
while ctx != 0:
event, payload = await task.wait()
[ctx] = await call_and_trap_on_throw(opts.callback, task, [ctx, event, payload])
[packed_ctx] = await call_and_trap_on_throw(callee, task, [])
while packed_ctx != 0:
is_yield = bool(packed_ctx & 1)
ctx = packed_ctx & ~1
if is_yield:
await task.yield_()
event, payload = (EventCode.YIELDED, 0)
else:
event, payload = await task.wait()
[packed_ctx] = await call_and_trap_on_throw(opts.callback, task, [ctx, event, payload])

assert(opts.post_return is None)
task.exit()
Expand Down Expand Up @@ -1512,7 +1530,6 @@ async def canon_task_poll(task, ptr):
### `canon task.yield`

async def canon_task_yield(task):
task.inst.thread.release()
await asyncio.sleep(0)
await task.inst.thread.acquire()
trap_if(task.opts.callback is not None)
await task.yield_()
return []
44 changes: 24 additions & 20 deletions design/mvp/canonical-abi/run_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -534,19 +534,19 @@ async def consumer(task, args):
consumer_heap.memory[argp] = 83
consumer_heap.memory[argp+1] = 84
fut1.set_result(None)
state, callidx = await task.wait()
assert(state == AsyncCallState.STARTED)
event, callidx = await task.wait()
assert(event == EventCode.CALL_STARTED)
assert(callidx == 1)
assert(consumer_heap.memory[retp] == 15)
fut2.set_result(None)
state, callidx = await task.wait()
assert(state == AsyncCallState.RETURNED)
event, callidx = await task.wait()
assert(event == EventCode.CALL_RETURNED)
assert(callidx == 1)
assert(consumer_heap.memory[retp] == 44)
fut3.set_result(None)
assert(task.num_async_subtasks == 1)
state, callidx = await task.wait()
assert(state == AsyncCallState.DONE)
event, callidx = await task.wait()
assert(event == EventCode.CALL_DONE)
assert(callidx == 1)
assert(task.num_async_subtasks == 0)

Expand All @@ -567,8 +567,8 @@ async def dtor(task, args):
assert(task.num_async_subtasks == 1)
assert(dtor_value is None)
dtor_fut.set_result(None)
state, callidx = await task.wait()
assert(state == AsyncCallState.DONE)
event, callidx = await task.wait()
assert(event == AsyncCallState.DONE)
assert(callidx == 1)
assert(task.num_async_subtasks == 0)

Expand Down Expand Up @@ -623,13 +623,17 @@ async def consumer(task, args):
async def callback(task, args):
assert(len(args) == 3)
if args[0] == 42:
assert(args[1] == AsyncCallState.DONE)
assert(args[1] == EventCode.CALL_DONE)
assert(args[2] == 1)
return [53]
elif args[0] == 52:
assert(args[1] == EventCode.YIELDED)
assert(args[2] == 0)
fut2.set_result(None)
return [43]
return [62]
else:
assert(args[0] == 43)
assert(args[1] == AsyncCallState.DONE)
assert(args[0] == 62)
assert(args[1] == EventCode.CALL_DONE)
assert(args[2] == 2)
[] = await canon_task_return(task, CoreFuncType(['i32'],[]), [83])
return [0]
Expand Down Expand Up @@ -693,16 +697,16 @@ async def consumer(task, args):

fut.set_result(None)
assert(producer1_done == False)
state, callidx = await task.wait()
assert(state == AsyncCallState.DONE)
event, callidx = await task.wait()
assert(event == EventCode.CALL_DONE)
assert(callidx == 1)
assert(producer1_done == True)

assert(producer2_done == False)
await canon_task_yield(task)
assert(producer2_done == True)
state, callidx = task.poll()
assert(state == AsyncCallState.DONE)
event, callidx = task.poll()
assert(event == EventCode.CALL_DONE)
assert(callidx == 2)
assert(producer2_done == True)

Expand Down Expand Up @@ -748,12 +752,12 @@ async def core_func(task, args):
assert(ret == (2 | (AsyncCallState.STARTED << 30)))

fut1.set_result(None)
state, callidx = await task.wait()
assert(state == AsyncCallState.DONE)
event, callidx = await task.wait()
assert(event == EventCode.CALL_DONE)
assert(callidx == 1)
fut2.set_result(None)
state, callidx = await task.wait()
assert(state == AsyncCallState.DONE)
event, callidx = await task.wait()
assert(event == EventCode.CALL_DONE)
assert(callidx == 2)
return []

Expand Down

0 comments on commit 0ab6048

Please sign in to comment.