Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add "trailing_headers" extension for HTTP/1.1 #582

Closed
wants to merge 8 commits into from
20 changes: 20 additions & 0 deletions docs/extensions.md
Original file line number Diff line number Diff line change
Expand Up @@ -275,3 +275,23 @@ with httpcore.stream("GET", "https://www.example.com") as response:
ssl_object = network_stream.get_extra_info("ssl_object")
print("TLS version", ssl_object.version())
```

### `"trailing_headers"`

Trailing headers are a rarely used feature of HTTP, where supplementary headers may be sent at the end of the response data.

The `trailing_headers` response extenstion is implemented as a list of `(byte, byte)` tuples containing any [trailing headers](https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Trailer#chunked_transfer_encoding_using_a_trailing_header) sent at the end of the response. This list is only populated once the response is complete, and will be empty while streaming the response data.
tomchristie marked this conversation as resolved.
Show resolved Hide resolved

```python
# The "TE: trailers" header should be used in order to indicate that we're
# willing to accept trailing headers. This isn't required by the `httpcore`
# package itself, but is mandated by the HTTP spec, and might be required
# by some servers or proxies.
response = httpcore.request("GET", "https://www.example.com", headers={"TE": "trailers"})

# Show the standard response headers.
print(response.headers)

# Show any trailing headers sent at the end of the response.
print(response.extensions['trailing_headers'])
```
35 changes: 29 additions & 6 deletions httpcore/_async/http11.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,14 +96,16 @@ async def handle_async_request(self, request: Request) -> Response:
headers,
)

trailing_headers: List[Tuple[bytes, bytes]] = []
return Response(
status=status,
headers=headers,
content=HTTP11ConnectionByteStream(self, request),
content=HTTP11ConnectionByteStream(self, request, trailing_headers),
extensions={
"http_version": http_version,
"reason_phrase": reason_phrase,
"network_stream": self._network_stream,
"trailing_headers": trailing_headers,
},
)
except BaseException as exc:
Expand Down Expand Up @@ -164,15 +166,28 @@ async def _receive_response_headers(

return http_version, event.status_code, event.reason, headers

async def _receive_response_body(self, request: Request) -> AsyncIterator[bytes]:
async def _receive_response_body(
self, request: Request, trailing_headers: List[Tuple[bytes, bytes]]
) -> AsyncIterator[bytes]:
timeouts = request.extensions.get("timeout", {})
timeout = timeouts.get("read", None)

while True:
event = await self._receive_event(timeout=timeout)
if isinstance(event, h11.Data):
# Each response will have zero, one, or more data events,
# containing the body of the response.
yield bytes(event.data)
elif isinstance(event, (h11.EndOfMessage, h11.PAUSED)):
elif isinstance(event, h11.EndOfMessage):
# Once we get an EndOfMessage event, the response data has finished.
if event.headers:
trailing_headers.extend(event.headers.raw_items())
break
elif isinstance(event, h11.PAUSED):
# This can occur here on a successful CONNECT or Upgrade
# response, where it is returned rather than EndOfMessage.
#
# See https://h11.readthedocs.io/en/latest/api.html#flow-control
break

async def _receive_event(
Expand Down Expand Up @@ -291,16 +306,24 @@ async def __aexit__(


class HTTP11ConnectionByteStream:
def __init__(self, connection: AsyncHTTP11Connection, request: Request) -> None:
def __init__(
self,
connection: AsyncHTTP11Connection,
request: Request,
trailing_headers: List[Tuple[bytes, bytes]],
) -> None:
self._connection = connection
self._request = request
self._trailing_headers = trailing_headers
self._closed = False

async def __aiter__(self) -> AsyncIterator[bytes]:
kwargs = {"request": self._request}
kwargs = {"request": self._request, "trailing_headers": self._trailing_headers}
try:
async with Trace("http11.receive_response_body", self._request, kwargs):
async for chunk in self._connection._receive_response_body(**kwargs):
async for chunk in self._connection._receive_response_body(
request=self._request, trailing_headers=self._trailing_headers
):
yield chunk
except BaseException as exc:
# If we get an exception while streaming the response,
Expand Down
35 changes: 29 additions & 6 deletions httpcore/_sync/http11.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,14 +96,16 @@ def handle_request(self, request: Request) -> Response:
headers,
)

trailing_headers: List[Tuple[bytes, bytes]] = []
return Response(
status=status,
headers=headers,
content=HTTP11ConnectionByteStream(self, request),
content=HTTP11ConnectionByteStream(self, request, trailing_headers),
extensions={
"http_version": http_version,
"reason_phrase": reason_phrase,
"network_stream": self._network_stream,
"trailing_headers": trailing_headers,
},
)
except BaseException as exc:
Expand Down Expand Up @@ -164,15 +166,28 @@ def _receive_response_headers(

return http_version, event.status_code, event.reason, headers

def _receive_response_body(self, request: Request) -> Iterator[bytes]:
def _receive_response_body(
self, request: Request, trailing_headers: List[Tuple[bytes, bytes]]
) -> Iterator[bytes]:
timeouts = request.extensions.get("timeout", {})
timeout = timeouts.get("read", None)

while True:
event = self._receive_event(timeout=timeout)
if isinstance(event, h11.Data):
# Each response will have zero, one, or more data events,
# containing the body of the response.
yield bytes(event.data)
elif isinstance(event, (h11.EndOfMessage, h11.PAUSED)):
elif isinstance(event, h11.EndOfMessage):
# Once we get an EndOfMessage event, the response data has finished.
if event.headers:
trailing_headers.extend(event.headers.raw_items())
break
elif isinstance(event, h11.PAUSED):
# This can occur here on a successful CONNECT or Upgrade
# response, where it is returned rather than EndOfMessage.
#
# See https://h11.readthedocs.io/en/latest/api.html#flow-control
break

def _receive_event(
Expand Down Expand Up @@ -291,16 +306,24 @@ def __exit__(


class HTTP11ConnectionByteStream:
def __init__(self, connection: HTTP11Connection, request: Request) -> None:
def __init__(
self,
connection: HTTP11Connection,
request: Request,
trailing_headers: List[Tuple[bytes, bytes]],
) -> None:
self._connection = connection
self._request = request
self._trailing_headers = trailing_headers
self._closed = False

def __iter__(self) -> Iterator[bytes]:
kwargs = {"request": self._request}
kwargs = {"request": self._request, "trailing_headers": self._trailing_headers}
try:
with Trace("http11.receive_response_body", self._request, kwargs):
for chunk in self._connection._receive_response_body(**kwargs):
for chunk in self._connection._receive_response_body(
request=self._request, trailing_headers=self._trailing_headers
):
yield chunk
except BaseException as exc:
# If we get an exception while streaming the response,
Expand Down
74 changes: 74 additions & 0 deletions tests/_async/test_http11.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,80 @@ async def test_http11_connection():
)


@pytest.mark.anyio
async def test_http11_connection_chunked_response():
origin = Origin(b"https", b"example.com", 443)
stream = AsyncMockStream(
[
b"HTTP/1.1 200 OK\r\n",
b"Content-Type: plain/text\r\n",
b"Transfer-Encoding: chunked\r\n",
b"\r\n",
b"3\r\n",
b"Hel\r\n",
b"4\r\n",
b"lo, \r\n",
b"6\r\n",
b"world!\r\n",
b"0\r\n",
b"\r\n",
]
)
async with AsyncHTTP11Connection(
origin=origin, stream=stream, keepalive_expiry=5.0
) as conn:
response = await conn.request("GET", "https://example.com/")
assert response.status == 200
assert response.content == b"Hello, world!"

assert conn.is_idle()
assert not conn.is_closed()
assert conn.is_available()
assert not conn.has_expired()
assert (
repr(conn)
== "<AsyncHTTP11Connection ['https://example.com:443', IDLE, Request Count: 1]>"
)


@pytest.mark.anyio
async def test_http11_connection_trailing_headers_response():
origin = Origin(b"https", b"example.com", 443)
stream = AsyncMockStream(
[
b"HTTP/1.1 200 OK\r\n",
b"Content-Type: plain/text\r\n",
b"Transfer-Encoding: chunked\r\n",
b"Trailer: Surprise\r\n",
b"\r\n",
b"3\r\n",
b"Hel\r\n",
b"4\r\n",
b"lo, \r\n",
b"6\r\n",
b"world!\r\n",
b"0\r\n",
b"Surprise: You thought we were done here?\r\n",
b"\r\n",
]
)
async with AsyncHTTP11Connection(
origin=origin, stream=stream, keepalive_expiry=5.0
) as conn:
response = await conn.request(
"GET", "https://example.com/", headers={"TE": "trailers"}
)
assert response.status == 200
assert response.content == b"Hello, world!"
assert response.headers == [
(b"Content-Type", b"plain/text"),
(b"Transfer-Encoding", b"chunked"),
(b"Trailer", b"Surprise"),
]
trailing_headers = response.extensions["trailing_headers"]
assert trailing_headers == [(b"Surprise", b"You thought we were done here?")]


@pytest.mark.anyio
async def test_http11_connection_unread_response():
"""
Expand Down
74 changes: 74 additions & 0 deletions tests/_sync/test_http11.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,80 @@ def test_http11_connection():



def test_http11_connection_chunked_response():
origin = Origin(b"https", b"example.com", 443)
stream = MockStream(
[
b"HTTP/1.1 200 OK\r\n",
b"Content-Type: plain/text\r\n",
b"Transfer-Encoding: chunked\r\n",
b"\r\n",
b"3\r\n",
b"Hel\r\n",
b"4\r\n",
b"lo, \r\n",
b"6\r\n",
b"world!\r\n",
b"0\r\n",
b"\r\n",
]
)
with HTTP11Connection(
origin=origin, stream=stream, keepalive_expiry=5.0
) as conn:
response = conn.request("GET", "https://example.com/")
assert response.status == 200
assert response.content == b"Hello, world!"

assert conn.is_idle()
assert not conn.is_closed()
assert conn.is_available()
assert not conn.has_expired()
assert (
repr(conn)
== "<HTTP11Connection ['https://example.com:443', IDLE, Request Count: 1]>"
)



def test_http11_connection_trailing_headers_response():
origin = Origin(b"https", b"example.com", 443)
stream = MockStream(
[
b"HTTP/1.1 200 OK\r\n",
b"Content-Type: plain/text\r\n",
b"Transfer-Encoding: chunked\r\n",
b"Trailer: Surprise\r\n",
b"\r\n",
b"3\r\n",
b"Hel\r\n",
b"4\r\n",
b"lo, \r\n",
b"6\r\n",
b"world!\r\n",
b"0\r\n",
b"Surprise: You thought we were done here?\r\n",
b"\r\n",
]
)
with HTTP11Connection(
origin=origin, stream=stream, keepalive_expiry=5.0
) as conn:
response = conn.request(
"GET", "https://example.com/", headers={"TE": "trailers"}
)
assert response.status == 200
assert response.content == b"Hello, world!"
assert response.headers == [
(b"Content-Type", b"plain/text"),
(b"Transfer-Encoding", b"chunked"),
(b"Trailer", b"Surprise"),
]
trailing_headers = response.extensions["trailing_headers"]
assert trailing_headers == [(b"Surprise", b"You thought we were done here?")]



def test_http11_connection_unread_response():
"""
If the client releases the response without reading it to termination,
Expand Down