Skip to content

Commit

Permalink
Add request id to message in ConnectionError (#1544)
Browse files Browse the repository at this point in the history
* Add request id to message in ConnectionError

* X_AMZN_TRACE_ID

* fix tests
  • Loading branch information
Wauplin committed Jul 7, 2023
1 parent 95dcdd5 commit 554215b
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 11 deletions.
26 changes: 23 additions & 3 deletions src/huggingface_hub/utils/_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,28 +25,48 @@
from requests import Response
from requests.adapters import HTTPAdapter
from requests.exceptions import ProxyError, Timeout
from requests.models import PreparedRequest

from . import logging
from ._typing import HTTP_METHOD_T


logger = logging.get_logger(__name__)

# Both headers are used by the Hub to debug failed requests.
# `X_AMZN_TRACE_ID` is better as it also works to debug on Cloudfront and ALB.
# If `X_AMZN_TRACE_ID` is set, the Hub will use it as well.
X_AMZN_TRACE_ID = "X-Amzn-Trace-Id"
X_REQUEST_ID = "x-request-id"


class UniqueRequestIdAdapter(HTTPAdapter):
X_AMZN_TRACE_ID = "X-Amzn-Trace-Id"

def add_headers(self, request, **kwargs):
super().add_headers(request, **kwargs)

# Add random request ID => easier for server-side debug
if "x-request-id" not in request.headers:
request.headers["x-request-id"] = str(uuid.uuid4())
if X_AMZN_TRACE_ID not in request.headers:
request.headers[X_AMZN_TRACE_ID] = request.headers.get(X_REQUEST_ID) or str(uuid.uuid4())

# Add debug log
has_token = str(request.headers.get("authorization", "")).startswith("Bearer hf_")
logger.debug(
f"Request {request.headers['x-request-id']}: {request.method} {request.url} (authenticated: {has_token})"
f"Request {request.headers[X_AMZN_TRACE_ID]}: {request.method} {request.url} (authenticated: {has_token})"
)

def send(self, request: PreparedRequest, *args, **kwargs) -> Response:
"""Catch any RequestException to append request id to the error message for debugging."""
try:
return super().send(request, *args, **kwargs)
except requests.RequestException as e:
request_id = request.headers.get(X_AMZN_TRACE_ID)
if request_id is not None:
# Taken from https://stackoverflow.com/a/58270258
e.args = (*e.args, f"(Request ID: {request_id})")
raise


def _default_backend_factory() -> requests.Session:
session = requests.Session()
Expand Down
16 changes: 8 additions & 8 deletions tests/test_utils_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,21 +220,21 @@ class TestUniqueRequestId(unittest.TestCase):
def test_request_id_is_used_by_server(self):
response = get_session().get(self.api_endpoint)

request_id = response.request.headers.get("x-request-id")
request_id = response.request.headers.get("X-Amzn-Trace-Id")
response_id = response.headers.get("x-request-id")
self.assertEqual(request_id, response_id)
self.assertTrue(_is_uuid(response_id))
self.assertIn(request_id, response_id)
self.assertTrue(_is_uuid(request_id))

def test_request_id_is_unique(self):
response_1 = get_session().get(self.api_endpoint)
response_2 = get_session().get(self.api_endpoint)

response_id_1 = response_1.headers["x-request-id"]
response_id_2 = response_2.headers["x-request-id"]
self.assertNotEqual(response_id_1, response_id_2)
request_id_1 = response_1.request.headers["X-Amzn-Trace-Id"]
request_id_2 = response_2.request.headers["X-Amzn-Trace-Id"]
self.assertNotEqual(request_id_1, request_id_2)

self.assertTrue(_is_uuid(response_id_1))
self.assertTrue(_is_uuid(response_id_2))
self.assertTrue(_is_uuid(request_id_1))
self.assertTrue(_is_uuid(request_id_2))

def test_request_id_not_overwritten(self):
response = get_session().get(self.api_endpoint, headers={"x-request-id": "custom-id"})
Expand Down

0 comments on commit 554215b

Please sign in to comment.