From 9923366969b42dfd7786921ec17b9616e60d33ae Mon Sep 17 00:00:00 2001 From: Two Dev Date: Sat, 14 Dec 2024 12:22:45 +0700 Subject: [PATCH] bugs: fix proxy --- README.md | 12 + tests/test_proxy.py | 49 ++++ tests/test_redirects.py | 12 + tests/{test_params.py => test_urls.py} | 0 tls_requests/__version__.py | 2 +- tls_requests/client.py | 37 ++- tls_requests/models/request.py | 8 +- tls_requests/models/urls.py | 316 ++++++++++++++++--------- tls_requests/types.py | 5 +- 9 files changed, 302 insertions(+), 139 deletions(-) create mode 100644 tests/test_proxy.py rename tests/{test_params.py => test_urls.py} (100%) diff --git a/README.md b/README.md index a24127f..86ac1a8 100644 --- a/README.md +++ b/README.md @@ -110,3 +110,15 @@ Explore the full capabilities of TLS Requests in the documentation: Read the documentation: [**thewebscraping.github.io/tls-requests/**](https://thewebscraping.github.io/tls-requests/) + +**Report Issues** +----------------- + +Found a bug? Please [open an issue](https://github.com/thewebscraping/tls-requests/issues/). + +By reporting an issue you help improve the project. + +**Credits** +----------------- + +Special thanks to [bogdanfinn](https://github.com/bogdanfinn/) for creating the awesome [tls-client](https://github.com/bogdanfinn/tls-client). diff --git a/tests/test_proxy.py b/tests/test_proxy.py new file mode 100644 index 0000000..8233c79 --- /dev/null +++ b/tests/test_proxy.py @@ -0,0 +1,49 @@ +import tls_requests + + +def test_http_proxy(): + proxy = tls_requests.Proxy("http://localhost:8080") + assert proxy.scheme == "http" + assert proxy.host == "localhost" + assert proxy.port == '8080' + assert proxy.url == "http://localhost:8080" + + +def test_https_proxy(): + proxy = tls_requests.Proxy("https://localhost:8080") + assert proxy.scheme == "https" + assert proxy.host == "localhost" + assert proxy.port == '8080' + assert proxy.url == "https://localhost:8080" + + +def test_socks5_proxy(): + proxy = tls_requests.Proxy("socks5://localhost:8080") + assert proxy.scheme == "socks5" + assert proxy.host == "localhost" + assert proxy.port == '8080' + assert proxy.url == "socks5://localhost:8080" + + +def test_proxy_with_params(): + proxy = tls_requests.Proxy("http://localhost:8080?a=b", params={"foo": "bar"}) + assert proxy.scheme == "http" + assert proxy.host == "localhost" + assert proxy.port == '8080' + assert proxy.url == "http://localhost:8080" + + +def test_auth_proxy(): + proxy = tls_requests.Proxy("http://username:password@localhost:8080") + assert proxy.scheme == "http" + assert proxy.host == "localhost" + assert proxy.port == '8080' + assert proxy.auth == ("username", "password") + assert proxy.url == "http://username:password@localhost:8080" + + +def test_unsupported_proxy_scheme(): + try: + _ = tls_requests.Proxy("unknown://localhost:8080") + except Exception as e: + assert isinstance(e, tls_requests.exceptions.ProxyError) diff --git a/tests/test_redirects.py b/tests/test_redirects.py index 13d4598..3baf8f5 100644 --- a/tests/test_redirects.py +++ b/tests/test_redirects.py @@ -10,6 +10,7 @@ def test_missing_host_redirects(httpserver: HTTPServer): httpserver.expect_request("/redirects/ok").respond_with_data(b"OK") response = tls_requests.get(httpserver.url_for("/redirects/3")) assert response.status_code == 200 + assert response.history[0].status_code == 302 assert len(response.history) == 3 @@ -20,9 +21,20 @@ def test_full_path_redirects(httpserver: HTTPServer): httpserver.expect_request("/redirects/ok").respond_with_data(b"OK") response = tls_requests.get(httpserver.url_for("/redirects/3")) assert response.status_code == 200 + assert response.history[0].status_code == 302 assert len(response.history) == 3 +def test_fragment_redirects(httpserver: HTTPServer): + httpserver.expect_request("/redirects/3").respond_with_data(b"OK", status=302, headers={"Location": httpserver.url_for("/redirects/ok#fragment")}) + httpserver.expect_request("/redirects/ok").respond_with_data(b"OK") + response = tls_requests.get(httpserver.url_for("/redirects/3")) + assert response.status_code == 200 + assert response.history[0].status_code == 302 + assert len(response.history) == 1 + assert response.request.url.fragment == "fragment" + + def test_too_many_redirects(httpserver: HTTPServer): httpserver.expect_request("/redirects/3").respond_with_data(b"OK", status=302, headers={"Location": "/redirects/1"}) httpserver.expect_request("/redirects/1").respond_with_data(b"OK", status=302, headers={"Location": "/redirects/2"}) diff --git a/tests/test_params.py b/tests/test_urls.py similarity index 100% rename from tests/test_params.py rename to tests/test_urls.py diff --git a/tls_requests/__version__.py b/tls_requests/__version__.py index 22e4cc6..a23f049 100644 --- a/tls_requests/__version__.py +++ b/tls_requests/__version__.py @@ -3,5 +3,5 @@ __url__ = "https://github.com/thewebscraping/tls-requests" __author__ = "Tu Pham" __author_email__ = "thetwofarm@gmail.com" -__version__ = "1.0.6" +__version__ = "1.0.7" __license__ = "MIT" diff --git a/tls_requests/client.py b/tls_requests/client.py index 40f9da4..4197f5a 100644 --- a/tls_requests/client.py +++ b/tls_requests/client.py @@ -8,7 +8,7 @@ from typing import (Any, Callable, Literal, Mapping, Optional, Sequence, TypeVar, Union) -from .exceptions import RemoteProtocolError, TooManyRedirects +from .exceptions import ProxyError, RemoteProtocolError, TooManyRedirects from .models import (URL, Auth, BasicAuth, Cookies, Headers, Proxy, Request, Response, StatusCodes, TLSClient, TLSConfig, URLParams) from .settings import (DEFAULT_FOLLOW_REDIRECTS, DEFAULT_HEADERS, @@ -105,7 +105,7 @@ def __init__( self._headers = Headers(headers) self._hooks = hooks if isinstance(hooks, dict) else {} self.auth = auth - self.proxy = Proxy(url=proxy) if isinstance(proxy, (str, URL)) else proxy + self.proxy = self.prepare_proxy(proxy) self.timeout = timeout self.follow_redirects = follow_redirects self.max_redirects = max_redirects @@ -194,27 +194,23 @@ def prepare_params(self, params: URLParamTypes = None) -> URLParams: merged_params = self.params.copy() return merged_params.update(params) + def prepare_proxy(self, proxy: ProxyTypes = None) -> Optional[Proxy]: + if proxy is not None: + if isinstance(proxy, (bytes, str, URL, Proxy)): + return Proxy(proxy) + + raise ProxyError("Invalid proxy.") + def prepare_config(self, request: Request): """Prepare TLS Config""" - proxy = None - if self.proxy and isinstance(self.proxy, Proxy): - proxy = self.proxy.url - if self.proxy.auth: - proxy = "%s://%s@%s:%s" % ( - self.proxy.url.scheme, - ":".join(self.proxy.auth), - self.proxy.url.host, - self.proxy.url.port, - ) - config = self.config.copy_with( method=request.method, - url=str(request.url), + url=request.url, body=request.read(), headers=dict(request.headers), cookies=[dict(name=k, value=v) for k, v in request.cookies.items()], - proxy=proxy, + proxy=request.proxy.url if request.proxy else None, timeout=request.timeout, http2=True if self.http2 in ["auto", "http2", True, None] else False, verify=self.verify, @@ -249,6 +245,7 @@ def build_request( params=self.prepare_params(params), headers=self.prepare_headers(headers), cookies=self.prepare_cookies(cookies), + proxy=self.proxy, timeout=timeout or self.timeout, ) @@ -313,15 +310,14 @@ def _rebuild_redirect_url(self, request: Request, response: Response) -> URL: except KeyError: raise RemoteProtocolError("Invalid URL in Location headers: %s" % e) - for missing_field in ["scheme", "host", "port", "fragment"]: - private_field = "_%s" % missing_field - if not getattr(url, private_field, None): - setattr(url, private_field, getattr(request.url, private_field, "")) + if not url.netloc: + for missing_field in ["scheme", "host", "port"]: + setattr(url, missing_field, getattr(request.url, missing_field, "")) # TLS error transport between HTTP/1.x -> HTTP/2 if url.scheme != request.url.scheme: if request.url.scheme == "http": - url._scheme = request.url.scheme + url.scheme = request.url.scheme else: if self.http2 in ["auto", None]: self.session.destroy_session(self.config.sessionId) @@ -331,6 +327,7 @@ def _rebuild_redirect_url(self, request: Request, response: Response) -> URL: "Switching remote scheme from HTTP/2 to HTTP/1 is not supported. Please initialize Client with parameter `http2` to `auto`." ) + setattr(url, "_url", None) # reset url if not url.url: raise RemoteProtocolError("Invalid URL in Location headers: %s" % e) diff --git a/tls_requests/models/request.py b/tls_requests/models/request.py index 24a8be2..ca961f0 100644 --- a/tls_requests/models/request.py +++ b/tls_requests/models/request.py @@ -3,11 +3,11 @@ from tls_requests.models.cookies import Cookies from tls_requests.models.encoders import StreamEncoder from tls_requests.models.headers import Headers -from tls_requests.models.urls import URL +from tls_requests.models.urls import URL, Proxy from tls_requests.settings import DEFAULT_TIMEOUT from tls_requests.types import (CookieTypes, HeaderTypes, MethodTypes, - RequestData, RequestFiles, TimeoutTypes, - URLParamTypes, URLTypes) + ProxyTypes, RequestData, RequestFiles, + TimeoutTypes, URLParamTypes, URLTypes) __all__ = ["Request"] @@ -24,6 +24,7 @@ def __init__( params: URLParamTypes = None, headers: HeaderTypes = None, cookies: CookieTypes = None, + proxy: ProxyTypes = None, timeout: TimeoutTypes = None, ) -> None: self._content = None @@ -31,6 +32,7 @@ def __init__( self.url = URL(url, params=params) self.method = method.upper() self.cookies = Cookies(cookies) + self.proxy = Proxy(proxy) if proxy else None self.timeout = timeout if isinstance(timeout, (float, int)) else DEFAULT_TIMEOUT self.stream = StreamEncoder(data, files, json) self.headers = self._prepare_headers(headers) diff --git a/tls_requests/models/urls.py b/tls_requests/models/urls.py index b6a6a9e..ef50792 100644 --- a/tls_requests/models/urls.py +++ b/tls_requests/models/urls.py @@ -8,12 +8,46 @@ import idna from tls_requests.exceptions import ProxyError, URLError, URLParamsError -from tls_requests.types import URL_ALLOWED_PARAMS, URLParamTypes +from tls_requests.types import (URL_ALLOWED_PARAMS, ProxyTypes, URLParamTypes, + URLTypes) __all__ = ["URL", "URLParams", "Proxy"] class URLParams(Mapping, ABC): + """URLParams + + Represents a mapping of URL parameters with utilities for normalization, encoding, and updating. + This class provides a dictionary-like interface for managing URL parameters, ensuring that keys + and values are properly validated and normalized. + + Attributes: + - params (str): Returns the encoded URL parameters as a query string. + + Methods: + - update(params: URLParamTypes = None, **kwargs): Updates the current parameters with new ones. + - keys() -> KeysView: Returns a view of the parameter keys. + - values() -> ValuesView: Returns a view of the parameter values. + - items() -> ItemsView: Returns a view of the parameter key-value pairs. + - copy() -> URLParams: Returns a copy of the current instance. + - normalize(s: URL_ALLOWED_PARAMS): Normalizes a key or value to a string. + + Raises: + - URLParamsError: Raised for invalid keys, values, or parameter types during initialization or updates. + + Example Usage: + >>> params = URLParams({'key1': 'value1', 'key2': ['value2', 'value3']}) + >>> print(str(params)) + 'key1=value1&key2=value2&key2=value3' + + >>> params.update({'key3': 'value4'}) + >>> print(params) + 'key1=value1&key2=value2&key2=value3&key3=value4' + + >>> 'key1' in params + True + """ + def __init__(self, params: URLParamTypes = None, **kwargs): self._data = self._prepare(params, **kwargs) @@ -104,28 +138,75 @@ def normalize(self, s: URL_ALLOWED_PARAMS): class URL: - def __init__(self, url: Union[str, bytes], params: URLParamTypes = None, **kwargs): - self._url = self._prepare(url, params) + """URL + + A utility class for parsing, manipulating, and constructing URLs. It integrates with the + `URLParams` class for managing query parameters and provides easy access to various components + of a URL, such as scheme, host, port, and path. + + Attributes: + - url (str): The raw or prepared URL string. + - params (URLParams): An instance of URLParams to manage query parameters. + - parsed (ParseResult): A `ParseResult` object containing the parsed components of the URL. + - auth (tuple): A tuple of (username, password) extracted from the URL. + - fragment (str): The fragment identifier of the URL. + - host (str): The hostname (IDNA-encoded if applicable). + - path (str): The path component of the URL. + - netloc (str): The network location (host:port if port is present). + - password (str): The password extracted from the URL. + - port (str): The port number of the URL. + - query (str): The query string, incorporating both existing and additional parameters. + - scheme (str): The URL scheme (e.g., "http", "https"). + - username (str): The username extracted from the URL. + + Methods: + - _prepare_url(url: Union[U, str, bytes]) -> str: Prepares and validates a URL string or bytes. + - _build(secure: bool = False) -> str: Constructs a URL string from its components. + + Raises: + - URLError: Raised when an invalid URL or component is encountered. + + Example Usage: + >>> url = URL("https://example.com/path?q=1#fragment", params={"key": "value"}) + >>> print(url.scheme) + 'https' + >>> print(url.host) + 'example.com' + >>> print(url.query) + 'q%3D1&key%3Dvalue' + >>> print(url.params) + 'key=value' + >>> url.params.update({'key2': 'value2'}) + >>> print(url.url) + 'https://example.com/path?q%3D1&key%3Dvalue%26key2%3Dvalue2#fragment' + >>> from urllib.parse import unquote + >>> print(unquote(url.url)) + 'https://example.com/path?q=1&key=value&key2=value2#fragment' + >>> url.url = 'https://example.org/' + >>> print(unquote(url.url)) + 'https://example.org/?key=value&key2=value2' + >>> url.url = 'https://httpbin.org/get' + >>> print(unquote(url.url)) + 'https://httpbin.org/get?key=value&key2=value2' + """ + + __attrs__ = ("auth", "fragment", "host", "path", "password", "port", "scheme", "username") + + def __init__(self, url: URLTypes, params: URLParamTypes = None, **kwargs): + self._parsed = self._prepare(url) + self._url = None self._params = URLParams(params) - self._parsed = None - self._auth = None - self._fragment = None - self._host = None - self._path = None - self._netloc = None - self._password = None - self._port = None - self._query = None - self._scheme = None - self._username = None @property def url(self): + if self._url is None: + self._url = self._build(False) return self._url @url.setter def url(self, value): - self._url = self._prepare(value) + self._parsed = self._prepare(value) + self._url = self._build(False) @property def params(self): @@ -133,92 +214,27 @@ def params(self): @params.setter def params(self, value): + self._url = None self._params = URLParams(value) @property def parsed(self) -> ParseResult: - if self._parsed is None: - self._parsed = urlparse(self.url) return self._parsed - @property - def auth(self) -> Union[tuple, None]: - if self._auth is None: - if self.parsed.username or self.parsed.password: - self._auth = self.parsed.username, self.parsed.password - else: - self._auth = None, None - return self._auth - - @property - def fragment(self) -> str: - if self._fragment is None: - self._fragment = self.parsed.fragment - return self._fragment - - @property - def host(self) -> str: - if self._host is None: - try: - self._host = idna.encode(self.parsed.hostname.lower()).decode("ascii") - except AttributeError: - self._host = "" - except idna.IDNAError: - raise URLError("Invalid IDNA hostname.") - return self._host - - @property - def path(self): - if self._path is None: - self._path = self.parsed.path - return self._path - @property def netloc(self) -> str: return ":".join([self.host, self.port]) if self.port else self.host - @property - def password(self) -> str: - if self._password is None: - self._password = self.parsed.password or "" - return self._password - - @property - def port(self) -> str: - if self._port is None: - port = "" - try: - if self.parsed.port: - port = str(self.parsed.port) - except ValueError as e: - raise URLError("%s. port range must be 0 - 65535." % e.args[0]) - - self._port = port - return self._port - @property def query(self) -> str: - if self._query is None: - self._query = "" - if self.parsed.query and self.params.params: - self._query = "&".join([quote(self.parsed.query), self.params.params]) - elif self.params.params: - self._query = self.params.params - elif self.parsed.query: - self._query = self.parsed.query - return self._query - - @property - def scheme(self) -> str: - if self._scheme is None: - self._scheme = self.parsed.scheme - return self._scheme - - @property - def username(self) -> str: - if self._username is None: - self._username = self.parsed.username or "" - return self._username + query = "" + if self.parsed.query and self.params.params: + query = "&".join([quote(self.parsed.query), self.params.params]) + elif self.params.params: + query = self.params.params + elif self.parsed.query: + query = self.parsed.query + return query def __str__(self): return self._build() @@ -226,31 +242,56 @@ def __str__(self): def __repr__(self): return "<%s: %s>" % (self.__class__.__name__, unquote(self._build(True))) - def _prepare(self, url: Union["URL", str, bytes], params: URLParamTypes) -> str: - try: - if isinstance(url, self.__class__): - return self._prepare(str(url), params) + def _prepare(self, url: Union[T, str, bytes]) -> ParseResult: + if isinstance(url, bytes): + url = url.decode("utf-8") + elif isinstance(url, self.__class__) or issubclass(self.__class__, url.__class__): + url = str(url) - if url: - url = ( - value.decode("utf-8").lstrip() - if isinstance(url, bytes) - else url.lstrip() - ) - return url + if not isinstance(url, str): + raise URLError("Invalid URL: %s" % url) - except Exception as e: - raise URLError("Invalid URL, details: %s" % e) + for attr in self.__attrs__: + setattr(self, attr, None) - raise URLError("Not found URL.") + parsed = urlparse(url.lstrip()) + try: + self.host = idna.encode(parsed.hostname.lower()).decode("ascii") + except AttributeError: + self.host = "" + except idna.IDNAError: + raise URLError("Invalid IDNA hostname.") + + self.fragment = parsed.fragment + if parsed.username or parsed.password: + self.auth = parsed.username, parsed.password + else: + self.auth = None, None + + self.path = parsed.path + self.port = "" + try: + if parsed.port: + self.port = str(parsed.port) + except ValueError as e: + raise URLError("%s. port range must be 0 - 65535." % e.args[0]) + + self.scheme = parsed.scheme + self.password = parsed.password or "" + self.username = parsed.username or "" + return parsed def _build(self, secure: bool = False) -> str: urls = [self.scheme, "://"] authority = self.netloc if self.username or self.password: + password = self.password or "" + if secure: + password = "[secure]" + authority = "@".join( [ - ":".join([self.username, "[secure]" if self.password else ""]), + ":".join([self.username, password]), self.netloc, ] ) @@ -262,20 +303,71 @@ def _build(self, secure: bool = False) -> str: urls.append(self.path) if self.fragment: - urls.append("#".join([self.fragment])) + urls.append("#" + self.fragment) return "".join(urls) class Proxy(URL): + """Proxy + + A specialized subclass of `URL` designed to handle proxy URLs with specific schemes and additional + validations. The class restricts allowed schemes to "http", "https", "socks5", and "socks5h". It + also modifies the URL construction process to focus on proxy-specific requirements. + + Attributes: + - ALLOWED_SCHEMES (tuple): A tuple of allowed schemes for the proxy ("http", "https", "socks5", "socks5h"). + Raises: + - ProxyError: Raised when an invalid proxy or unsupported protocol is encountered. + + Example Usage: + >>> proxy = Proxy("http://user:pass@127.0.0.1:8080") + >>> print(proxy.scheme) + 'http' + >>> print(proxy.netloc) + '127.0.0.1:8080' + >>> print(proxy) + 'http://user:pass@127.0.0.1:8080' + >>> print(proxy.__repr__()) + '' + + >>> socks5 = Proxy("socks5://127.0.0.1:8080") + >>> print(socks) + 'socks5://127.0.0.1:8080' + """ + ALLOWED_SCHEMES = ("http", "https", "socks5", "socks5h") - @property - def scheme(self) -> str: - if self._scheme is None: - if str(self.parsed.scheme).lower() not in self.ALLOWED_SCHEMES: - raise ProxyError("Invalid scheme.") + def _prepare(self, url: ProxyTypes) -> ParseResult: + try: + if isinstance(url, bytes): + url = url.decode("utf-8") - self._scheme = self.parsed.scheme + if isinstance(url, str): + url = url.strip() - return self._scheme + parsed = super(Proxy, self)._prepare(url) + if str(parsed.scheme).lower() not in self.ALLOWED_SCHEMES: + raise ProxyError("Invalid proxy scheme `%s`. The allowed schemes are ('http', 'https', 'socks5', 'socks5h')." % parsed.scheme) + + return urlparse("%s://%s" % (parsed.scheme, parsed.netloc)) + except URLError: + raise ProxyError("Invalid proxy: %s" % url) + + def _build(self, secure: bool = False) -> str: + urls = [self.scheme, "://"] + authority = self.netloc + if self.username or self.password: + userinfo = ":".join([self.username, self.password]) + if secure: + userinfo = "[secure]" + + authority = "@".join( + [ + userinfo, + self.netloc, + ] + ) + + urls.append(authority) + return "".join(urls) diff --git a/tls_requests/types.py b/tls_requests/types.py index 001f775..34351cc 100644 --- a/tls_requests/types.py +++ b/tls_requests/types.py @@ -18,7 +18,8 @@ "BasicAuth", ] ] -URLTypes = Union["URL", str] +URLTypes = Union["URL", str, bytes] +ProxyTypes = Union[str, bytes, "Proxy", "URL"] URL_ALLOWED_PARAMS = Union[str, bytes, int, float, bool] URLParamTypes = Optional[ Union[ @@ -128,8 +129,6 @@ ] TimeoutTypes = Optional[Union[int, float]] -ProxyTypes = Union["URL", str, "Proxy"] - ByteOrStr = Union[bytes, str] BufferTypes = Union[IO[bytes], "BytesIO", "BufferedReader"] FileContent = Union[ByteOrStr, BinaryIO]