Skip to content

Commit

Permalink
Ensure exclusive access to underlying connection
Browse files Browse the repository at this point in the history
  • Loading branch information
a-feld committed Nov 26, 2021
1 parent abf03a6 commit 7ad0417
Showing 1 changed file with 60 additions and 52 deletions.
112 changes: 60 additions & 52 deletions src/pywreck/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,11 +87,13 @@ def __init__(
self,
reader: asyncio.StreamReader,
writer: asyncio.StreamWriter,
lock: asyncio.Lock,
host: str,
timeout: Optional[float],
):
self._reader = _HttpReader(reader, timeout)
self._writer = _HttpWriter(writer, timeout)
self._lock = lock
self._host = host

@classmethod
Expand Down Expand Up @@ -120,7 +122,7 @@ async def create(
:rtype: Connection
"""
reader, writer = await asyncio.open_connection(host, port, ssl=ssl)
return Connection(reader, writer, host, timeout)
return Connection(reader, writer, asyncio.Lock(), host, timeout)

async def __aenter__(self) -> "Connection":
return self
Expand Down Expand Up @@ -158,68 +160,74 @@ async def request(
:rtype: Response
"""
reader, writer = self._reader, self._writer
request = f"{method} {uri} HTTP/1.1\r\n"
writer.write_ascii(request)
# Since the connection can be a shared resource, we must ensure
# exclusive access for the duration of the request/response cycle
async with self._lock:
reader, writer = self._reader, self._writer
request = f"{method} {uri} HTTP/1.1\r\n"
writer.write_ascii(request)

request_headers = {"host": self._host, "user-agent": f"pywreck/{__version__}"}
if payload:
request_headers["content-length"] = str(len(payload))
request_headers = {
"host": self._host,
"user-agent": f"pywreck/{__version__}",
}
if payload:
request_headers["content-length"] = str(len(payload))

if headers:
request_headers.update(headers)
if headers:
request_headers.update(headers)

for header_name, header_value in request_headers.items():
writer.write_ascii(f"{header_name}:{header_value}\r\n")
for header_name, header_value in request_headers.items():
writer.write_ascii(f"{header_name}:{header_value}\r\n")

# Finish request metadata section
writer.write(b"\r\n")
# Finish request metadata section
writer.write(b"\r\n")

# Send payload
writer.write(payload)
# Send payload
writer.write(payload)

response_line = await reader.readline()
status = int(response_line.split(" ", 2)[1])
response_line = await reader.readline()
status = int(response_line.split(" ", 2)[1])

response_headers: Dict[str, str] = {}
content_length = 0
chunked = False
response_headers: Dict[str, str] = {}
content_length = 0
chunked = False

while True:
header_line = await reader.readline()
header_line = header_line.rstrip()
if not header_line:
break

header_name, header_value = header_line.split(":", 1)
header_name = header_name.rstrip().lower()
header_value = header_value.lstrip()

if header_name in response_headers:
separator = "," if header_name != "set-cookie" else ";"
response_headers[header_name] += separator + header_value
else:
response_headers[header_name] = header_value

if method != "HEAD":
if "content-length" in response_headers:
content_length = int(response_headers["content-length"])

chunked = response_headers.get("transfer-encoding", "") == "chunked"

if chunked:
response_chunks = []
while True:
chunk_len_bytes = await reader.readuntil(b"\r\n")
content_length = int(chunk_len_bytes.rstrip(), 16)
if not content_length:
header_line = await reader.readline()
header_line = header_line.rstrip()
if not header_line:
break
part = await reader.readexactly(content_length + 2)
response_chunks.append(part[:-2])

response_data = b"".join(response_chunks)
else:
response_data = await reader.readexactly(content_length)
header_name, header_value = header_line.split(":", 1)
header_name = header_name.rstrip().lower()
header_value = header_value.lstrip()

if header_name in response_headers:
separator = "," if header_name != "set-cookie" else ";"
response_headers[header_name] += separator + header_value
else:
response_headers[header_name] = header_value

if method != "HEAD":
if "content-length" in response_headers:
content_length = int(response_headers["content-length"])

chunked = response_headers.get("transfer-encoding", "") == "chunked"

if chunked:
response_chunks = []
while True:
chunk_len_bytes = await reader.readuntil(b"\r\n")
content_length = int(chunk_len_bytes.rstrip(), 16)
if not content_length:
break
part = await reader.readexactly(content_length + 2)
response_chunks.append(part[:-2])

response_data = b"".join(response_chunks)
else:
response_data = await reader.readexactly(content_length)

return Response(status, response_headers, response_data)

Expand Down

0 comments on commit 7ad0417

Please sign in to comment.