Skip to content

Commit

Permalink
Support top-level await, adding ast.PyCF_ALLOW_TOP_LEVEL_AWAIT to the…
Browse files Browse the repository at this point in the history
… compile()'s flag and replacing exec() with await eval() to handle the coroutine obtained from compile() with the top-level await flag
  • Loading branch information
whitphx committed Dec 11, 2024
1 parent 65d277b commit 288190f
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 10 deletions.
8 changes: 4 additions & 4 deletions lib/streamlit/runtime/scriptrunner/exec_code.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

from __future__ import annotations

from typing import TYPE_CHECKING, Any, Callable
from typing import TYPE_CHECKING, Any, Awaitable, Callable

from streamlit.delta_generator_singletons import (
context_dg_stack,
Expand All @@ -32,8 +32,8 @@
from streamlit.runtime.scriptrunner_utils.script_run_context import ScriptRunContext


def exec_func_with_error_handling(
func: Callable[[], Any], ctx: ScriptRunContext
async def exec_func_with_error_handling(
func: Callable[[], Awaitable[Any]], ctx: ScriptRunContext
) -> tuple[
Any | None,
bool,
Expand Down Expand Up @@ -85,7 +85,7 @@ def exec_func_with_error_handling(
uncaught_exception: Exception | None = None

try:
result = func()
result = await func()
except RerunException as e:
rerun_exception_data = e.rerun_data

Expand Down
3 changes: 2 additions & 1 deletion lib/streamlit/runtime/scriptrunner/script_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from __future__ import annotations

import ast
import os.path
import threading
from typing import Any
Expand Down Expand Up @@ -79,7 +80,7 @@ def get_bytecode(self, script_path: str) -> Any:
# mode (as opposed to "eval" or "single").
mode="exec",
# Don't inherit any flags or "future" statements.
flags=0,
flags=ast.PyCF_ALLOW_TOP_LEVEL_AWAIT,
dont_inherit=1,
# Use the default optimization options.
optimize=-1,
Expand Down
15 changes: 10 additions & 5 deletions lib/streamlit/runtime/scriptrunner/script_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import types
from contextlib import contextmanager
from enum import Enum
from inspect import CO_COROUTINE
from timeit import default_timer as timer
from typing import TYPE_CHECKING, Callable, Final

Expand Down Expand Up @@ -316,7 +317,7 @@ async def _run_script_thread(self) -> None:
# request that we'll handle immediately. When the script finishes,
# it's possible that another request has come in that we need to
# handle, which is why we call _run_script in a loop.
self._run_script(request.rerun_data)
await self._run_script(request.rerun_data)
request = self._requests.on_scriptrunner_ready()

assert request.type == ScriptRequestType.STOP
Expand Down Expand Up @@ -400,7 +401,7 @@ def _set_execing_flag(self):
finally:
self._execing = False

def _run_script(self, rerun_data: RerunData) -> None:
async def _run_script(self, rerun_data: RerunData) -> None:
"""Run our script.
Parameters
Expand Down Expand Up @@ -532,7 +533,7 @@ def _run_script(self, rerun_data: RerunData) -> None:
# assume is the main script directory.
module.__dict__["__file__"] = script_path

def code_to_exec(code=code, module=module, ctx=ctx, rerun_data=rerun_data):
async def code_to_exec(code=code, module=module, ctx=ctx, rerun_data=rerun_data):
with modified_sys_path(
self._main_script_path
), self._set_execing_flag():
Expand Down Expand Up @@ -576,7 +577,11 @@ def code_to_exec(code=code, module=module, ctx=ctx, rerun_data=rerun_data):
pass

else:
exec(code, module.__dict__)
if code.co_flags & CO_COROUTINE:
# The source code includes top-level awaits, so the compiled code object is a coroutine.
await eval(code, module.__dict__)
else:
exec(code, module.__dict__)
self._fragment_storage.clear(
new_fragment_ids=ctx.new_fragment_ids
)
Expand All @@ -592,7 +597,7 @@ def code_to_exec(code=code, module=module, ctx=ctx, rerun_data=rerun_data):
rerun_exception_data,
premature_stop,
uncaught_exception,
) = exec_func_with_error_handling(code_to_exec, ctx)
) = await exec_func_with_error_handling(code_to_exec, ctx)
# setting the session state here triggers a yield-callback call
# which reads self._requests and checks for rerun data
self._session_state[SCRIPT_RUN_WITHOUT_ERRORS_KEY] = run_without_errors
Expand Down

0 comments on commit 288190f

Please sign in to comment.