Skip to content

Commit 300a17d

Browse files
jonemonateprewitt
andauthored
fields & fieldlists interfaces and implementation (#122)
* fields & fieldlists * docstrings * Field.value cannot be None, test escaping of backslashes * missing "Protocol" * dots in docstrings * updated quoting & escaping rules * updated equality rules, check for duplicated initial field names * no return type annotation for __init__ * ", " instead of "," as field separator * __iter__ and improved __repr__ for Fields * move quote_and_escape_field_value to utility method * updated field value quoting and escaping rules * drop Field.get_value_list(), Field.add as_tuples() * normalize field names in Fields * accept any iterable for field values in Field and fields in Fields * type hints for tests * grammer, naming, reprs * Field.value --> Field.values * Apply suggestions from code review Co-authored-by: Nate Prewitt <nate.prewitt@gmail.com> * use kwargs everywhere --------- Co-authored-by: Nate Prewitt <nate.prewitt@gmail.com>
1 parent 2d202ab commit 300a17d

File tree

3 files changed

+459
-14
lines changed

3 files changed

+459
-14
lines changed

python-packages/smithy-python/smithy_python/_private/http/__init__.py

Lines changed: 175 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,14 @@
1414
# TODO: move all of this out of _private
1515

1616

17+
from collections import Counter, OrderedDict
18+
from collections.abc import Iterable
1719
from dataclasses import dataclass, field
1820
from typing import Any, Protocol
1921
from urllib.parse import urlparse, urlunparse
2022

21-
from smithy_python.interfaces import http as http_interface
23+
from ... import interfaces
24+
from ...interfaces.http import FieldPosition as FieldPosition # re-export
2225

2326

2427
class URI:
@@ -91,16 +94,15 @@ def __eq__(self, other: object) -> bool:
9194
class Request:
9295
def __init__(
9396
self,
94-
url: http_interface.URI,
97+
url: interfaces.http.URI,
9598
method: str = "GET",
96-
headers: http_interface.HeadersList | None = None,
99+
headers: interfaces.http.HeadersList | None = None,
97100
body: Any = None,
98101
):
99-
self.url: http_interface.URI = url
102+
self.url: interfaces.http.URI = url
100103
self.method: str = method
101104
self.body: Any = body
102-
103-
self.headers: http_interface.HeadersList = []
105+
self.headers: interfaces.http.HeadersList = []
104106
if headers is not None:
105107
self.headers = headers
106108

@@ -109,18 +111,177 @@ class Response:
109111
def __init__(
110112
self,
111113
status_code: int,
112-
headers: http_interface.HeadersList,
114+
headers: interfaces.http.HeadersList,
113115
body: Any,
114116
):
115117
self.status_code: int = status_code
116-
self.headers: http_interface.HeadersList = headers
118+
self.headers: interfaces.http.HeadersList = headers
117119
self.body: Any = body
118120

119121

122+
class Field(interfaces.http.Field):
123+
"""
124+
A name-value pair representing a single field in an HTTP Request or Response.
125+
126+
The kind will dictate metadata placement within an HTTP message.
127+
128+
All field names are case insensitive and case-variance must be treated as
129+
equivalent. Names may be normalized but should be preserved for accuracy during
130+
transmission.
131+
"""
132+
133+
def __init__(
134+
self,
135+
*,
136+
name: str,
137+
values: Iterable[str] | None = None,
138+
kind: FieldPosition = FieldPosition.HEADER,
139+
):
140+
self.name = name
141+
self.values: list[str] = [val for val in values] if values is not None else []
142+
self.kind = kind
143+
144+
def add(self, value: str) -> None:
145+
"""Append a value to a field."""
146+
self.values.append(value)
147+
148+
def set(self, values: list[str]) -> None:
149+
"""Overwrite existing field values."""
150+
self.values = values
151+
152+
def remove(self, value: str) -> None:
153+
"""Remove all matching entries from list."""
154+
try:
155+
while True:
156+
self.values.remove(value)
157+
except ValueError:
158+
return
159+
160+
def as_string(self) -> str:
161+
"""
162+
Get comma-delimited string of all values.
163+
164+
If the ``Field`` has zero values, the empty string is returned. If the ``Field``
165+
has exactly one value, the value is returned unmodified.
166+
167+
For ``Field``s with more than one value, the values are joined by a comma and a
168+
space. For such multi-valued ``Field``s, any values that already contain
169+
commas or double quotes will be surrounded by double quotes. Within any values
170+
that get quoted, pre-existing double quotes and backslashes are escaped with a
171+
backslash.
172+
"""
173+
value_count = len(self.values)
174+
if value_count == 0:
175+
return ""
176+
if value_count == 1:
177+
return self.values[0]
178+
return ", ".join(quote_and_escape_field_value(val) for val in self.values)
179+
180+
def as_tuples(self) -> list[tuple[str, str]]:
181+
"""
182+
Get list of ``name``, ``value`` tuples where each tuple represents one value.
183+
"""
184+
return [(self.name, val) for val in self.values]
185+
186+
def __eq__(self, other: object) -> bool:
187+
"""Name, values, and kind must match. Values order must match."""
188+
if not isinstance(other, Field):
189+
return False
190+
return (
191+
self.name == other.name
192+
and self.kind is other.kind
193+
and self.values == other.values
194+
)
195+
196+
def __repr__(self) -> str:
197+
return f"Field(name={self.name!r}, value={self.values!r}, kind={self.kind!r})"
198+
199+
200+
def quote_and_escape_field_value(value: str) -> str:
201+
"""Escapes and quotes a single :class:`Field` value if necessary.
202+
203+
See :func:`Field.as_string` for quoting and escaping logic.
204+
"""
205+
chars_to_quote = (",", '"')
206+
if any(char in chars_to_quote for char in value):
207+
escaped = value.replace("\\", "\\\\").replace('"', '\\"')
208+
return f'"{escaped}"'
209+
else:
210+
return value
211+
212+
213+
class Fields(interfaces.http.Fields):
214+
def __init__(
215+
self,
216+
initial: Iterable[interfaces.http.Field] | None = None,
217+
*,
218+
encoding: str = "utf-8",
219+
):
220+
"""
221+
Collection of header and trailer entries mapped by name.
222+
223+
:param initial: Initial list of ``Field`` objects. ``Field``s can alse be added
224+
with :func:`set_field` and later removed with :func:`remove_field`.
225+
:param encoding: The string encoding to be used when converting the ``Field``
226+
name and value from ``str`` to ``bytes`` for transmission.
227+
"""
228+
init_fields = [fld for fld in initial] if initial is not None else []
229+
init_field_names = [self._normalize_field_name(fld.name) for fld in init_fields]
230+
fname_counter = Counter(init_field_names)
231+
repeated_names_exist = (
232+
len(init_fields) > 0 and fname_counter.most_common(1)[0][1] > 1
233+
)
234+
if repeated_names_exist:
235+
non_unique_names = [name for name, num in fname_counter.items() if num > 1]
236+
raise ValueError(
237+
"Field names of the initial list of fields must be unique. The "
238+
"following normalized field names appear more than once: "
239+
f"{', '.join(non_unique_names)}."
240+
)
241+
init_tuples = zip(init_field_names, init_fields)
242+
self.entries: OrderedDict[str, interfaces.http.Field] = OrderedDict(init_tuples)
243+
self.encoding: str = encoding
244+
245+
def set_field(self, field: interfaces.http.Field) -> None:
246+
"""Set entry for a Field name."""
247+
normalized_name = self._normalize_field_name(field.name)
248+
self.entries[normalized_name] = field
249+
250+
def get_field(self, name: str) -> interfaces.http.Field:
251+
"""Retrieve Field entry."""
252+
normalized_name = self._normalize_field_name(name)
253+
return self.entries[normalized_name]
254+
255+
def remove_field(self, name: str) -> None:
256+
"""Delete entry from collection."""
257+
normalized_name = self._normalize_field_name(name)
258+
del self.entries[normalized_name]
259+
260+
def get_by_type(self, kind: FieldPosition) -> list[interfaces.http.Field]:
261+
"""Helper function for retrieving specific types of fields.
262+
263+
Used to grab all headers or all trailers.
264+
"""
265+
return [entry for entry in self.entries.values() if entry.kind is kind]
266+
267+
def _normalize_field_name(self, name: str) -> str:
268+
"""Normalize field names. For use as key in ``entries``."""
269+
return name.lower()
270+
271+
def __eq__(self, other: object) -> bool:
272+
"""Encoding must match. Entries must match in values and order."""
273+
if not isinstance(other, Fields):
274+
return False
275+
return self.encoding == other.encoding and self.entries == other.entries
276+
277+
def __iter__(self) -> Iterable[interfaces.http.Field]:
278+
yield from self.entries.values()
279+
280+
120281
@dataclass
121-
class Endpoint(http_interface.Endpoint):
122-
url: http_interface.URI
123-
headers: http_interface.HeadersList = field(default_factory=list)
282+
class Endpoint(interfaces.http.Endpoint):
283+
url: interfaces.http.URI
284+
headers: interfaces.http.HeadersList = field(default_factory=list)
124285

125286

126287
@dataclass
@@ -131,10 +292,10 @@ class StaticEndpointParams:
131292
:params url: A static URI to route requests to.
132293
"""
133294

134-
url: str | http_interface.URI
295+
url: str | interfaces.http.URI
135296

136297

137-
class StaticEndpointResolver(http_interface.EndpointResolver[StaticEndpointParams]):
298+
class StaticEndpointResolver(interfaces.http.EndpointResolver[StaticEndpointParams]):
138299
"""A basic endpoint resolver that forwards a static url."""
139300

140301
async def resolve_endpoint(self, params: StaticEndpointParams) -> Endpoint:
@@ -164,7 +325,7 @@ async def resolve_endpoint(self, params: StaticEndpointParams) -> Endpoint:
164325

165326

166327
class _StaticEndpointConfig(Protocol):
167-
endpoint_resolver: http_interface.EndpointResolver[StaticEndpointParams] | None
328+
endpoint_resolver: interfaces.http.EndpointResolver[StaticEndpointParams] | None
168329

169330

170331
def set_static_endpoint_resolver(config: _StaticEndpointConfig) -> None:

python-packages/smithy-python/smithy_python/interfaces/http.py

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,18 @@
1+
# Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
from collections import OrderedDict
114
from dataclasses import dataclass
15+
from enum import Enum
216
from typing import Any, Protocol, TypeVar
317

418
# Defining headers as a list instead of a mapping to avoid ambiguity and
@@ -7,6 +21,94 @@
721
QueryParamsList = list[tuple[str, str]]
822

923

24+
class FieldPosition(Enum):
25+
"""
26+
The type of a field. Defines its placement in a request or response.
27+
"""
28+
29+
HEADER = 0
30+
"""
31+
Header field. In HTTP this is a header as defined in RFC 9114 Section 6.3.
32+
Implementations of other protocols may use this FieldPosition for similar types
33+
of metadata.
34+
"""
35+
36+
TRAILER = 1
37+
"""
38+
Trailer field. In HTTP this is a trailer as defined in RFC 9114 Section 6.5.
39+
Implementations of other protocols may use this FieldPosition for similar types
40+
of metadata.
41+
"""
42+
43+
44+
class Field(Protocol):
45+
"""
46+
A name-value pair representing a single field in a request or response.
47+
48+
The kind will dictate metadata placement within an the message, for example as
49+
header or trailer field in a HTTP request as defined in RFC 9114 Section 4.2.
50+
51+
All field names are case insensitive and case-variance must be treated as
52+
equivalent. Names may be normalized but should be preserved for accuracy during
53+
transmission.
54+
"""
55+
56+
name: str
57+
values: list[str]
58+
kind: FieldPosition = FieldPosition.HEADER
59+
60+
def add(self, value: str) -> None:
61+
"""Append a value to a field."""
62+
...
63+
64+
def set(self, values: list[str]) -> None:
65+
"""Overwrite existing field values."""
66+
...
67+
68+
def remove(self, value: str) -> None:
69+
"""Remove all matching entries from list."""
70+
...
71+
72+
def as_string(self) -> str:
73+
"""Serialize the ``Field``'s values into a single line string."""
74+
...
75+
76+
def as_tuples(self) -> list[tuple[str, str]]:
77+
"""
78+
Get list of ``name``, ``value`` tuples where each tuple represents one value.
79+
"""
80+
...
81+
82+
83+
class Fields(Protocol):
84+
"""
85+
Protocol agnostic mapping of key-value pair request metadata, such as HTTP fields.
86+
"""
87+
88+
# Entries are keyed off the name of a provided Field
89+
entries: OrderedDict[str, Field]
90+
encoding: str | None = "utf-8"
91+
92+
def set_field(self, field: Field) -> None:
93+
"""Set entry for a Field name."""
94+
...
95+
96+
def get_field(self, name: str) -> Field:
97+
"""Retrieve Field entry."""
98+
...
99+
100+
def remove_field(self, name: str) -> None:
101+
"""Delete entry from collection."""
102+
...
103+
104+
def get_by_type(self, kind: FieldPosition) -> list[Field]:
105+
"""Helper function for retrieving specific types of fields.
106+
107+
Used to grab all headers or all trailers.
108+
"""
109+
...
110+
111+
10112
class URI(Protocol):
11113
"""Universal Resource Identifier, target location for a :py:class:`Request`."""
12114

0 commit comments

Comments
 (0)