diff --git a/playwright/sync_base.py b/playwright/sync_base.py index 3247ee7fc..93b77ab87 100644 --- a/playwright/sync_base.py +++ b/playwright/sync_base.py @@ -13,7 +13,17 @@ # limitations under the License. import asyncio -from typing import Any, Callable, Coroutine, Generic, Optional, TypeVar, cast +from typing import ( + Any, + Callable, + Coroutine, + Dict, + Generic, + List, + Optional, + TypeVar, + cast, +) import greenlet @@ -109,3 +119,35 @@ def once(self, event_name: str, handler: Any) -> None: def remove_listener(self, event_name: str, handler: Any) -> None: self._impl_obj.remove_listener(event_name, handler) + + def _gather(self, *actions: Callable) -> List[Any]: + g_self = greenlet.getcurrent() + results: Dict[Callable, Any] = {} + exceptions: List[Exception] = [] + + def action_wrapper(action: Callable) -> Callable: + def body() -> Any: + try: + results[action] = action() + except Exception as e: + results[action] = e + exceptions.append(e) + g_self.switch() + + return body + + async def task() -> None: + for action in actions: + g = greenlet.greenlet(action_wrapper(action)) + g.switch() + + self._loop.create_task(task()) + + while len(results) < len(actions): + dispatcher_fiber_.switch() + + asyncio._set_running_loop(self._loop) + if exceptions: + raise exceptions[0] + + return list(map(lambda action: results[action], actions)) diff --git a/tests/async/test_navigation.py b/tests/async/test_navigation.py index 4b6324055..4c46ec668 100644 --- a/tests/async/test_navigation.py +++ b/tests/async/test_navigation.py @@ -658,7 +658,7 @@ async def test_expect_navigation_should_work_for_cross_process_navigations( await goto_task -@pytest.mark.skip_browser("webkit") +@pytest.mark.skip("flaky, investigate") async def test_wait_for_load_state_should_pick_up_ongoing_navigation(page, server): requests = [] diff --git a/tests/sync/test_sync.py b/tests/sync/test_sync.py index db69413fb..d3656c463 100644 --- a/tests/sync/test_sync.py +++ b/tests/sync/test_sync.py @@ -187,3 +187,13 @@ def test_sync_set_default_timeout(page): with pytest.raises(TimeoutError) as exc: page.waitForFunction("false") assert "Timeout 1ms exceeded." in exc.value.message + + +def test_close_should_reject_all_promises(context): + new_page = context.newPage() + with pytest.raises(Error) as exc_info: + new_page._gather( + lambda: new_page.evaluate("() => new Promise(r => {})"), + lambda: new_page.close(), + ) + assert "Protocol error" in exc_info.value.message