diff --git a/py/selenium/webdriver/chrome/remote_connection.py b/py/selenium/webdriver/chrome/remote_connection.py index d20ac581f36d7..1aa34dbfa4b31 100644 --- a/py/selenium/webdriver/chrome/remote_connection.py +++ b/py/selenium/webdriver/chrome/remote_connection.py @@ -14,10 +14,12 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -import typing + +from typing import Optional from selenium.webdriver import DesiredCapabilities from selenium.webdriver.chromium.remote_connection import ChromiumRemoteConnection +from selenium.webdriver.remote.client_config import ClientConfig class ChromeRemoteConnection(ChromiumRemoteConnection): @@ -27,7 +29,8 @@ def __init__( self, remote_server_addr: str, keep_alive: bool = True, - ignore_proxy: typing.Optional[bool] = False, + ignore_proxy: Optional[bool] = False, + client_config: Optional[ClientConfig] = None, ) -> None: super().__init__( remote_server_addr=remote_server_addr, @@ -35,4 +38,5 @@ def __init__( browser_name=ChromeRemoteConnection.browser_name, keep_alive=keep_alive, ignore_proxy=ignore_proxy, + client_config=client_config, ) diff --git a/py/selenium/webdriver/chromium/remote_connection.py b/py/selenium/webdriver/chromium/remote_connection.py index 29d33499111cf..021c47737cd17 100644 --- a/py/selenium/webdriver/chromium/remote_connection.py +++ b/py/selenium/webdriver/chromium/remote_connection.py @@ -14,7 +14,9 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from typing import Optional +from selenium.webdriver.remote.client_config import ClientConfig from selenium.webdriver.remote.remote_connection import RemoteConnection @@ -25,9 +27,15 @@ def __init__( vendor_prefix: str, browser_name: str, keep_alive: bool = True, - ignore_proxy: bool = False, + ignore_proxy: Optional[bool] = False, + client_config: Optional[ClientConfig] = None, ) -> None: - super().__init__(remote_server_addr, keep_alive, ignore_proxy) + super().__init__( + remote_server_addr=remote_server_addr, + keep_alive=keep_alive, + ignore_proxy=ignore_proxy, + client_config=client_config, + ) self.browser_name = browser_name commands = self._remote_commands(vendor_prefix) for key, value in commands.items(): diff --git a/py/selenium/webdriver/common/options.py b/py/selenium/webdriver/common/options.py index e938191c79adb..3e754d2537ba6 100644 --- a/py/selenium/webdriver/common/options.py +++ b/py/selenium/webdriver/common/options.py @@ -15,6 +15,7 @@ # specific language governing permissions and limitations # under the License. import typing +import warnings from abc import ABCMeta from abc import abstractmethod from enum import Enum @@ -514,6 +515,15 @@ def add_argument(self, argument) -> None: def ignore_local_proxy_environment_variables(self) -> None: """By calling this you will ignore HTTP_PROXY and HTTPS_PROXY from being picked up and used.""" + warnings.warn( + "using ignore_local_proxy_environment_variables in Options has been deprecated, " + "instead, create a Proxy instance with ProxyType.DIRECT to ignore proxy settings, " + "pass the proxy instance into a ClientConfig constructor, " + "pass the client config instance into the Webdriver constructor", + DeprecationWarning, + stacklevel=2, + ) + super().ignore_local_proxy_environment_variables() def to_capabilities(self): diff --git a/py/selenium/webdriver/edge/remote_connection.py b/py/selenium/webdriver/edge/remote_connection.py index 5e4a3739ba0e8..8f74c9f52c5af 100644 --- a/py/selenium/webdriver/edge/remote_connection.py +++ b/py/selenium/webdriver/edge/remote_connection.py @@ -14,10 +14,12 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -import typing + +from typing import Optional from selenium.webdriver import DesiredCapabilities from selenium.webdriver.chromium.remote_connection import ChromiumRemoteConnection +from selenium.webdriver.remote.client_config import ClientConfig class EdgeRemoteConnection(ChromiumRemoteConnection): @@ -27,7 +29,8 @@ def __init__( self, remote_server_addr: str, keep_alive: bool = True, - ignore_proxy: typing.Optional[bool] = False, + ignore_proxy: Optional[bool] = False, + client_config: Optional[ClientConfig] = None, ) -> None: super().__init__( remote_server_addr=remote_server_addr, @@ -35,4 +38,5 @@ def __init__( browser_name=EdgeRemoteConnection.browser_name, keep_alive=keep_alive, ignore_proxy=ignore_proxy, + client_config=client_config, ) diff --git a/py/selenium/webdriver/firefox/remote_connection.py b/py/selenium/webdriver/firefox/remote_connection.py index 1147a6a6aaadd..502c144c41622 100644 --- a/py/selenium/webdriver/firefox/remote_connection.py +++ b/py/selenium/webdriver/firefox/remote_connection.py @@ -15,15 +15,29 @@ # specific language governing permissions and limitations # under the License. +from typing import Optional + from selenium.webdriver.common.desired_capabilities import DesiredCapabilities +from selenium.webdriver.remote.client_config import ClientConfig from selenium.webdriver.remote.remote_connection import RemoteConnection class FirefoxRemoteConnection(RemoteConnection): browser_name = DesiredCapabilities.FIREFOX["browserName"] - def __init__(self, remote_server_addr, keep_alive=True, ignore_proxy=False) -> None: - super().__init__(remote_server_addr, keep_alive, ignore_proxy) + def __init__( + self, + remote_server_addr: str, + keep_alive: bool = True, + ignore_proxy: Optional[bool] = False, + client_config: Optional[ClientConfig] = None, + ) -> None: + super().__init__( + remote_server_addr=remote_server_addr, + keep_alive=keep_alive, + ignore_proxy=ignore_proxy, + client_config=client_config, + ) self._commands["GET_CONTEXT"] = ("GET", "/session/$sessionId/moz/context") self._commands["SET_CONTEXT"] = ("POST", "/session/$sessionId/moz/context") diff --git a/py/selenium/webdriver/remote/client_config.py b/py/selenium/webdriver/remote/client_config.py new file mode 100644 index 0000000000000..62ba82076946b --- /dev/null +++ b/py/selenium/webdriver/remote/client_config.py @@ -0,0 +1,258 @@ +# Licensed to the Software Freedom Conservancy (SFC) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The SFC licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import base64 +import os +import socket +from typing import Optional +from urllib import parse + +import certifi + +from selenium.webdriver.common.proxy import Proxy +from selenium.webdriver.common.proxy import ProxyType + + +class ClientConfig: + def __init__( + self, + remote_server_addr: str, + keep_alive: Optional[bool] = True, + proxy: Optional[Proxy] = Proxy(raw={"proxyType": ProxyType.SYSTEM}), + ignore_certificates: Optional[bool] = False, + init_args_for_pool_manager: Optional[dict] = None, + timeout: Optional[int] = None, + ca_certs: Optional[str] = None, + username: Optional[str] = None, + password: Optional[str] = None, + auth_type: Optional[str] = "Basic", + token: Optional[str] = None, + ) -> None: + self.remote_server_addr = remote_server_addr + self.keep_alive = keep_alive + self.proxy = proxy + self.ignore_certificates = ignore_certificates + self.init_args_for_pool_manager = init_args_for_pool_manager or {} + self.timeout = timeout + self.username = username + self.password = password + self.auth_type = auth_type + self.token = token + + self.timeout = ( + ( + float(os.getenv("GLOBAL_DEFAULT_TIMEOUT", str(socket.getdefaulttimeout()))) + if os.getenv("GLOBAL_DEFAULT_TIMEOUT") is not None + else socket.getdefaulttimeout() + ) + if timeout is None + else timeout + ) + + self.ca_certs = ( + (os.getenv("REQUESTS_CA_BUNDLE") if "REQUESTS_CA_BUNDLE" in os.environ else certifi.where()) + if ca_certs is None + else ca_certs + ) + + @property + def remote_server_addr(self) -> str: + """:Returns: The address of the remote server.""" + return self._remote_server_addr + + @remote_server_addr.setter + def remote_server_addr(self, value: str) -> None: + """Provides the address of the remote server.""" + self._remote_server_addr = value + + @property + def keep_alive(self) -> bool: + """:Returns: The keep alive value.""" + return self._keep_alive + + @keep_alive.setter + def keep_alive(self, value: bool) -> None: + """Toggles the keep alive value. + + :Args: + - value: whether to keep the http connection alive + """ + self._keep_alive = value + + @property + def proxy(self) -> Proxy: + """:Returns: The proxy used for communicating to the driver/server.""" + return self._proxy + + @proxy.setter + def proxy(self, proxy: Proxy) -> None: + """Provides the information for communicating with the driver or + server. + For example: Proxy(raw={"proxyType": ProxyType.SYSTEM}) + + :Args: + - value: the proxy information to use to communicate with the driver or server + """ + self._proxy = proxy + + @property + def ignore_certificates(self) -> bool: + """:Returns: The ignore certificate check value.""" + return self._ignore_certificates + + @ignore_certificates.setter + def ignore_certificates(self, ignore_certificates: bool) -> None: + """Toggles the ignore certificate check. + + :Args: + - value: value of ignore certificate check + """ + self._ignore_certificates = ignore_certificates + + @property + def init_args_for_pool_manager(self) -> dict: + """:Returns: The dictionary of arguments will be appended while + initializing the pool manager.""" + return self._init_args_for_pool_manager + + @init_args_for_pool_manager.setter + def init_args_for_pool_manager(self, init_args_for_pool_manager: dict) -> None: + """Provides dictionary of arguments will be appended while initializing the pool manager. + For example: {"init_args_for_pool_manager": {"retries": 3, "block": True}} + + :Args: + - value: the dictionary of arguments will be appended while initializing the pool manager + """ + self._init_args_for_pool_manager = init_args_for_pool_manager + + @property + def timeout(self) -> int: + """:Returns: The timeout (in seconds) used for communicating to the + driver/server.""" + return self._timeout + + @timeout.setter + def timeout(self, timeout: int) -> None: + """Provides the timeout (in seconds) for communicating with the driver + or server. + + :Args: + - value: the timeout (in seconds) to use to communicate with the driver or server + """ + self._timeout = timeout + + def reset_timeout(self) -> None: + """Resets the timeout to the default value of socket.""" + self._timeout = socket.getdefaulttimeout() + + @property + def ca_certs(self) -> str: + """:Returns: The path to bundle of CA certificates.""" + return self._ca_certs + + @ca_certs.setter + def ca_certs(self, ca_certs: str) -> None: + """Provides the path to bundle of CA certificates for establishing + secure connections. + + :Args: + - value: the path to bundle of CA certificates for establishing secure connections + """ + self._ca_certs = ca_certs + + @property + def username(self) -> str: + """Returns the username used for basic authentication to the remote + server.""" + return self._username + + @username.setter + def username(self, value: str) -> None: + """Sets the username used for basic authentication to the remote + server.""" + self._username = value + + @property + def password(self) -> str: + """Returns the password used for basic authentication to the remote + server.""" + return self._password + + @password.setter + def password(self, value: str) -> None: + """Sets the password used for basic authentication to the remote + server.""" + self._password = value + + @property + def auth_type(self) -> str: + """Returns the type of authentication to the remote server.""" + return self._auth_type + + @auth_type.setter + def auth_type(self, value: str) -> None: + """Sets the type of authentication to the remote server if it is not + using basic with username and password.""" + self._auth_type = value + + @property + def token(self) -> str: + """Returns the token used for authentication to the remote server.""" + return self._token + + @token.setter + def token(self, value: str) -> None: + """Sets the token used for authentication to the remote server if + auth_type is not basic.""" + self._token = value + + def get_proxy_url(self) -> Optional[str]: + """Returns the proxy URL to use for the connection.""" + proxy_type = self.proxy.proxy_type + remote_add = parse.urlparse(self.remote_server_addr) + if proxy_type is ProxyType.DIRECT: + return None + if proxy_type is ProxyType.SYSTEM: + _no_proxy = os.environ.get("no_proxy", os.environ.get("NO_PROXY")) + if _no_proxy: + for entry in map(str.strip, _no_proxy.split(",")): + if entry == "*": + return None + n_url = parse.urlparse(entry) + if n_url.netloc and remote_add.netloc == n_url.netloc: + return None + if n_url.path in remote_add.netloc: + return None + return os.environ.get( + "https_proxy" if self.remote_server_addr.startswith("https://") else "http_proxy", + os.environ.get("HTTPS_PROXY" if self.remote_server_addr.startswith("https://") else "HTTP_PROXY"), + ) + if proxy_type is ProxyType.MANUAL: + return self.proxy.sslProxy if self.remote_server_addr.startswith("https://") else self.proxy.http_proxy + return None + + def get_auth_header(self) -> Optional[dict]: + """Returns the authorization to add to the request headers.""" + auth_type = self.auth_type.lower() + if auth_type == "basic" and self.username and self.password: + credentials = f"{self.username}:{self.password}" + encoded_credentials = base64.b64encode(credentials.encode("utf-8")).decode("utf-8") + return {"Authorization": f"Basic {encoded_credentials}"} + if auth_type == "bearer" and self.token: + return {"Authorization": f"Bearer {self.token}"} + if auth_type == "oauth" and self.token: + return {"Authorization": f"OAuth {self.token}"} + return None diff --git a/py/selenium/webdriver/remote/remote_connection.py b/py/selenium/webdriver/remote/remote_connection.py index 0a4bde22c40d2..57f38b5806574 100644 --- a/py/selenium/webdriver/remote/remote_connection.py +++ b/py/selenium/webdriver/remote/remote_connection.py @@ -16,19 +16,19 @@ # under the License. import logging -import os import platform -import socket import string +import warnings from base64 import b64encode +from typing import Optional from urllib import parse -import certifi import urllib3 from selenium import __version__ from . import utils +from .client_config import ClientConfig from .command import Command from .errorhandler import ErrorCode @@ -136,12 +136,7 @@ class RemoteConnection: """ browser_name = None - _timeout = ( - float(os.getenv("GLOBAL_DEFAULT_TIMEOUT", str(socket.getdefaulttimeout()))) - if os.getenv("GLOBAL_DEFAULT_TIMEOUT") is not None - else socket.getdefaulttimeout() - ) - _ca_certs = os.getenv("REQUESTS_CA_BUNDLE") if "REQUESTS_CA_BUNDLE" in os.environ else certifi.where() + _client_config: ClientConfig = None system = platform.system().lower() if system == "darwin": @@ -158,7 +153,12 @@ def get_timeout(cls): Timeout value in seconds for all http requests made to the Remote Connection """ - return None if cls._timeout == socket._GLOBAL_DEFAULT_TIMEOUT else cls._timeout + warnings.warn( + "get_timeout() in RemoteConnection is deprecated, get timeout from ClientConfig instance instead", + DeprecationWarning, + stacklevel=2, + ) + return cls._client_config.timeout @classmethod def set_timeout(cls, timeout): @@ -167,12 +167,22 @@ def set_timeout(cls, timeout): :Args: - timeout - timeout value for http requests in seconds """ - cls._timeout = timeout + warnings.warn( + "set_timeout() in RemoteConnection is deprecated, set timeout to ClientConfig instance in constructor instead", + DeprecationWarning, + stacklevel=2, + ) + cls._client_config.timeout = timeout @classmethod def reset_timeout(cls): """Reset the http request timeout to socket._GLOBAL_DEFAULT_TIMEOUT.""" - cls._timeout = socket._GLOBAL_DEFAULT_TIMEOUT + warnings.warn( + "reset_timeout() in RemoteConnection is deprecated, use reset_timeout() in ClientConfig instance instead", + DeprecationWarning, + stacklevel=2, + ) + cls._client_config.reset_timeout() @classmethod def get_certificate_bundle_path(cls): @@ -182,7 +192,12 @@ def get_certificate_bundle_path(cls): command executor. Defaults to certifi.where() or REQUESTS_CA_BUNDLE env variable if set. """ - return cls._ca_certs + warnings.warn( + "get_certificate_bundle_path() in RemoteConnection is deprecated, get ca_certs from ClientConfig instance instead", + DeprecationWarning, + stacklevel=2, + ) + return cls._client_config.ca_certs @classmethod def set_certificate_bundle_path(cls, path): @@ -193,7 +208,12 @@ def set_certificate_bundle_path(cls, path): :Args: - path - path of a .pem encoded certificate chain. """ - cls._ca_certs = path + warnings.warn( + "set_certificate_bundle_path() in RemoteConnection is deprecated, set ca_certs to ClientConfig instance in constructor instead", + DeprecationWarning, + stacklevel=2, + ) + cls._client_config.ca_certs = path @classmethod def get_remote_connection_headers(cls, parsed_url, keep_alive=False): @@ -222,12 +242,6 @@ def get_remote_connection_headers(cls, parsed_url, keep_alive=False): return headers - def _get_proxy_url(self): - if self._url.startswith("https://"): - return os.environ.get("https_proxy", os.environ.get("HTTPS_PROXY")) - if self._url.startswith("http://"): - return os.environ.get("http_proxy", os.environ.get("HTTP_PROXY")) - def _identify_http_proxy_auth(self): url = self._proxy_url url = url[url.find(":") + 3 :] @@ -242,15 +256,17 @@ def _separate_http_proxy_auth(self): return proxy_without_auth, auth def _get_connection_manager(self): - pool_manager_init_args = {"timeout": self.get_timeout()} - pool_manager_init_args.update(self._init_args_for_pool_manager.get("init_args_for_pool_manager", {})) + pool_manager_init_args = {"timeout": self._client_config.timeout} + pool_manager_init_args.update( + self._client_config.init_args_for_pool_manager.get("init_args_for_pool_manager", {}) + ) - if self._ignore_certificates: + if self._client_config.ignore_certificates: pool_manager_init_args["cert_reqs"] = "CERT_NONE" urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning) - elif self._ca_certs: + elif self._client_config.ca_certs: pool_manager_init_args["cert_reqs"] = "CERT_REQUIRED" - pool_manager_init_args["ca_certs"] = self._ca_certs + pool_manager_init_args["ca_certs"] = self._client_config.ca_certs if self._proxy_url: if self._proxy_url.lower().startswith("sock"): @@ -266,38 +282,61 @@ def _get_connection_manager(self): def __init__( self, - remote_server_addr: str, - keep_alive: bool = False, - ignore_proxy: bool = False, - ignore_certificates: bool = False, - init_args_for_pool_manager: dict = None, + remote_server_addr: Optional[str] = None, + keep_alive: Optional[bool] = True, + ignore_proxy: Optional[bool] = False, + ignore_certificates: Optional[bool] = False, + init_args_for_pool_manager: Optional[dict] = None, + client_config: Optional[ClientConfig] = None, ): - self.keep_alive = keep_alive - self._url = remote_server_addr - self._ignore_certificates = ignore_certificates - self._init_args_for_pool_manager = init_args_for_pool_manager or {} - - # Env var NO_PROXY will override this part of the code - _no_proxy = os.environ.get("no_proxy", os.environ.get("NO_PROXY")) - if _no_proxy: - for npu in _no_proxy.split(","): - npu = npu.strip() - if npu == "*": - ignore_proxy = True - break - n_url = parse.urlparse(npu) - remote_add = parse.urlparse(self._url) - if n_url.netloc: - if remote_add.netloc == n_url.netloc: - ignore_proxy = True - break - else: - if n_url.path in remote_add.netloc: - ignore_proxy = True - break - - self._proxy_url = self._get_proxy_url() if not ignore_proxy else None - if keep_alive: + self._client_config = client_config or ClientConfig( + remote_server_addr=remote_server_addr, + keep_alive=keep_alive, + ignore_certificates=ignore_certificates, + init_args_for_pool_manager=init_args_for_pool_manager, + ) + + RemoteConnection._client_config = self._client_config + + if remote_server_addr: + warnings.warn( + "setting remote_server_addr in RemoteConnection() is deprecated, set in ClientConfig instance instead", + DeprecationWarning, + stacklevel=2, + ) + + if not keep_alive: + warnings.warn( + "setting keep_alive in RemoteConnection() is deprecated, set in ClientConfig instance instead", + DeprecationWarning, + stacklevel=2, + ) + + if ignore_certificates: + warnings.warn( + "setting ignore_certificates in RemoteConnection() is deprecated, set in ClientConfig instance instead", + DeprecationWarning, + stacklevel=2, + ) + + if init_args_for_pool_manager: + warnings.warn( + "setting init_args_for_pool_manager in RemoteConnection() is deprecated, set in ClientConfig instance instead", + DeprecationWarning, + stacklevel=2, + ) + + if ignore_proxy: + warnings.warn( + "setting ignore_proxy in RemoteConnection() is deprecated, set in ClientConfig instance instead", + DeprecationWarning, + stacklevel=2, + ) + self._proxy_url = None + else: + self._proxy_url = self._client_config.get_proxy_url() + + if self._client_config.keep_alive: self._conn = self._get_connection_manager() self._commands = remote_commands @@ -331,7 +370,7 @@ def execute(self, command, params): for word in substitute_params: del params[word] data = utils.dump_json(params) - url = f"{self._url}{path}" + url = f"{self._client_config.remote_server_addr}{path}" trimmed = self._trim_large_entries(params) LOGGER.debug("%s %s %s", command_info[0], url, str(trimmed)) return self._request(command_info[0], url, body=data) @@ -348,12 +387,16 @@ def _request(self, method, url, body=None, timeout=120): A dictionary with the server's parsed JSON response. """ parsed_url = parse.urlparse(url) - headers = self.get_remote_connection_headers(parsed_url, self.keep_alive) - response = None + headers = self.get_remote_connection_headers(parsed_url, self._client_config.keep_alive) + auth_header = self._client_config.get_auth_header() + + if auth_header: + headers.update(auth_header) + if body and method not in ("POST", "PUT"): body = None - if self.keep_alive: + if self._client_config.keep_alive: response = self._conn.request(method, url, body=body, headers=headers, timeout=timeout) statuscode = response.status else: diff --git a/py/selenium/webdriver/remote/webdriver.py b/py/selenium/webdriver/remote/webdriver.py index 8ef6292012089..bae7f4e8d28c1 100644 --- a/py/selenium/webdriver/remote/webdriver.py +++ b/py/selenium/webdriver/remote/webdriver.py @@ -54,6 +54,7 @@ from selenium.webdriver.support.relative_locator import RelativeBy from .bidi_connection import BidiConnection +from .client_config import ClientConfig from .command import Command from .errorhandler import ErrorHandler from .file_detector import FileDetector @@ -96,7 +97,17 @@ def _create_caps(caps): return {"capabilities": {"firstMatch": [{}], "alwaysMatch": always_match}} -def get_remote_connection(capabilities, command_executor, keep_alive, ignore_local_proxy=False): +def get_remote_connection( + capabilities: dict, + command_executor: Union[str, RemoteConnection], + keep_alive: bool, + ignore_local_proxy: bool, + client_config: Optional[ClientConfig] = None, +) -> RemoteConnection: + if isinstance(command_executor, str): + client_config = client_config or ClientConfig(remote_server_addr=command_executor) + client_config.remote_server_addr = command_executor + command_executor = RemoteConnection(client_config=client_config) from selenium.webdriver.chrome.remote_connection import ChromeRemoteConnection from selenium.webdriver.edge.remote_connection import EdgeRemoteConnection from selenium.webdriver.firefox.remote_connection import FirefoxRemoteConnection @@ -105,7 +116,12 @@ def get_remote_connection(capabilities, command_executor, keep_alive, ignore_loc candidates = [ChromeRemoteConnection, EdgeRemoteConnection, SafariRemoteConnection, FirefoxRemoteConnection] handler = next((c for c in candidates if c.browser_name == capabilities.get("browserName")), RemoteConnection) - return handler(command_executor, keep_alive=keep_alive, ignore_proxy=ignore_local_proxy) + return handler( + remote_server_addr=command_executor, + keep_alive=keep_alive, + ignore_proxy=ignore_local_proxy, + client_config=client_config, + ) def create_matches(options: List[BaseOptions]) -> Dict: @@ -174,6 +190,7 @@ def __init__( options: Optional[Union[BaseOptions, List[BaseOptions]]] = None, locator_converter: Optional[LocatorConverter] = None, web_element_cls: Optional[type] = None, + client_config: Optional[ClientConfig] = None, ) -> None: """Create a new driver that will issue commands using the wire protocol. @@ -181,13 +198,14 @@ def __init__( :Args: - command_executor - Either a string representing URL of the remote server or a custom remote_connection.RemoteConnection object. Defaults to 'http://127.0.0.1:4444/wd/hub'. - - keep_alive - Whether to configure remote_connection.RemoteConnection to use + - keep_alive - (Deprecated) Whether to configure remote_connection.RemoteConnection to use HTTP keep-alive. Defaults to True. - file_detector - Pass custom file detector object during instantiation. If None, then default LocalFileDetector() will be used. - options - instance of a driver options.Options class - locator_converter - Custom locator converter to use. Defaults to None. - web_element_cls - Custom class to use for web elements. Defaults to WebElement. + - client_config - Custom client configuration to use. Defaults to None. """ if isinstance(options, list): @@ -203,6 +221,7 @@ def __init__( command_executor=command_executor, keep_alive=keep_alive, ignore_local_proxy=_ignore_local_proxy, + client_config=client_config, ) self._is_remote = True self.session_id = None diff --git a/py/selenium/webdriver/safari/remote_connection.py b/py/selenium/webdriver/safari/remote_connection.py index a97f614a98585..05dbfb379b4c2 100644 --- a/py/selenium/webdriver/safari/remote_connection.py +++ b/py/selenium/webdriver/safari/remote_connection.py @@ -15,15 +15,29 @@ # specific language governing permissions and limitations # under the License. +from typing import Optional + from selenium.webdriver.common.desired_capabilities import DesiredCapabilities +from selenium.webdriver.remote.client_config import ClientConfig from selenium.webdriver.remote.remote_connection import RemoteConnection class SafariRemoteConnection(RemoteConnection): browser_name = DesiredCapabilities.SAFARI["browserName"] - def __init__(self, remote_server_addr: str, keep_alive: bool = True, ignore_proxy: bool = False) -> None: - super().__init__(remote_server_addr, keep_alive, ignore_proxy) + def __init__( + self, + remote_server_addr: str, + keep_alive: bool = True, + ignore_proxy: Optional[bool] = False, + client_config: Optional[ClientConfig] = None, + ) -> None: + super().__init__( + remote_server_addr=remote_server_addr, + keep_alive=keep_alive, + ignore_proxy=ignore_proxy, + client_config=client_config, + ) self._commands["GET_PERMISSIONS"] = ("GET", "/session/$sessionId/apple/permissions") self._commands["SET_PERMISSIONS"] = ("POST", "/session/$sessionId/apple/permissions") diff --git a/py/test/unit/selenium/webdriver/remote/remote_connection_tests.py b/py/test/unit/selenium/webdriver/remote/remote_connection_tests.py index 260214e5c918d..e3d8e29f38e69 100644 --- a/py/test/unit/selenium/webdriver/remote/remote_connection_tests.py +++ b/py/test/unit/selenium/webdriver/remote/remote_connection_tests.py @@ -22,6 +22,7 @@ import urllib3 from selenium import __version__ +from selenium.webdriver.remote.remote_connection import ClientConfig from selenium.webdriver.remote.remote_connection import RemoteConnection @@ -76,26 +77,35 @@ def test_get_remote_connection_headers_adds_keep_alive_if_requested(): def test_get_proxy_url_http(mock_proxy_settings): proxy = "http://http_proxy.com:8080" remote_connection = RemoteConnection("http://remote", keep_alive=False) - proxy_url = remote_connection._get_proxy_url() + proxy_url = remote_connection._client_config.get_proxy_url() assert proxy_url == proxy +def test_get_auth_header_if_client_config_pass(): + custom_config = ClientConfig( + remote_server_addr="http://remote", keep_alive=True, username="user", password="pass", auth_type="Basic" + ) + remote_connection = RemoteConnection(custom_config.remote_server_addr, client_config=custom_config) + headers = remote_connection._client_config.get_auth_header() + assert headers.get("Authorization") == "Basic dXNlcjpwYXNz" + + def test_get_proxy_url_https(mock_proxy_settings): proxy = "http://https_proxy.com:8080" remote_connection = RemoteConnection("https://remote", keep_alive=False) - proxy_url = remote_connection._get_proxy_url() + proxy_url = remote_connection._client_config.get_proxy_url() assert proxy_url == proxy def test_get_proxy_url_none(mock_proxy_settings_missing): remote_connection = RemoteConnection("https://remote", keep_alive=False) - proxy_url = remote_connection._get_proxy_url() + proxy_url = remote_connection._client_config.get_proxy_url() assert proxy_url is None def test_get_proxy_url_http_auth(mock_proxy_auth_settings): remote_connection = RemoteConnection("http://remote", keep_alive=False) - proxy_url = remote_connection._get_proxy_url() + proxy_url = remote_connection._client_config.get_proxy_url() raw_proxy_url, basic_auth_string = remote_connection._separate_http_proxy_auth() assert proxy_url == "http://user:password@http_proxy.com:8080" assert raw_proxy_url == "http://http_proxy.com:8080" @@ -104,7 +114,7 @@ def test_get_proxy_url_http_auth(mock_proxy_auth_settings): def test_get_proxy_url_https_auth(mock_proxy_auth_settings): remote_connection = RemoteConnection("https://remote", keep_alive=False) - proxy_url = remote_connection._get_proxy_url() + proxy_url = remote_connection._client_config.get_proxy_url() raw_proxy_url, basic_auth_string = remote_connection._separate_http_proxy_auth() assert proxy_url == "https://user:password@https_proxy.com:8080" assert raw_proxy_url == "https://https_proxy.com:8080" @@ -117,9 +127,10 @@ def test_get_connection_manager_without_proxy(mock_proxy_settings_missing): assert isinstance(conn, urllib3.PoolManager) -def test_get_connection_manager_for_certs_and_timeout(monkeypatch): - monkeypatch.setattr(RemoteConnection, "get_timeout", lambda _: 10) # Class state; leaks into subsequent tests. +def test_get_connection_manager_for_certs_and_timeout(): remote_connection = RemoteConnection("http://remote", keep_alive=False) + remote_connection.set_timeout(10) + assert remote_connection.get_timeout() == 10 conn = remote_connection._get_connection_manager() assert conn.connection_pool_kw["timeout"] == 10 assert conn.connection_pool_kw["cert_reqs"] == "CERT_REQUIRED" @@ -296,21 +307,74 @@ def test_register_extra_headers(mock_request, remote_connection): assert headers["Foo"] == "bar" -def test_get_connection_manager_ignores_certificates(monkeypatch): - monkeypatch.setattr(RemoteConnection, "get_timeout", lambda _: 10) - remote_connection = RemoteConnection("http://remote", ignore_certificates=True) +def test_get_connection_manager_with_timeout_from_client_config(): + remote_connection = RemoteConnection(remote_server_addr="http://remote", keep_alive=False) + remote_connection.set_timeout(10) + conn = remote_connection._get_connection_manager() + assert remote_connection.get_timeout() == 10 + assert conn.connection_pool_kw["timeout"] == 10 + assert isinstance(conn, urllib3.PoolManager) + + client_config = ClientConfig("http://remote", timeout=300) + remote_connection = RemoteConnection(client_config=client_config) + conn = remote_connection._get_connection_manager() + assert conn.connection_pool_kw["timeout"] == 300 + assert isinstance(conn, urllib3.PoolManager) + + +def test_get_connection_manager_with_ca_certs_from_client_config(): + remote_connection = RemoteConnection(remote_server_addr="http://remote") + remote_connection.set_certificate_bundle_path("/path/to/cacert.pem") + conn = remote_connection._get_connection_manager() + assert conn.connection_pool_kw["timeout"] is None + assert conn.connection_pool_kw["cert_reqs"] == "CERT_REQUIRED" + assert conn.connection_pool_kw["ca_certs"] == "/path/to/cacert.pem" + assert isinstance(conn, urllib3.PoolManager) + + client_config = ClientConfig(remote_server_addr="http://remote", ca_certs="/path/to/cacert.pem") + remote_connection = RemoteConnection(client_config=client_config) + conn = remote_connection._get_connection_manager() + assert conn.connection_pool_kw["timeout"] is None + assert conn.connection_pool_kw["cert_reqs"] == "CERT_REQUIRED" + assert conn.connection_pool_kw["ca_certs"] == "/path/to/cacert.pem" + assert isinstance(conn, urllib3.PoolManager) + + +def test_get_connection_manager_ignores_certificates(): + remote_connection = RemoteConnection(remote_server_addr="http://remote", keep_alive=False, ignore_certificates=True) + remote_connection.set_timeout(10) conn = remote_connection._get_connection_manager() + assert conn.connection_pool_kw["timeout"] == 10 + assert conn.connection_pool_kw["cert_reqs"] == "CERT_NONE" + assert isinstance(conn, urllib3.PoolManager) + client_config = ClientConfig(remote_server_addr="http://remote", ignore_certificates=True, timeout=10) + remote_connection = RemoteConnection(client_config=client_config) + conn = remote_connection._get_connection_manager() assert conn.connection_pool_kw["timeout"] == 10 assert conn.connection_pool_kw["cert_reqs"] == "CERT_NONE" assert isinstance(conn, urllib3.PoolManager) + remote_connection.reset_timeout() + assert remote_connection.get_timeout() is None + def test_get_connection_manager_with_custom_args(): custom_args = {"init_args_for_pool_manager": {"retries": 3, "block": True}} - remote_connection = RemoteConnection("http://remote", keep_alive=False, init_args_for_pool_manager=custom_args) + + remote_connection = RemoteConnection( + remote_server_addr="http://remote", keep_alive=False, init_args_for_pool_manager=custom_args + ) conn = remote_connection._get_connection_manager() + assert isinstance(conn, urllib3.PoolManager) + assert conn.connection_pool_kw["retries"] == 3 + assert conn.connection_pool_kw["block"] is True + client_config = ClientConfig( + remote_server_addr="http://remote", keep_alive=False, init_args_for_pool_manager=custom_args + ) + remote_connection = RemoteConnection(client_config=client_config) + conn = remote_connection._get_connection_manager() assert isinstance(conn, urllib3.PoolManager) assert conn.connection_pool_kw["retries"] == 3 assert conn.connection_pool_kw["block"] is True