Skip to content
16 changes: 11 additions & 5 deletions py/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -489,6 +489,10 @@ def firefox_options(request):
except (AttributeError, TypeError):
raise Exception("This test requires a --driver to be specified")

# skip if not Firefox or Remote
if driver_class not in ("firefox", "remote"):
pytest.skip(f"This test requires Firefox or Remote. Got {driver_class}")

# skip tests in the 'remote' directory if run with a local driver
if request.node.path.parts[-2] == "remote" and getattr(_supported_drivers, driver_class) != "Remote":
pytest.skip(f"Remote tests can't be run with driver '{driver_class}'")
Expand All @@ -506,15 +510,17 @@ def chromium_options(request):
except (AttributeError, TypeError):
raise Exception("This test requires a --driver to be specified")

# Skip if not Chrome or Edge
if driver_class not in ("chrome", "edge"):
pytest.skip(f"This test requires Chrome or Edge, got {driver_class}")
# skip if not Chrome, Edge, or Remote
if driver_class not in ("chrome", "edge", "remote"):
pytest.skip(f"This test requires Chrome, Edge, or Remote. Got {driver_class}")

# skip tests in the 'remote' directory if run with a local driver
if request.node.path.parts[-2] == "remote" and getattr(_supported_drivers, driver_class) != "Remote":
pytest.skip(f"Remote tests can't be run with driver '{driver_class}'")

if driver_class in ("chrome", "edge"):
options = Driver.clean_options(driver_class, request)
if driver_class in ("chrome", "remote"):
options = Driver.clean_options("chrome", request)
else:
options = Driver.clean_options("edge", request)

return options
19 changes: 12 additions & 7 deletions py/selenium/webdriver/remote/client_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,22 +50,19 @@ class ClientConfig:
keep_alive = _ClientConfigDescriptor("_keep_alive")
"""Gets and Sets Keep Alive value."""
proxy = _ClientConfigDescriptor("_proxy")
"""Gets and Sets the proxy used for communicating to the driver/server."""
"""Gets and Sets the proxy used for communicating with the driver/server."""
ignore_certificates = _ClientConfigDescriptor("_ignore_certificates")
"""Gets and Sets the ignore certificate check value."""
init_args_for_pool_manager = _ClientConfigDescriptor("_init_args_for_pool_manager")
"""Gets and Sets the ignore certificate check."""
timeout = _ClientConfigDescriptor("_timeout")
"""Gets and Sets the timeout (in seconds) used for communicating to the
driver/server."""
"""Gets and Sets the timeout (in seconds) used for communicating with the driver/server."""
ca_certs = _ClientConfigDescriptor("_ca_certs")
"""Gets and Sets the path to bundle of CA certificates."""
username = _ClientConfigDescriptor("_username")
"""Gets and Sets the username used for basic authentication to the
remote."""
"""Gets and Sets the username used for basic authentication to the remote."""
password = _ClientConfigDescriptor("_password")
"""Gets and Sets the password used for basic authentication to the
remote."""
"""Gets and Sets the password used for basic authentication to the remote."""
auth_type = _ClientConfigDescriptor("_auth_type")
"""Gets and Sets the type of authentication to the remote server."""
token = _ClientConfigDescriptor("_token")
Expand All @@ -74,6 +71,10 @@ class ClientConfig:
"""Gets and Sets user agent to be added to the request headers."""
extra_headers = _ClientConfigDescriptor("_extra_headers")
"""Gets and Sets extra headers to be added to the request."""
websocket_timeout = _ClientConfigDescriptor("_websocket_timeout")
"""Gets and Sets the WebSocket response wait timeout (in seconds) used for communicating with the browser."""
websocket_interval = _ClientConfigDescriptor("_websocket_interval")
"""Gets and Sets the WebSocket response wait interval (in seconds) used for communicating with the browser."""

def __init__(
self,
Expand All @@ -90,6 +91,8 @@ def __init__(
token: Optional[str] = None,
user_agent: Optional[str] = None,
extra_headers: Optional[dict] = None,
websocket_timeout: Optional[float] = 30.0,
websocket_interval: Optional[float] = 0.1,
) -> None:
self.remote_server_addr = remote_server_addr
self.keep_alive = keep_alive
Expand All @@ -103,6 +106,8 @@ def __init__(
self.token = token
self.user_agent = user_agent
self.extra_headers = extra_headers
self.websocket_timeout = websocket_timeout
self.websocket_interval = websocket_interval

self.ca_certs = (
(os.getenv("REQUESTS_CA_BUNDLE") if "REQUESTS_CA_BUNDLE" in os.environ else certifi.where())
Expand Down
12 changes: 10 additions & 2 deletions py/selenium/webdriver/remote/webdriver.py
Original file line number Diff line number Diff line change
Expand Up @@ -1211,7 +1211,11 @@ def start_devtools(self):
return self._devtools, self._websocket_connection
if self.caps["browserName"].lower() == "firefox":
raise RuntimeError("CDP support for Firefox has been removed. Please switch to WebDriver BiDi.")
self._websocket_connection = WebSocketConnection(ws_url)
self._websocket_connection = WebSocketConnection(
ws_url,
self.command_executor.client_config.websocket_timeout,
self.command_executor.client_config.websocket_interval,
)
targets = self._websocket_connection.execute(self._devtools.target.get_targets())
for target in targets:
if target.target_id == self.current_window_handle:
Expand Down Expand Up @@ -1260,7 +1264,11 @@ def _start_bidi(self):
else:
raise WebDriverException("Unable to find url to connect to from capabilities")

self._websocket_connection = WebSocketConnection(ws_url)
self._websocket_connection = WebSocketConnection(
ws_url,
self.command_executor.client_config.websocket_timeout,
self.command_executor.client_config.websocket_interval,
)

@property
def network(self):
Expand Down
23 changes: 14 additions & 9 deletions py/selenium/webdriver/remote/websocket_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

import json
import logging
from ssl import CERT_NONE
Expand All @@ -28,16 +29,20 @@


class WebSocketConnection:
_response_wait_timeout = 30
_response_wait_interval = 0.1

_max_log_message_size = 9999

def __init__(self, url):
self.callbacks = {}
self.session_id = None
def __init__(self, url, timeout, interval):
if not isinstance(timeout, (int, float)) or timeout < 0:
raise WebDriverException("timeout must be a positive number")
if not isinstance(interval, (int, float)) or timeout < 0:
raise WebDriverException("interval must be a positive number")

self.url = url
self.response_wait_timeout = timeout
self.response_wait_interval = interval

self.callbacks = {}
self.session_id = None
self._id = 0
self._messages = {}
self._started = False
Expand All @@ -46,7 +51,7 @@ def __init__(self, url):
self._wait_until(lambda: self._started)

def close(self):
self._ws_thread.join(timeout=self._response_wait_timeout)
self._ws_thread.join(timeout=self.response_wait_timeout)
self._ws.close()
self._started = False
self._ws = None
Expand Down Expand Up @@ -142,8 +147,8 @@ def _process_message(self, message):
Thread(target=callback, args=(params,)).start()

def _wait_until(self, condition):
timeout = self._response_wait_timeout
interval = self._response_wait_interval
timeout = self.response_wait_timeout
interval = self.response_wait_interval

while timeout > 0:
result = condition()
Expand Down
30 changes: 28 additions & 2 deletions py/test/selenium/webdriver/remote/remote_connection_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
# under the License.

import base64
import time

import filetype
import pytest
Expand All @@ -40,8 +41,8 @@ def test_remote_webdriver_with_http_timeout(firefox_options, webserver):
set less than the implicit wait timeout, and verifies the http timeout
is triggered first when waiting for an element.
"""
http_timeout = 6
wait_timeout = 8
http_timeout = 4
wait_timeout = 6
server_addr = f"http://{webserver.host}:{webserver.port}"
client_config = ClientConfig(remote_server_addr=server_addr, timeout=http_timeout)
assert client_config.timeout == http_timeout
Expand All @@ -50,3 +51,28 @@ def test_remote_webdriver_with_http_timeout(firefox_options, webserver):
driver.implicitly_wait(wait_timeout)
with pytest.raises(ReadTimeoutError):
driver.find_element(By.ID, "no_element_to_be_found")


def test_remote_webdriver_with_websocket_timeout(firefox_options, webserver):
"""This test starts a remote webdriver that uses websockets, and has a websocket
client timeout less than the default. It verifies the websocket times out according
to this value.
"""
websocket_timeout = 2.0
websocket_interval = 1.0

server_addr = f"http://{webserver.host}:{webserver.port}"
client_config = ClientConfig(
remote_server_addr=server_addr, websocket_timeout=websocket_timeout, websocket_interval=websocket_interval
)
assert client_config.websocket_timeout == websocket_timeout
firefox_options.enable_bidi = True
with webdriver.Remote(options=firefox_options, client_config=client_config) as driver:
driver._start_bidi()
assert driver._websocket_connection.response_wait_timeout == websocket_timeout
assert driver._websocket_connection.response_wait_interval == websocket_interval
start = time.time()
driver._websocket_connection.close()
elapsed = time.time() - start
assert elapsed >= websocket_timeout
assert elapsed < websocket_timeout + 10
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,12 @@ def test_execute_custom_command(mock_request, remote_connection):
assert response == {"status": 200, "value": "OK"}


def test_default_websocket_settings():
config = ClientConfig(remote_server_addr="http://localhost:4444")
assert config.websocket_timeout == 30.0
assert config.websocket_interval == 0.1


def test_get_remote_connection_headers_defaults():
url = "http://remote"
headers = RemoteConnection.get_remote_connection_headers(parse.urlparse(url))
Expand Down