diff --git a/curl_cffi/curl.py b/curl_cffi/curl.py index 3473ab8..f1b976c 100644 --- a/curl_cffi/curl.py +++ b/curl_cffi/curl.py @@ -2,12 +2,12 @@ import warnings from http.cookies import SimpleCookie from pathlib import Path -from typing import Any, List, Optional, Tuple, Union +from typing import Any, List, Literal, Optional, Tuple, Union, cast import certifi from ._wrapper import ffi, lib -from .const import CurlHttpVersion, CurlInfo, CurlOpt, CurlWsFlag +from .const import CurlECode, CurlHttpVersion, CurlInfo, CurlOpt, CurlWsFlag DEFAULT_CACERT = certifi.where() REASON_PHRASE_RE = re.compile(rb"HTTP/\d\.\d [0-9]{3} (.*)") @@ -17,9 +17,9 @@ class CurlError(Exception): """Base exception for curl_cffi package""" - def __init__(self, msg, code: int = 0, *args, **kwargs): + def __init__(self, msg, code: Union[CurlECode, Literal[0]] = 0, *args, **kwargs): super().__init__(msg, *args, **kwargs) - self.code = code + self.code: Union[CurlECode, Literal[0]] = code CURLINFO_TEXT = 0 @@ -143,7 +143,7 @@ def _get_error(self, errcode: int, *args: Any): return CurlError( f"Failed to {action}, curl: ({errcode}) {errmsg}. " "See https://curl.se/libcurl/c/libcurl-errors.html first for more details.", - code=errcode, + code=cast(CurlECode, errcode), ) def setopt(self, option: CurlOpt, value: Any) -> int: diff --git a/curl_cffi/requests/errors.py b/curl_cffi/requests/errors.py index b5dffa3..859fda1 100644 --- a/curl_cffi/requests/errors.py +++ b/curl_cffi/requests/errors.py @@ -1,17 +1,7 @@ -from .. import CurlError - - -class RequestsError(CurlError): - """Base exception for curl_cffi.requests package""" - - def __init__(self, msg, code=0, response=None, *args, **kwargs): - super().__init__(msg, code, *args, **kwargs) - self.response = response +# for compatibility with 0.5.x +__all__ = ["CurlError", "RequestsError", "CookieConflict", "SessionClosed"] -class CookieConflict(RequestsError): - pass - - -class SessionClosed(RequestsError): - pass +from .. import CurlError +from .exceptions import CookieConflict, SessionClosed +from .exceptions import RequestException as RequestsError diff --git a/curl_cffi/requests/exceptions.py b/curl_cffi/requests/exceptions.py index efbda06..df35b99 100644 --- a/curl_cffi/requests/exceptions.py +++ b/curl_cffi/requests/exceptions.py @@ -1,3 +1,217 @@ -from .errors import RequestsError +# Apache 2.0 License +# Vendored from https://github.com/psf/requests/blob/main/src/requests/exceptions.py +# With our own addtions -RequestsException = RequestsError +import json +from typing import Literal, Union + +from .. import CurlError +from ..const import CurlECode + + +# Note IOError is an alias of OSError in Python 3.x +class RequestException(CurlError, OSError): + """Base exception for curl_cffi.requests package""" + + def __init__(self, msg, code: Union[CurlECode, Literal[0]] = 0, response=None, *args, **kwargs): + super().__init__(msg, code, *args, **kwargs) + self.response = response + + +class CookieConflict(RequestException): + """Same cookie exists for different domains.""" + + +class SessionClosed(RequestException): + """The session has already been closed.""" + + +class ImpersonateError(RequestException): + """The impersonate config was wrong or impersonate failed.""" + + +# not used +class InvalidJSONError(RequestException): + """A JSON error occurred.""" + + +# not used +class JSONDecodeError(InvalidJSONError, json.JSONDecodeError): + """Couldn't decode the text into json""" + + +class HTTPError(RequestException): + """An HTTP error occurred.""" + + +class IncompleteRead(HTTPError): + """Incomplete read of content""" + + +class ConnectionError(RequestException): + """A Connection error occurred.""" + + +class DNSError(ConnectionError): + """Could not resolve""" + + +class ProxyError(RequestException): + """A proxy error occurred.""" + + +class SSLError(ConnectionError): + """An SSL error occurred.""" + + +class CertificateVerifyError(SSLError): + """Raised when certificate validated has failed""" + + +class Timeout(RequestException): + """The request timed out.""" + + +# not used +class ConnectTimeout(ConnectionError, Timeout): + """The request timed out while trying to connect to the remote server. + + Requests that produced this error are safe to retry. + """ + + +# not used +class ReadTimeout(Timeout): + """The server did not send any data in the allotted amount of time.""" + + +# not used +class URLRequired(RequestException): + """A valid URL is required to make a request.""" + + +class TooManyRedirects(RequestException): + """Too many redirects.""" + + +# not used +class MissingSchema(RequestException, ValueError): + """The URL scheme (e.g. http or https) is missing.""" + + +class InvalidSchema(RequestException, ValueError): + """The URL scheme provided is either invalid or unsupported.""" + + +class InvalidURL(RequestException, ValueError): + """The URL provided was somehow invalid.""" + + +# not used +class InvalidHeader(RequestException, ValueError): + """The header value provided was somehow invalid.""" + + +# not used +class InvalidProxyURL(InvalidURL): + """The proxy URL provided is invalid.""" + + +# not used +class ChunkedEncodingError(RequestException): + """The server declared chunked encoding but sent an invalid chunk.""" + + +# not used +class ContentDecodingError(RequestException): + """Failed to decode response content.""" + + +# not used +class StreamConsumedError(RequestException, TypeError): + """The content for this response was already consumed.""" + + +# does not support +class RetryError(RequestException): + """Custom retries logic failed""" + + +# not used +class UnrewindableBodyError(RequestException): + """Requests encountered an error when trying to rewind a body.""" + + +class InterfaceError(RequestException): + """A specified outgoing interface could not be used.""" + + +# Warnings + + +# not used +class RequestsWarning(Warning): + """Base warning for Requests.""" + + +# not used +class FileModeWarning(RequestsWarning, DeprecationWarning): + """A file was opened in text mode, but Requests determined its binary length.""" + + +# not used +class RequestsDependencyWarning(RequestsWarning): + """An imported dependency doesn't match the expected version range.""" + + +CODE2ERROR = { + 0: RequestException, + CurlECode.UNSUPPORTED_PROTOCOL: InvalidSchema, + CurlECode.URL_MALFORMAT: InvalidURL, + CurlECode.COULDNT_RESOLVE_PROXY: ProxyError, + CurlECode.COULDNT_RESOLVE_HOST: DNSError, + CurlECode.COULDNT_CONNECT: ConnectionError, + CurlECode.WEIRD_SERVER_REPLY: ConnectionError, + CurlECode.REMOTE_ACCESS_DENIED: ConnectionError, + CurlECode.HTTP2: HTTPError, + CurlECode.HTTP_RETURNED_ERROR: HTTPError, + CurlECode.WRITE_ERROR: RequestException, + CurlECode.READ_ERROR: RequestException, + CurlECode.OUT_OF_MEMORY: RequestException, + CurlECode.OPERATION_TIMEDOUT: Timeout, + CurlECode.SSL_CONNECT_ERROR: SSLError, + CurlECode.INTERFACE_FAILED: InterfaceError, + CurlECode.TOO_MANY_REDIRECTS: TooManyRedirects, + CurlECode.UNKNOWN_OPTION: RequestException, + CurlECode.SETOPT_OPTION_SYNTAX: RequestException, + CurlECode.GOT_NOTHING: ConnectionError, + CurlECode.SSL_ENGINE_NOTFOUND: SSLError, + CurlECode.SSL_ENGINE_SETFAILED: SSLError, + CurlECode.SEND_ERROR: ConnectionError, + CurlECode.RECV_ERROR: ConnectionError, + CurlECode.SSL_CERTPROBLEM: SSLError, + CurlECode.SSL_CIPHER: SSLError, + CurlECode.PEER_FAILED_VERIFICATION: CertificateVerifyError, + CurlECode.BAD_CONTENT_ENCODING: HTTPError, + CurlECode.SSL_ENGINE_INITFAILED: SSLError, + CurlECode.SSL_CACERT_BADFILE: SSLError, + CurlECode.SSL_CRL_BADFILE: SSLError, + CurlECode.SSL_ISSUER_ERROR: SSLError, + CurlECode.SSL_PINNEDPUBKEYNOTMATCH: SSLError, + CurlECode.SSL_INVALIDCERTSTATUS: SSLError, + CurlECode.HTTP2_STREAM: HTTPError, + CurlECode.HTTP3: HTTPError, + CurlECode.QUIC_CONNECT_ERROR: ConnectionError, + CurlECode.PROXY: ProxyError, + CurlECode.SSL_CLIENTCERT: SSLError, + CurlECode.ECH_REQUIRED: SSLError, + CurlECode.PARTIAL_FILE: IncompleteRead, +} + + +# credits: https://github.com/yt-dlp/yt-dlp/blob/master/yt_dlp/networking/_curlcffi.py#L241 +# Unlicense +def code2error(code: Union[CurlECode, Literal[0]], msg: str): + if code == CurlECode.RECV_ERROR and "CONNECT" in msg: + return ProxyError + return CODE2ERROR.get(code, RequestException) diff --git a/curl_cffi/requests/models.py b/curl_cffi/requests/models.py index c6d904c..e4766c4 100644 --- a/curl_cffi/requests/models.py +++ b/curl_cffi/requests/models.py @@ -7,7 +7,7 @@ from .. import Curl from .cookies import Cookies -from .errors import RequestsError +from .exceptions import HTTPError, RequestException from .headers import Headers CHARSET_RE = re.compile(r"charset=([\w-]+)") @@ -138,7 +138,7 @@ def _decode(self, content: bytes) -> str: def raise_for_status(self): """Raise an error if status code is not in [200, 400)""" if not self.ok: - raise RequestsError(f"HTTP Error {self.status_code}: {self.reason}") + raise HTTPError(f"HTTP Error {self.status_code}: {self.reason}") def iter_lines(self, chunk_size=None, decode_unicode=False, delimiter=None): """ @@ -179,7 +179,7 @@ def iter_content(self, chunk_size=None, decode_unicode=False): chunk = self.queue.get() # re-raise the exception if something wrong happened. - if isinstance(chunk, RequestsError): + if isinstance(chunk, RequestException): self.curl.reset() raise chunk @@ -242,7 +242,7 @@ async def aiter_content(self, chunk_size=None, decode_unicode=False): chunk = await self.queue.get() # re-raise the exception if something wrong happened. - if isinstance(chunk, RequestsError): + if isinstance(chunk, RequestException): await self.aclose() raise chunk diff --git a/curl_cffi/requests/session.py b/curl_cffi/requests/session.py index 1d51319..5d90041 100644 --- a/curl_cffi/requests/session.py +++ b/curl_cffi/requests/session.py @@ -29,13 +29,13 @@ from .. import AsyncCurl, Curl, CurlError, CurlHttpVersion, CurlInfo, CurlOpt, CurlSslVersion from ..curl import CURL_WRITEFUNC_ERROR, CurlMime from .cookies import Cookies, CookieTypes, CurlMorsel -from .errors import RequestsError, SessionClosed +from .exceptions import ImpersonateError, RequestException, SessionClosed, code2error from .headers import Headers, HeaderTypes +from .impersonate import BrowserType # noqa: F401 from .impersonate import ( TLS_CIPHER_NAME_MAP, TLS_EC_CURVES_MAP, TLS_VERSION_MAP, - BrowserType, # noqa: F401 BrowserTypeLiteral, ExtraFingerprints, ExtraFpDict, @@ -43,7 +43,7 @@ toggle_extension, ) from .models import Request, Response -from .websockets import WebSocket +from .websockets import ON_CLOSE_T, ON_ERROR_T, ON_MESSAGE_T, ON_OPEN_T, WebSocket with suppress(ImportError): import gevent @@ -616,7 +616,7 @@ def _set_curl_options( impersonate = normalize_browser_type(impersonate) ret = c.impersonate(impersonate, default_headers=default_headers) if ret != 0: - raise RequestsError(f"Impersonating {impersonate} is not supported") + raise ImpersonateError(f"Impersonating {impersonate} is not supported") # ja3 string ja3 = ja3 or self.ja3 @@ -644,7 +644,8 @@ def _set_curl_options( extra_fp = ExtraFingerprints(**extra_fp) if impersonate: warnings.warn( - "Extra fingerprints was altered after browser version was set.", stacklevel=1 + "Extra fingerprints was altered after browser version was set.", + stacklevel=1, ) self._set_extra_fp(c, extra_fp) @@ -855,10 +856,10 @@ def ws_connect( self, url, *args, - on_message: Optional[Callable[[WebSocket, bytes], None]] = None, - on_error: Optional[Callable[[WebSocket, CurlError], None]] = None, - on_open: Optional[Callable] = None, - on_close: Optional[Callable] = None, + on_message: Optional[ON_MESSAGE_T] = None, + on_error: Optional[ON_ERROR_T] = None, + on_open: Optional[ON_OPEN_T] = None, + on_close: Optional[ON_CLOSE_T] = None, **kwargs, ) -> WebSocket: """Connects to a websocket url. @@ -982,7 +983,7 @@ def perform(): except CurlError as e: rsp = self._parse_response(c, buffer, header_buffer, default_encoding) rsp.request = req - cast(queue.Queue, q).put_nowait(RequestsError(str(e), e.code, rsp)) + cast(queue.Queue, q).put_nowait(RequestException(str(e), e.code, rsp)) finally: if not cast(threading.Event, header_recved).is_set(): cast(threading.Event, header_recved).set() @@ -1003,7 +1004,7 @@ def cleanup(fut): # Raise the exception if something wrong happens when receiving the header. first_element = _peek_queue(cast(queue.Queue, q)) - if isinstance(first_element, RequestsError): + if isinstance(first_element, RequestException): c.reset() raise first_element @@ -1025,7 +1026,8 @@ def cleanup(fut): except CurlError as e: rsp = self._parse_response(c, buffer, header_buffer, default_encoding) rsp.request = req - raise RequestsError(str(e), e.code, rsp) from e + error = code2error(e.code, str(e)) + raise error(str(e), e.code, rsp) from e else: rsp = self._parse_response(c, buffer, header_buffer, default_encoding) rsp.request = req @@ -1266,7 +1268,7 @@ async def perform(): except CurlError as e: rsp = self._parse_response(curl, buffer, header_buffer, default_encoding) rsp.request = req - cast(asyncio.Queue, q).put_nowait(RequestsError(str(e), e.code, rsp)) + cast(asyncio.Queue, q).put_nowait(RequestException(str(e), e.code, rsp)) finally: if not cast(asyncio.Event, header_recved).is_set(): cast(asyncio.Event, header_recved).set() @@ -1287,7 +1289,7 @@ def cleanup(fut): rsp = self._parse_response(curl, buffer, header_buffer, default_encoding) first_element = _peek_aio_queue(cast(asyncio.Queue, q)) - if isinstance(first_element, RequestsError): + if isinstance(first_element, RequestException): self.release_curl(curl) raise first_element @@ -1305,7 +1307,8 @@ def cleanup(fut): except CurlError as e: rsp = self._parse_response(curl, buffer, header_buffer, default_encoding) rsp.request = req - raise RequestsError(str(e), e.code, rsp) from e + error = code2error(e.code, str(e)) + raise error(str(e), e.code, rsp) from e else: rsp = self._parse_response(curl, buffer, header_buffer, default_encoding) rsp.request = req diff --git a/curl_cffi/requests/websockets.py b/curl_cffi/requests/websockets.py index 836ab53..8a00b3e 100644 --- a/curl_cffi/requests/websockets.py +++ b/curl_cffi/requests/websockets.py @@ -3,8 +3,8 @@ from enum import IntEnum from typing import Callable, Optional, Tuple -from curl_cffi.const import CurlECode, CurlWsFlag -from curl_cffi.curl import CurlError +from ..const import CurlECode, CurlWsFlag +from ..curl import CurlError ON_MESSAGE_T = Callable[["WebSocket", bytes], None] ON_ERROR_T = Callable[["WebSocket", CurlError], None] diff --git a/pyproject.toml b/pyproject.toml index defb135..9e2e832 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -125,9 +125,12 @@ select = [ "UP", # pyupgrade "B", # flake8-bugbear "SIM", # flake8-simplify - "I", # isort ] +[tool.isort] +profile = "black" +line_length = 100 + [tool.mypy] python_version = "3.8" ignore_missing_imports = true