Skip to content

Commit

Permalink
[py] Add low-level sync API to use DevTools (SeleniumHQ#13977)
Browse files Browse the repository at this point in the history
  • Loading branch information
p0deje authored and sandeepsuryaprasad committed Oct 29, 2024
1 parent eefec24 commit aa4dd0d
Show file tree
Hide file tree
Showing 7 changed files with 198 additions and 0 deletions.
2 changes: 2 additions & 0 deletions py/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,7 @@ py_library(
requirement("trio_websocket"),
requirement("urllib3"),
requirement("certifi"),
requirement("websocket-client"),
],
)

Expand Down Expand Up @@ -263,6 +264,7 @@ py_wheel(
"trio-websocket~=0.9",
"certifi>=2021.10.8",
"typing_extensions>=4.9.0",
"websocket-client>=1.8.0",
],
strip_path_prefixes = [
"py/",
Expand Down
1 change: 1 addition & 0 deletions py/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ def event_class(method):
''' A decorator that registers a class as an event class. '''
def decorate(cls):
_event_parsers[method] = cls
cls.event_class = method
return cls
return decorate
Expand Down
1 change: 1 addition & 0 deletions py/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -32,5 +32,6 @@ trio-websocket==0.9.2
twine==4.0.2
typing_extensions==4.9.0
urllib3[socks]==2.0.7
websocket-client==1.8.0
wsproto==1.2.0
zipp==3.17.0
4 changes: 4 additions & 0 deletions py/requirements_lock.txt
Original file line number Diff line number Diff line change
Expand Up @@ -521,6 +521,10 @@ urllib3[socks]==2.0.7 \
# -r py/requirements.txt
# requests
# twine
websocket-client==1.8.0 \
--hash=sha256:17b44cc997f5c498e809b22cdf2d9c7a9e71c02c8cc2b6c56e7c2d1239bfa526 \
--hash=sha256:3239df9f44da632f96012472805d40a23281a991027ce11d2f45a6f24ac4c3da
# via -r py/requirements.txt
wsproto==1.2.0 \
--hash=sha256:ad565f26ecb92588a3e43bc3d96164de84cd9902482b130d0ddbaa9664a85065 \
--hash=sha256:b9acddd652b585d75b20477888c56642fdade28bdfd3579aa24a4d2c037dd736
Expand Down
29 changes: 29 additions & 0 deletions py/selenium/webdriver/remote/webdriver.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,10 @@
from .shadowroot import ShadowRoot
from .switch_to import SwitchTo
from .webelement import WebElement
from .websocket_connection import WebSocketConnection

cdp = None
devtools = None


def import_cdp():
Expand Down Expand Up @@ -207,6 +209,7 @@ def __init__(
self._authenticator_id = None
self.start_client()
self.start_session(capabilities)
self._websocket_connection = None

def __repr__(self):
return f'<{type(self).__module__}.{type(self).__name__} (session="{self.session_id}")>'
Expand Down Expand Up @@ -1018,6 +1021,32 @@ def get_log(self, log_type):
"""
return self.execute(Command.GET_LOG, {"type": log_type})["value"]

def start_devtools(self):
global devtools
if self._websocket_connection:
return devtools, self._websocket_connection
else:
global cdp
import_cdp()

if not devtools:
if self.caps.get("se:cdp"):
ws_url = self.caps.get("se:cdp")
version = self.caps.get("se:cdpVersion").split(".")[0]
else:
version, ws_url = self._get_cdp_details()

if not ws_url:
raise WebDriverException("Unable to find url to connect to from capabilities")

devtools = cdp.import_devtools(version)
self._websocket_connection = WebSocketConnection(ws_url)
targets = self._websocket_connection.execute(devtools.target.get_targets())
target_id = targets[0].target_id
session = self._websocket_connection.execute(devtools.target.attach_to_target(target_id, True))
self._websocket_connection.session_id = session
return devtools, self._websocket_connection

@asynccontextmanager
async def bidi_connection(self):
global cdp
Expand Down
125 changes: 125 additions & 0 deletions py/selenium/webdriver/remote/websocket_connection.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
# Licensed to the Software Freedom Conservancy (SFC) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The SFC licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import json
import logging
from ssl import CERT_NONE
from threading import Thread
from time import sleep

from websocket import WebSocketApp

logger = logging.getLogger("websocket")


class WebSocketConnection:
_response_wait_timeout = 30
_response_wait_interval = 0.1

_max_log_message_size = 9999

def __init__(self, url):
self.session_id = None
self.url = url

self._id = 0
self._callbacks = {}
self._messages = {}
self._started = False

self._start_ws()
self._wait_until(lambda: self._started)

def close(self):
self._ws_thread.join(timeout=self._response_wait_timeout)
self._ws.close()
self._started = False
self._ws = None

def execute(self, command):
self._id += 1
payload = self._serialize_command(command)
payload["id"] = self._id
if self.session_id:
payload["sessionId"] = self.session_id

data = json.dumps(payload)
logger.debug(f"WebSocket -> {data}"[: self._max_log_message_size])
self._ws.send(data)

self._wait_until(lambda: self._id in self._messages)
result = self._messages.pop(self._id)["result"]
return self._deserialize_result(result, command)

def on(self, event, callback):
if event not in self._callbacks:
self._callbacks[event.event_class] = []
self._callbacks[event.event_class].append(lambda params: callback(event.from_json(params)))

def _serialize_command(self, command):
return next(command)

def _deserialize_result(self, result, command):
try:
_ = command.send(result)
raise Exception("The command's generator function did not exit when expected!")
except StopIteration as exit:
return exit.value

def _start_ws(self):
def on_open(ws):
self._started = True

def on_message(ws, message):
self._process_message(message)

def on_error(ws, error):
logger.debug(f"WebSocket error: {error}")
ws.close()

def run_socket():
if self.url.startswith("wss://"):
self._ws.run_forever(sslopt={"cert_reqs": CERT_NONE}, suppress_origin=True)
else:
self._ws.run_forever(suppress_origin=True)

self._ws = WebSocketApp(self.url, on_open=on_open, on_message=on_message, on_error=on_error)
self._ws_thread = Thread(target=run_socket)
self._ws_thread.start()

def _process_message(self, message):
message = json.loads(message)
logger.debug(f"WebSocket <- {message}"[: self._max_log_message_size])

if "id" in message:
self._messages[message["id"]] = message

if "method" in message:
params = message["params"]
for callback in self._callbacks.get(message["method"], []):
callback(params)

def _wait_until(self, condition):
timeout = self._response_wait_timeout
interval = self._response_wait_interval

while timeout > 0:
result = condition()
if result:
return result
else:
timeout -= interval
sleep(interval)
36 changes: 36 additions & 0 deletions py/test/selenium/webdriver/common/devtools_tests.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# Licensed to the Software Freedom Conservancy (SFC) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The SFC licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import pytest

from selenium.webdriver.support.ui import WebDriverWait


@pytest.mark.xfail_safari
def test_check_console_messages(driver, pages):
devtools, connection = driver.start_devtools()
console_api_calls = []

connection.execute(devtools.runtime.enable())
connection.on(devtools.runtime.ConsoleAPICalled, console_api_calls.append)
driver.execute_script("console.log('I love cheese')")
driver.execute_script("console.error('I love bread')")
WebDriverWait(driver, 5).until(lambda _: len(console_api_calls) == 2)

assert console_api_calls[0].type_ == "log"
assert console_api_calls[0].args[0].value == "I love cheese"
assert console_api_calls[1].type_ == "error"
assert console_api_calls[1].args[0].value == "I love bread"

0 comments on commit aa4dd0d

Please sign in to comment.