Skip to content

Commit 7b3d983

Browse files
jonemodavidlm
authored andcommitted
Update pytest-asyncio to actually run async tests (#119)
* bump pytest and pytest-asyncio to latest versions * regenerate lockfiles * remove unnecessary @pytest.mark.asyncio (we have asyncio_mode=auto) * fix endpoint provider unit test * equality operator for URI objects * update pants and ignore W503 which conflicts with black
1 parent d13ee5c commit 7b3d983

File tree

14 files changed

+825
-310
lines changed

14 files changed

+825
-310
lines changed

codegen/smithy-python-codegen/src/main/java/software/amazon/smithy/python/codegen/HttpProtocolTestGenerator.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -375,7 +375,7 @@ private void writeUtilStubs(Symbol serviceSymbol) {
375375
LOGGER.fine(String.format("Writing utility stubs for %s : %s", serviceSymbol.getName(), protocol.getName()));
376376
writer.addStdlibImport("typing", "Any");
377377
writer.addImports("smithy_python.interfaces.http", Set.of(
378-
"HeadersList", "HttpRequestConfiguration", "Request", "Response")
378+
"Fields", "HttpRequestConfiguration", "Request", "Response")
379379
);
380380
writer.addImport("smithy_python._private.http", "Response", "_Response");
381381

@@ -415,7 +415,7 @@ async def read(self, size: int = -1) -> bytes:
415415
class $4L:
416416
""\"An asynchronous HTTP client solely for testing purposes.""\"
417417
418-
def __init__(self, status_code: int, headers: HeadersList, body: bytes):
418+
def __init__(self, status_code: int, headers: Fields, body: bytes):
419419
self.status_code = status_code
420420
self.headers = headers
421421
self.body = AwaitableBody(body)

codegen/smithy-python-codegen/src/main/java/software/amazon/smithy/python/codegen/integration/HttpBindingProtocolGenerator.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -186,8 +186,8 @@ private void serializeHeaders(
186186
writer.pushState(new SerializeFieldsSection(operation));
187187
// TODO: map headers from inputs
188188
// TODO: write out default http and protocol headers
189-
writer.addImport("smithy_python.interfaces.http", "HeadersList", "_HeadersList");
190-
writer.write("headers: _HeadersList = []");
189+
writer.addImport("smithy_python.interfaces.http", "Fields", "_Fields");
190+
writer.write("headers: _Fields = []");
191191
writer.popState();
192192
}
193193

python-packages/smithy-python/smithy_python/_private/http/__init__.py

Lines changed: 220 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -13,36 +13,50 @@
1313

1414
# TODO: move all of this out of _private
1515

16+
<<<<<<< HEAD
1617

1718
from collections import OrderedDict
1819
from dataclasses import dataclass, field
1920
from typing import Any, Iterable, Protocol
21+
=======
22+
from collections import OrderedDict
23+
from collections.abc import AsyncIterable
24+
from dataclasses import dataclass, field
25+
from typing import Protocol
26+
>>>>>>> 23a2690 (Update pytest-asyncio to actually run async tests (#119))
2027
from urllib.parse import urlparse, urlunparse
2128

2229
from ... import interfaces
2330
from ...interfaces.http import FieldPosition as FieldPosition # re-export
2431

2532

26-
class URI:
27-
def __init__(
28-
self,
29-
host: str,
30-
path: str | None = None,
31-
scheme: str | None = None,
32-
query: str | None = None,
33-
port: int | None = None,
34-
username: str | None = None,
35-
password: str | None = None,
36-
fragment: str | None = None,
37-
):
38-
self.scheme: str = "https" if scheme is None else scheme
39-
self.host = host
40-
self.port = port
41-
self.path = path
42-
self.query = query
43-
self.username = username
44-
self.password = password
45-
self.fragment = fragment
33+
@dataclass(kw_only=True)
34+
class URI(interfaces.URI):
35+
"""Universal Resource Identifier, target location for a :py:class:`HTTPRequest`."""
36+
37+
scheme: str = "https"
38+
"""For example ``http`` or ``https``."""
39+
40+
username: str | None = None
41+
"""Username part of the userinfo URI component."""
42+
43+
password: str | None = None
44+
"""Password part of the userinfo URI component."""
45+
46+
host: str
47+
"""The hostname, for example ``amazonaws.com``."""
48+
49+
port: int | None = None
50+
"""An explicit port number."""
51+
52+
path: str | None = None
53+
"""Path component of the URI."""
54+
55+
query: str | None = None
56+
"""Query component of the URI as string."""
57+
58+
fragment: str | None = None
59+
"""Part of the URI specification, but may not be transmitted by a client."""
4660

4761
@property
4862
def netloc(self) -> str:
@@ -89,10 +103,68 @@ def __eq__(self, other: object) -> bool:
89103
and self.fragment == other.fragment
90104
)
91105

106+
<<<<<<< HEAD
107+
=======
108+
109+
@dataclass(kw_only=True)
110+
class HTTPRequest(interfaces.http.HTTPRequest):
111+
"""
112+
HTTP primitives for an Exchange to construct a version agnostic HTTP message.
113+
"""
114+
115+
destination: interfaces.URI
116+
body: AsyncIterable[bytes]
117+
method: str
118+
fields: interfaces.http.Fields
119+
120+
async def consume_body(self) -> bytes:
121+
"""Iterate over request body and return as bytes."""
122+
body = b""
123+
async for chunk in self.body:
124+
body += chunk
125+
return body
126+
127+
128+
@dataclass(kw_only=True)
129+
class HTTPResponse(interfaces.http.HTTPResponse):
130+
body: AsyncIterable[bytes]
131+
status: int
132+
fields: interfaces.http.Fields
133+
reason: str | None = None
134+
135+
async def consume_body(self) -> bytes:
136+
"""Iterate over response body and return as bytes."""
137+
body = b""
138+
async for chunk in self.body:
139+
body += chunk
140+
return body
141+
142+
@property
143+
def done(self) -> bool:
144+
"""
145+
Has the complete body been received.
146+
147+
Always returns True. Subclasses in implementations that support response
148+
streaming may override this.
149+
"""
150+
return True
151+
152+
153+
class Field(interfaces.http.Field):
154+
"""
155+
A name-value pair representing a single field in an HTTP Request or Response.
156+
157+
The kind will dictate metadata placement within an HTTP message.
158+
159+
All field names are case insensitive and case-variance must be treated as
160+
equivalent. Names may be normalized but should be preserved for accuracy during
161+
transmission.
162+
"""
163+
>>>>>>> 23a2690 (Update pytest-asyncio to actually run async tests (#119))
92164

93-
class Request:
94165
def __init__(
95166
self,
167+
<<<<<<< HEAD
96168
url: interfaces.http.URI,
97169
method: str = "GET",
98170
headers: interfaces.http.HeadersList | None = None,
@@ -104,18 +176,135 @@ def __init__(
104176
self.headers: interfaces.http.HeadersList = []
105177
if headers is not None:
106178
self.headers = headers
179+
=======
180+
name: str,
181+
value: list[str] | None = None,
182+
kind: FieldPosition | None = FieldPosition.HEADER,
183+
) -> None:
184+
self.name = name
185+
self.value = value
186+
self.kind = kind
187+
188+
def add(self, value: str) -> None:
189+
"""Append a value to a field"""
190+
if self.value is None:
191+
self.value = [value]
192+
else:
193+
self.value.append(value)
194+
195+
def set(self, value: list[str]) -> None:
196+
"""Overwrite existing field values."""
197+
self.value = value
198+
199+
def remove(self, value: str) -> None:
200+
"""Remove all matching entries from list"""
201+
if self.value is None:
202+
return
203+
try:
204+
while True:
205+
self.value.remove(value)
206+
except ValueError:
207+
return
208+
209+
def _quote_and_escape_single_value(self, value: str) -> str:
210+
"""Escapes and quotes a single value if necessary.
211+
212+
A value is surrounded by double quotes if it contains comma (,) or whitespace.
213+
Any double quote characters present in the value (before quoting) are escaped
214+
with a backslash.
215+
"""
216+
escaped = value.replace('"', '\\"')
217+
needs_quoting = any(char == "," or char.isspace() for char in escaped)
218+
quoted = f'"{escaped}"' if needs_quoting else escaped
219+
return quoted
220+
221+
def get_value(self) -> str:
222+
"""
223+
Get comma-delimited string values.
224+
225+
Values with spaces or commas are double-quoted.
226+
"""
227+
if self.value is None:
228+
return ""
229+
return ",".join(self._quote_and_escape_single_value(val) for val in self.value)
230+
231+
def get_value_list(self) -> list[str]:
232+
"""Get string values as a list"""
233+
if self.value is None:
234+
return []
235+
else:
236+
return self.value
237+
238+
def __eq__(self, other: object) -> bool:
239+
"""Name, values, and kind must match. Values order must match."""
240+
if not isinstance(other, Field):
241+
return False
242+
return (
243+
self.name == other.name
244+
and self.kind == other.kind
245+
and self.value == other.value
246+
)
247+
248+
def __repr__(self) -> str:
249+
return f"Field({self.kind.name} {self.name}: {self.get_value()})"
250+
>>>>>>> 23a2690 (Update pytest-asyncio to actually run async tests (#119))
251+
107252

253+
class Fields(interfaces.http.Fields):
254+
"""Collection of Field entries mapped by name."""
108255

109-
class Response:
110256
def __init__(
111257
self,
258+
<<<<<<< HEAD
112259
status_code: int,
113260
headers: interfaces.http.HeadersList,
114261
body: Any,
115262
):
116263
self.status_code: int = status_code
117264
self.headers: interfaces.http.HeadersList = headers
118265
self.body: Any = body
266+
=======
267+
initial: list[interfaces.http.Field] | None = None,
268+
*,
269+
encoding: str = "utf-8",
270+
) -> None:
271+
init_tuples = [] if initial is None else [(fld.name, fld) for fld in initial]
272+
self.entries: OrderedDict[str, interfaces.http.Field] = OrderedDict(init_tuples)
273+
self.encoding: str = encoding
274+
275+
def set_field(self, field: interfaces.http.Field) -> None:
276+
"""Set entry for a Field name."""
277+
self.entries[field.name] = field
278+
279+
def get_field(self, name: str) -> interfaces.http.Field:
280+
"""Retrieve Field entry"""
281+
return self.entries[name]
282+
283+
def remove_field(self, name: str) -> None:
284+
"""Delete entry from collection"""
285+
del self.entries[name]
286+
287+
def get_by_type(self, kind: FieldPosition) -> list[interfaces.http.Field]:
288+
"""Helper function for retrieving specific types of fields
289+
290+
Used to grab all headers or all trailers
291+
"""
292+
return [entry for entry in self.entries.values() if entry.kind == kind]
293+
294+
def __eq__(self, other: object) -> bool:
295+
"""Encoding must match. Entries must match in values but not order."""
296+
if not isinstance(other, Fields):
297+
return False
298+
if self.encoding != other.encoding:
299+
return False
300+
if set(self.entries.keys()) != set(other.entries.keys()):
301+
return False
302+
for field_name, self_field in self.entries.items():
303+
other_field = other.get_field(field_name)
304+
if self_field != other_field:
305+
return False
306+
return True
307+
>>>>>>> 23a2690 (Update pytest-asyncio to actually run async tests (#119))
119308

120309

121310
class Field(interfaces.http.Field):
@@ -251,8 +440,13 @@ def __iter__(self) -> Iterable[interfaces.http.Field]:
251440

252441
@dataclass
253442
class Endpoint(interfaces.http.Endpoint):
443+
<<<<<<< HEAD
254444
url: interfaces.http.URI
255445
headers: interfaces.http.HeadersList = field(default_factory=list)
446+
=======
447+
url: interfaces.URI
448+
headers: interfaces.http.Fields = field(default_factory=Fields)
449+
>>>>>>> 23a2690 (Update pytest-asyncio to actually run async tests (#119))
256450

257451

258452
@dataclass
@@ -263,7 +457,11 @@ class StaticEndpointParams:
263457
:params url: A static URI to route requests to.
264458
"""
265459

460+
<<<<<<< HEAD
266461
url: str | interfaces.http.URI
462+
=======
463+
url: str | interfaces.URI
464+
>>>>>>> 23a2690 (Update pytest-asyncio to actually run async tests (#119))
267465

268466

269467
class StaticEndpointResolver(interfaces.http.EndpointResolver[StaticEndpointParams]):

0 commit comments

Comments
 (0)