Skip to content

Commit

Permalink
Ensure body is consumed only once
Browse files Browse the repository at this point in the history
Fixes: kevin1024#846
Signed-off-by: Mathieu Parent <math.parent@gmail.com>
  • Loading branch information
sathieu committed Jul 7, 2024
1 parent 9cfa6c5 commit ffb8c07
Show file tree
Hide file tree
Showing 4 changed files with 101 additions and 2 deletions.
53 changes: 52 additions & 1 deletion tests/unit/test_stubs.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import contextlib
import http.client as httplib
from io import BytesIO
from unittest import mock

from pytest import mark

from vcr import mode
from vcr import mode, use_cassette
from vcr.cassette import Cassette
from vcr.stubs import VCRHTTPSConnection

Expand All @@ -21,3 +23,52 @@ def testing_connect(*args):
vcr_connection.cassette = Cassette("test", record_mode=mode.ALL)
vcr_connection.real_connection.connect()
assert vcr_connection.real_connection.sock is not None

def test_body_consumed_once_stream(self, tmpdir, httpbin):
self._test_body_consumed_once(
tmpdir,
httpbin,
"body_consumed_once_stream.yml",
BytesIO(b"1234567890"),
BytesIO(b"9876543210"),
BytesIO(b"9876543210"),
)

def test_body_consumed_once_iterator(self, tmpdir, httpbin):
self._test_body_consumed_once(
tmpdir,
httpbin,
"body_consumed_once_iterator.yml",
iter([b"1234567890"]),
iter([b"9876543210"]),
iter([b"9876543210"]),
)

def _test_body_consumed_once(
self,
tmpdir,
httpbin,
testfile,
data1,
data2,
data3,
):
testpath = str(tmpdir.join(testfile))
host, port = httpbin.host, httpbin.port
match_on = ["method", "uri", "body"]
with use_cassette(testpath, match_on=match_on):
conn1 = httplib.HTTPConnection(host, port)
conn1.request("POST", "/anything", body=data1)
resp1 = conn1.getresponse()
assert resp1.status == 501 # Chunked request Not implemented
conn2 = httplib.HTTPConnection(host, port)
conn2.request("POST", "/anything", body=data2)
resp2 = conn2.getresponse()
assert resp2.status == 501 # Chunked request Not implemented
with use_cassette(testpath, match_on=match_on) as cass:
conn3 = httplib.HTTPConnection(host, port)
conn3.request("POST", "/anything", body=data3)
resp3 = conn3.getresponse()
assert resp3.status == 501 # Chunked request Not implemented
assert cass.play_counts[0] == 0
assert cass.play_counts[1] == 1
23 changes: 23 additions & 0 deletions tests/unit/test_util.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from io import BytesIO, StringIO

import pytest

from vcr import request
from vcr.util import read_body


@pytest.mark.parametrize(
"input, output",
[
(BytesIO(b"Stream"), b"Stream"),
(StringIO("Stream"), b"Stream"),
(iter(["StringIter"]), b"StringIter"),
(iter([b"BytesIter"]), b"BytesIter"),
(iter([70, 111, 111]), b"Foo"),
("String", b"String"),
(b"Bytes", b"Bytes"),
],
)
def test_read_body(input, output):
r = request.Request("POST", "http://host.com/", input, {})
assert read_body(r) == output
12 changes: 11 additions & 1 deletion vcr/request.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,14 @@ def __init__(self, method, uri, body, headers):
self.method = method
self.uri = uri
self._was_file = hasattr(body, "read")
self._was_iter = hasattr(body, "__iter__") and not isinstance(
body,
(bytearray, bytes, dict, list, str),
)
if self._was_file:
self.body = body.read()
elif self._was_iter:
self.body = list(body)
else:
self.body = body
self.headers = headers
Expand All @@ -36,7 +42,11 @@ def headers(self, value):

@property
def body(self):
return BytesIO(self._body) if self._was_file else self._body
if self._was_file:
return BytesIO(self._body)
if self._was_iter:
return iter(self._body)
return self._body

@body.setter
def body(self, value):
Expand Down
15 changes: 15 additions & 0 deletions vcr/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,21 @@ def composed(incoming):
def read_body(request):
if hasattr(request.body, "read"):
return request.body.read()
if hasattr(request.body, "__iter__") and not isinstance(
request.body,
(bytearray, bytes, dict, list, str),
):
body = list(request.body)
if body:
if isinstance(body[0], str):
return "".join(body).encode("utf-8")
elif isinstance(body[0], bytes):
return b"".join(body)
elif isinstance(body[0], int):
return bytes(body)
else:
raise ValueError(f"Body type {type(body[0])} not supported")
return None
return request.body


Expand Down

0 comments on commit ffb8c07

Please sign in to comment.