Skip to content

Commit

Permalink
fix: don't set attr on bound method (#539)
Browse files Browse the repository at this point in the history
  • Loading branch information
pavelfeldman authored Mar 2, 2021
1 parent 246aed8 commit 436132c
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 6 deletions.
27 changes: 21 additions & 6 deletions playwright/_impl/_impl_to_api_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,13 @@
# limitations under the License.

import inspect
from typing import Any, Callable, Dict, List, Optional
from types import MethodType
from typing import Any, Callable, Dict, List, Optional, cast

from playwright._impl._api_types import Error

INSTANCE_ATTR = "_pw_api_instance"
API_ATTR = "_pw_api_instance_"
IMPL_ATTR = "_pw_impl_instance_"


class ImplWrapper:
Expand All @@ -41,10 +43,10 @@ def from_maybe_impl(self, obj: Any) -> Any:
return [self.from_maybe_impl(item) for item in obj]
api_class = self._mapping.get(type(obj))
if api_class:
api_instance = getattr(obj, INSTANCE_ATTR, None)
api_instance = getattr(obj, API_ATTR, None)
if not api_instance:
api_instance = api_class(obj)
setattr(obj, INSTANCE_ATTR, api_instance)
setattr(obj, API_ATTR, api_instance)
return api_instance
else:
return obj
Expand Down Expand Up @@ -85,8 +87,21 @@ def wrapper_func(*args: Any) -> Any:
*list(map(lambda a: self.from_maybe_impl(a), args))[:arg_count]
)

wrapper = getattr(handler, INSTANCE_ATTR, None)
if inspect.ismethod(handler):
wrapper = getattr(
cast(MethodType, handler).__self__, IMPL_ATTR + handler.__name__, None
)
if not wrapper:
wrapper = wrapper_func
setattr(
cast(MethodType, handler).__self__,
IMPL_ATTR + handler.__name__,
wrapper,
)
return wrapper

wrapper = getattr(handler, IMPL_ATTR, None)
if not wrapper:
wrapper = wrapper_func
setattr(handler, INSTANCE_ATTR, wrapper)
setattr(handler, IMPL_ATTR, wrapper)
return wrapper
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

driver_version = "1.9.0-1614037901000"


def extractall(zip: zipfile.ZipFile, path: str) -> None:
for name in zip.namelist():
member = zip.getinfo(name)
Expand Down
14 changes: 14 additions & 0 deletions tests/async/test_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,20 @@ async def test_page_events_request_should_fire_for_navigation_requests(
assert len(requests) == 1


async def test_page_events_request_should_accept_method(page: Page, server):
class Log:
def __init__(self):
self.requests = []

def handle(self, request):
self.requests.append(request)

log = Log()
page.on("request", log.handle)
await page.goto(server.EMPTY_PAGE)
assert len(log.requests) == 1


async def test_page_events_request_should_fire_for_iframes(page, server, utils):
requests = []
page.on("request", lambda r: requests.append(r))
Expand Down

0 comments on commit 436132c

Please sign in to comment.