From 288190f7f80486a1b82cfda15e1f9b05792552cb Mon Sep 17 00:00:00 2001 From: "Yuichiro Tachibana (Tsuchiya)" Date: Sun, 19 Mar 2023 18:18:28 +0900 Subject: [PATCH] Support top-level await, adding ast.PyCF_ALLOW_TOP_LEVEL_AWAIT to the compile()'s flag and replacing exec() with await eval() to handle the coroutine obtained from compile() with the top-level await flag --- lib/streamlit/runtime/scriptrunner/exec_code.py | 8 ++++---- .../runtime/scriptrunner/script_cache.py | 3 ++- .../runtime/scriptrunner/script_runner.py | 15 ++++++++++----- 3 files changed, 16 insertions(+), 10 deletions(-) diff --git a/lib/streamlit/runtime/scriptrunner/exec_code.py b/lib/streamlit/runtime/scriptrunner/exec_code.py index 169329efccca7..92b03527ba1c3 100644 --- a/lib/streamlit/runtime/scriptrunner/exec_code.py +++ b/lib/streamlit/runtime/scriptrunner/exec_code.py @@ -14,7 +14,7 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, Callable +from typing import TYPE_CHECKING, Any, Awaitable, Callable from streamlit.delta_generator_singletons import ( context_dg_stack, @@ -32,8 +32,8 @@ from streamlit.runtime.scriptrunner_utils.script_run_context import ScriptRunContext -def exec_func_with_error_handling( - func: Callable[[], Any], ctx: ScriptRunContext +async def exec_func_with_error_handling( + func: Callable[[], Awaitable[Any]], ctx: ScriptRunContext ) -> tuple[ Any | None, bool, @@ -85,7 +85,7 @@ def exec_func_with_error_handling( uncaught_exception: Exception | None = None try: - result = func() + result = await func() except RerunException as e: rerun_exception_data = e.rerun_data diff --git a/lib/streamlit/runtime/scriptrunner/script_cache.py b/lib/streamlit/runtime/scriptrunner/script_cache.py index b4bafd25d80d1..cec5a17988f33 100644 --- a/lib/streamlit/runtime/scriptrunner/script_cache.py +++ b/lib/streamlit/runtime/scriptrunner/script_cache.py @@ -14,6 +14,7 @@ from __future__ import annotations +import ast import os.path import threading from typing import Any @@ -79,7 +80,7 @@ def get_bytecode(self, script_path: str) -> Any: # mode (as opposed to "eval" or "single"). mode="exec", # Don't inherit any flags or "future" statements. - flags=0, + flags=ast.PyCF_ALLOW_TOP_LEVEL_AWAIT, dont_inherit=1, # Use the default optimization options. optimize=-1, diff --git a/lib/streamlit/runtime/scriptrunner/script_runner.py b/lib/streamlit/runtime/scriptrunner/script_runner.py index eefc8d7e3a987..fc6bcfe34160e 100644 --- a/lib/streamlit/runtime/scriptrunner/script_runner.py +++ b/lib/streamlit/runtime/scriptrunner/script_runner.py @@ -21,6 +21,7 @@ import types from contextlib import contextmanager from enum import Enum +from inspect import CO_COROUTINE from timeit import default_timer as timer from typing import TYPE_CHECKING, Callable, Final @@ -316,7 +317,7 @@ async def _run_script_thread(self) -> None: # request that we'll handle immediately. When the script finishes, # it's possible that another request has come in that we need to # handle, which is why we call _run_script in a loop. - self._run_script(request.rerun_data) + await self._run_script(request.rerun_data) request = self._requests.on_scriptrunner_ready() assert request.type == ScriptRequestType.STOP @@ -400,7 +401,7 @@ def _set_execing_flag(self): finally: self._execing = False - def _run_script(self, rerun_data: RerunData) -> None: + async def _run_script(self, rerun_data: RerunData) -> None: """Run our script. Parameters @@ -532,7 +533,7 @@ def _run_script(self, rerun_data: RerunData) -> None: # assume is the main script directory. module.__dict__["__file__"] = script_path - def code_to_exec(code=code, module=module, ctx=ctx, rerun_data=rerun_data): + async def code_to_exec(code=code, module=module, ctx=ctx, rerun_data=rerun_data): with modified_sys_path( self._main_script_path ), self._set_execing_flag(): @@ -576,7 +577,11 @@ def code_to_exec(code=code, module=module, ctx=ctx, rerun_data=rerun_data): pass else: - exec(code, module.__dict__) + if code.co_flags & CO_COROUTINE: + # The source code includes top-level awaits, so the compiled code object is a coroutine. + await eval(code, module.__dict__) + else: + exec(code, module.__dict__) self._fragment_storage.clear( new_fragment_ids=ctx.new_fragment_ids ) @@ -592,7 +597,7 @@ def code_to_exec(code=code, module=module, ctx=ctx, rerun_data=rerun_data): rerun_exception_data, premature_stop, uncaught_exception, - ) = exec_func_with_error_handling(code_to_exec, ctx) + ) = await exec_func_with_error_handling(code_to_exec, ctx) # setting the session state here triggers a yield-callback call # which reads self._requests and checks for rerun data self._session_state[SCRIPT_RUN_WITHOUT_ERRORS_KEY] = run_without_errors