Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix body matcher for chunked requests (fixes #734) #739

Merged
merged 3 commits into from
Jul 23, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions tests/unit/test_matchers.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,9 @@ def test_uri_matcher():
"Expect": b"100-continue",
"Content-Length": "21",
}
chunked_headers = {
"Transfer-Encoding": "chunked",
}


@pytest.mark.parametrize(
Expand Down Expand Up @@ -151,6 +154,36 @@ def test_uri_matcher():
request.Request("POST", "http://aws.custom.com/", b"123", boto3_bytes_headers),
request.Request("POST", "http://aws.custom.com/", b"123", boto3_bytes_headers),
),
(
# chunked transfer encoding: decoded bytes versus encoded bytes
request.Request("POST", "scheme1://host1.test/", b"123456789_123456", chunked_headers),
request.Request(
"GET",
"scheme2://host2.test/",
b"10\r\n123456789_123456\r\n0\r\n\r\n",
chunked_headers,
),
),
(
# chunked transfer encoding: bytes iterator versus string iterator
request.Request(
"POST",
"scheme1://host1.test/",
iter([b"123456789_", b"123456"]),
chunked_headers,
),
request.Request("GET", "scheme2://host2.test/", iter(["123456789_", "123456"]), chunked_headers),
),
(
# chunked transfer encoding: bytes iterator versus single byte iterator
request.Request(
"POST",
"scheme1://host1.test/",
iter([b"123456789_", b"123456"]),
chunked_headers,
),
request.Request("GET", "scheme2://host2.test/", iter(b"123456789_123456"), chunked_headers),
),
],
)
def test_body_matcher_does_match(r1, r2):
Expand Down
87 changes: 74 additions & 13 deletions vcr/matchers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,13 @@
import logging
import urllib
import xmlrpc.client
from string import hexdigits
from typing import List, Set

from .util import read_body

_HEXDIG_CODE_POINTS: Set[int] = {ord(s.encode("ascii")) for s in hexdigits}

log = logging.getLogger(__name__)


Expand Down Expand Up @@ -49,11 +53,17 @@ def raw_body(r1, r2):


def body(r1, r2):
transformer = _get_transformer(r1)
r2_transformer = _get_transformer(r2)
if transformer != r2_transformer:
transformer = _identity
if transformer(read_body(r1)) != transformer(read_body(r2)):
transformers = list(_get_transformers(r1))
if transformers != list(_get_transformers(r2)):
transformers = []

b1 = read_body(r1)
b2 = read_body(r2)
for transform in transformers:
b1 = transform(b1)
b2 = transform(b2)

if b1 != b2:
raise AssertionError


Expand All @@ -72,6 +82,62 @@ def checker(headers):
return checker


def _dechunk(body):
if isinstance(body, str):
body = body.encode("utf-8")
elif isinstance(body, bytearray):
body = bytes(body)
elif hasattr(body, "__iter__"):
body = list(body)
if body:
if isinstance(body[0], str):
body = ("".join(body)).encode("utf-8")
elif isinstance(body[0], bytes):
body = b"".join(body)
elif isinstance(body[0], int):
body = bytes(body)
else:
raise ValueError(f"Body chunk type {type(body[0])} not supported")
else:
body = None

if not isinstance(body, bytes):
return body

# Now decode chunked data format (https://en.wikipedia.org/wiki/Chunked_transfer_encoding)
# Example input: b"45\r\n<69 bytes>\r\n0\r\n\r\n" where int(b"45", 16) == 69.
CHUNK_GAP = b"\r\n"
BODY_LEN: int = len(body)

chunks: List[bytes] = []
pos: int = 0

while True:
for i in range(pos, BODY_LEN):
if body[i] not in _HEXDIG_CODE_POINTS:
break

if i == 0 or body[i : i + len(CHUNK_GAP)] != CHUNK_GAP:
if pos == 0:
return body # i.e. assume non-chunk data
raise ValueError("Malformed chunked data")

size_bytes = int(body[pos:i], 16)
if size_bytes == 0: # i.e. well-formed ending
return b"".join(chunks)

chunk_data_first = i + len(CHUNK_GAP)
chunk_data_after_last = chunk_data_first + size_bytes

if body[chunk_data_after_last : chunk_data_after_last + len(CHUNK_GAP)] != CHUNK_GAP:
raise ValueError("Malformed chunked data")

chunk_data = body[chunk_data_first:chunk_data_after_last]
chunks.append(chunk_data)

pos = chunk_data_after_last + len(CHUNK_GAP)


def _transform_json(body):
# Request body is always a byte string, but json.loads() wants a text
# string. RFC 7159 says the default encoding is UTF-8 (although UTF-16
Expand All @@ -83,6 +149,7 @@ def _transform_json(body):
_xml_header_checker = _header_checker("text/xml")
_xmlrpc_header_checker = _header_checker("xmlrpc", header="User-Agent")
_checker_transformer_pairs = (
(_header_checker("chunked", header="Transfer-Encoding"), _dechunk),
(
_header_checker("application/x-www-form-urlencoded"),
lambda body: urllib.parse.parse_qs(body.decode("ascii")),
Expand All @@ -92,16 +159,10 @@ def _transform_json(body):
)


def _identity(x):
return x


def _get_transformer(request):
def _get_transformers(request):
for checker, transformer in _checker_transformer_pairs:
if checker(request.headers):
return transformer
else:
return _identity
yield transformer


def requests_match(r1, r2, matchers):
Expand Down