1414# TODO: move all of this out of _private
1515
1616
17+ from collections import Counter , OrderedDict
18+ from collections .abc import Iterable
1719from dataclasses import dataclass , field
1820from typing import Any , Protocol
1921from urllib .parse import urlparse , urlunparse
2022
21- from smithy_python .interfaces import http as http_interface
23+ from ... import interfaces
24+ from ...interfaces .http import FieldPosition as FieldPosition # re-export
2225
2326
2427class URI :
@@ -91,16 +94,15 @@ def __eq__(self, other: object) -> bool:
9194class Request :
9295 def __init__ (
9396 self ,
94- url : http_interface .URI ,
97+ url : interfaces . http .URI ,
9598 method : str = "GET" ,
96- headers : http_interface .HeadersList | None = None ,
99+ headers : interfaces . http .HeadersList | None = None ,
97100 body : Any = None ,
98101 ):
99- self .url : http_interface .URI = url
102+ self .url : interfaces . http .URI = url
100103 self .method : str = method
101104 self .body : Any = body
102-
103- self .headers : http_interface .HeadersList = []
105+ self .headers : interfaces .http .HeadersList = []
104106 if headers is not None :
105107 self .headers = headers
106108
@@ -109,18 +111,177 @@ class Response:
109111 def __init__ (
110112 self ,
111113 status_code : int ,
112- headers : http_interface .HeadersList ,
114+ headers : interfaces . http .HeadersList ,
113115 body : Any ,
114116 ):
115117 self .status_code : int = status_code
116- self .headers : http_interface .HeadersList = headers
118+ self .headers : interfaces . http .HeadersList = headers
117119 self .body : Any = body
118120
119121
122+ class Field (interfaces .http .Field ):
123+ """
124+ A name-value pair representing a single field in an HTTP Request or Response.
125+
126+ The kind will dictate metadata placement within an HTTP message.
127+
128+ All field names are case insensitive and case-variance must be treated as
129+ equivalent. Names may be normalized but should be preserved for accuracy during
130+ transmission.
131+ """
132+
133+ def __init__ (
134+ self ,
135+ * ,
136+ name : str ,
137+ values : Iterable [str ] | None = None ,
138+ kind : FieldPosition = FieldPosition .HEADER ,
139+ ):
140+ self .name = name
141+ self .values : list [str ] = [val for val in values ] if values is not None else []
142+ self .kind = kind
143+
144+ def add (self , value : str ) -> None :
145+ """Append a value to a field."""
146+ self .values .append (value )
147+
148+ def set (self , values : list [str ]) -> None :
149+ """Overwrite existing field values."""
150+ self .values = values
151+
152+ def remove (self , value : str ) -> None :
153+ """Remove all matching entries from list."""
154+ try :
155+ while True :
156+ self .values .remove (value )
157+ except ValueError :
158+ return
159+
160+ def as_string (self ) -> str :
161+ """
162+ Get comma-delimited string of all values.
163+
164+ If the ``Field`` has zero values, the empty string is returned. If the ``Field``
165+ has exactly one value, the value is returned unmodified.
166+
167+ For ``Field``s with more than one value, the values are joined by a comma and a
168+ space. For such multi-valued ``Field``s, any values that already contain
169+ commas or double quotes will be surrounded by double quotes. Within any values
170+ that get quoted, pre-existing double quotes and backslashes are escaped with a
171+ backslash.
172+ """
173+ value_count = len (self .values )
174+ if value_count == 0 :
175+ return ""
176+ if value_count == 1 :
177+ return self .values [0 ]
178+ return ", " .join (quote_and_escape_field_value (val ) for val in self .values )
179+
180+ def as_tuples (self ) -> list [tuple [str , str ]]:
181+ """
182+ Get list of ``name``, ``value`` tuples where each tuple represents one value.
183+ """
184+ return [(self .name , val ) for val in self .values ]
185+
186+ def __eq__ (self , other : object ) -> bool :
187+ """Name, values, and kind must match. Values order must match."""
188+ if not isinstance (other , Field ):
189+ return False
190+ return (
191+ self .name == other .name
192+ and self .kind is other .kind
193+ and self .values == other .values
194+ )
195+
196+ def __repr__ (self ) -> str :
197+ return f"Field(name={ self .name !r} , value={ self .values !r} , kind={ self .kind !r} )"
198+
199+
200+ def quote_and_escape_field_value (value : str ) -> str :
201+ """Escapes and quotes a single :class:`Field` value if necessary.
202+
203+ See :func:`Field.as_string` for quoting and escaping logic.
204+ """
205+ chars_to_quote = ("," , '"' )
206+ if any (char in chars_to_quote for char in value ):
207+ escaped = value .replace ("\\ " , "\\ \\ " ).replace ('"' , '\\ "' )
208+ return f'"{ escaped } "'
209+ else :
210+ return value
211+
212+
213+ class Fields (interfaces .http .Fields ):
214+ def __init__ (
215+ self ,
216+ initial : Iterable [interfaces .http .Field ] | None = None ,
217+ * ,
218+ encoding : str = "utf-8" ,
219+ ):
220+ """
221+ Collection of header and trailer entries mapped by name.
222+
223+ :param initial: Initial list of ``Field`` objects. ``Field``s can alse be added
224+ with :func:`set_field` and later removed with :func:`remove_field`.
225+ :param encoding: The string encoding to be used when converting the ``Field``
226+ name and value from ``str`` to ``bytes`` for transmission.
227+ """
228+ init_fields = [fld for fld in initial ] if initial is not None else []
229+ init_field_names = [self ._normalize_field_name (fld .name ) for fld in init_fields ]
230+ fname_counter = Counter (init_field_names )
231+ repeated_names_exist = (
232+ len (init_fields ) > 0 and fname_counter .most_common (1 )[0 ][1 ] > 1
233+ )
234+ if repeated_names_exist :
235+ non_unique_names = [name for name , num in fname_counter .items () if num > 1 ]
236+ raise ValueError (
237+ "Field names of the initial list of fields must be unique. The "
238+ "following normalized field names appear more than once: "
239+ f"{ ', ' .join (non_unique_names )} ."
240+ )
241+ init_tuples = zip (init_field_names , init_fields )
242+ self .entries : OrderedDict [str , interfaces .http .Field ] = OrderedDict (init_tuples )
243+ self .encoding : str = encoding
244+
245+ def set_field (self , field : interfaces .http .Field ) -> None :
246+ """Set entry for a Field name."""
247+ normalized_name = self ._normalize_field_name (field .name )
248+ self .entries [normalized_name ] = field
249+
250+ def get_field (self , name : str ) -> interfaces .http .Field :
251+ """Retrieve Field entry."""
252+ normalized_name = self ._normalize_field_name (name )
253+ return self .entries [normalized_name ]
254+
255+ def remove_field (self , name : str ) -> None :
256+ """Delete entry from collection."""
257+ normalized_name = self ._normalize_field_name (name )
258+ del self .entries [normalized_name ]
259+
260+ def get_by_type (self , kind : FieldPosition ) -> list [interfaces .http .Field ]:
261+ """Helper function for retrieving specific types of fields.
262+
263+ Used to grab all headers or all trailers.
264+ """
265+ return [entry for entry in self .entries .values () if entry .kind is kind ]
266+
267+ def _normalize_field_name (self , name : str ) -> str :
268+ """Normalize field names. For use as key in ``entries``."""
269+ return name .lower ()
270+
271+ def __eq__ (self , other : object ) -> bool :
272+ """Encoding must match. Entries must match in values and order."""
273+ if not isinstance (other , Fields ):
274+ return False
275+ return self .encoding == other .encoding and self .entries == other .entries
276+
277+ def __iter__ (self ) -> Iterable [interfaces .http .Field ]:
278+ yield from self .entries .values ()
279+
280+
120281@dataclass
121- class Endpoint (http_interface .Endpoint ):
122- url : http_interface .URI
123- headers : http_interface .HeadersList = field (default_factory = list )
282+ class Endpoint (interfaces . http .Endpoint ):
283+ url : interfaces . http .URI
284+ headers : interfaces . http .HeadersList = field (default_factory = list )
124285
125286
126287@dataclass
@@ -131,10 +292,10 @@ class StaticEndpointParams:
131292 :params url: A static URI to route requests to.
132293 """
133294
134- url : str | http_interface .URI
295+ url : str | interfaces . http .URI
135296
136297
137- class StaticEndpointResolver (http_interface .EndpointResolver [StaticEndpointParams ]):
298+ class StaticEndpointResolver (interfaces . http .EndpointResolver [StaticEndpointParams ]):
138299 """A basic endpoint resolver that forwards a static url."""
139300
140301 async def resolve_endpoint (self , params : StaticEndpointParams ) -> Endpoint :
@@ -164,7 +325,7 @@ async def resolve_endpoint(self, params: StaticEndpointParams) -> Endpoint:
164325
165326
166327class _StaticEndpointConfig (Protocol ):
167- endpoint_resolver : http_interface .EndpointResolver [StaticEndpointParams ] | None
328+ endpoint_resolver : interfaces . http .EndpointResolver [StaticEndpointParams ] | None
168329
169330
170331def set_static_endpoint_resolver (config : _StaticEndpointConfig ) -> None :
0 commit comments