From 066d988de1ba44102cfb9df0b7ec0285f891246a Mon Sep 17 00:00:00 2001 From: Pavel Date: Sun, 2 Aug 2020 15:24:40 -0700 Subject: [PATCH] feat(gather): introduce sync gather for easier async tests migration --- playwright/sync_base.py | 34 +++++++++++++++++++++++++++++++++- tests/sync/test_sync.py | 10 ++++++++++ 2 files changed, 43 insertions(+), 1 deletion(-) diff --git a/playwright/sync_base.py b/playwright/sync_base.py index 3247ee7fc9..c9577adfc6 100644 --- a/playwright/sync_base.py +++ b/playwright/sync_base.py @@ -13,7 +13,8 @@ # limitations under the License. import asyncio -from typing import Any, Callable, Coroutine, Generic, Optional, TypeVar, cast +import sys +from typing import Any, Callable, Coroutine, Dict, Generic, List, Optional, TypeVar, cast import greenlet @@ -109,3 +110,34 @@ def once(self, event_name: str, handler: Any) -> None: def remove_listener(self, event_name: str, handler: Any) -> None: self._impl_obj.remove_listener(event_name, handler) + + def _gather(self, actions: List[Callable]) -> List[Any]: + g_self = greenlet.getcurrent() + results: Dict[Callable, Any] = {} + exceptions: List[Exception] = [] + + def action_wrapper(action): + def body(): + try: + results[action] = action() + except Exception as e: + results[action] = e + exceptions.append(e) + g_self.switch() + return body + + async def task(): + for action in actions: + g = greenlet.greenlet(action_wrapper(action)) + g.switch() + + self._loop.create_task(task()) + + while len(results) < len(actions): + dispatcher_fiber_.switch() + + asyncio._set_running_loop(self._loop) + if exceptions: + raise exceptions[0] + + return list(map(lambda action: results[action], actions)) diff --git a/tests/sync/test_sync.py b/tests/sync/test_sync.py index db69413fbe..ebe193537b 100644 --- a/tests/sync/test_sync.py +++ b/tests/sync/test_sync.py @@ -187,3 +187,13 @@ def test_sync_set_default_timeout(page): with pytest.raises(TimeoutError) as exc: page.waitForFunction("false") assert "Timeout 1ms exceeded." in exc.value.message + + +def test_close_should_reject_all_promises(context): + new_page = context.newPage() + with pytest.raises(Error) as exc_info: + new_page._gather([ + lambda: new_page.evaluate("() => new Promise(r => {})"), + lambda: new_page.close() + ]) + assert "Protocol error" in exc_info.value.message