1414# TODO: move all of this out of _private
1515
1616
17+ from collections import OrderedDict
1718from dataclasses import dataclass , field
1819from typing import Any , Protocol
1920from urllib .parse import urlparse , urlunparse
2021
21- from smithy_python .interfaces import http as http_interface
22+ from ... import interfaces
23+ from ...interfaces .http import FieldPosition as FieldPosition # re-export
2224
2325
2426class URI :
@@ -91,16 +93,15 @@ def __eq__(self, other: object) -> bool:
9193class Request :
9294 def __init__ (
9395 self ,
94- url : http_interface .URI ,
96+ url : interfaces . http .URI ,
9597 method : str = "GET" ,
96- headers : http_interface .HeadersList | None = None ,
98+ headers : interfaces . http .HeadersList | None = None ,
9799 body : Any = None ,
98100 ):
99- self .url : http_interface .URI = url
101+ self .url : interfaces . http .URI = url
100102 self .method : str = method
101103 self .body : Any = body
102-
103- self .headers : http_interface .HeadersList = []
104+ self .headers : interfaces .http .HeadersList = []
104105 if headers is not None :
105106 self .headers = headers
106107
@@ -109,18 +110,150 @@ class Response:
109110 def __init__ (
110111 self ,
111112 status_code : int ,
112- headers : http_interface .HeadersList ,
113+ headers : interfaces . http .HeadersList ,
113114 body : Any ,
114115 ):
115116 self .status_code : int = status_code
116- self .headers : http_interface .HeadersList = headers
117+ self .headers : interfaces . http .HeadersList = headers
117118 self .body : Any = body
118119
119120
121+ class Field (interfaces .http .Field ):
122+ """
123+ A name-value pair representing a single field in an HTTP Request or Response.
124+
125+ The kind will dictate metadata placement within an HTTP message.
126+
127+ All field names are case insensitive and case-variance must be treated as
128+ equivalent. Names may be normalized but should be preserved for accuracy during
129+ transmission.
130+ """
131+
132+ def __init__ (
133+ self ,
134+ name : str ,
135+ value : list [str ] | None = None ,
136+ kind : FieldPosition = FieldPosition .HEADER ,
137+ ) -> None :
138+ self .name = name
139+ self .value = value
140+ self .kind = kind
141+
142+ def add (self , value : str ) -> None :
143+ """Append a value to a field"""
144+ if self .value is None :
145+ self .value = [value ]
146+ else :
147+ self .value .append (value )
148+
149+ def set (self , value : list [str ]) -> None :
150+ """Overwrite existing field values."""
151+ self .value = value
152+
153+ def remove (self , value : str ) -> None :
154+ """Remove all matching entries from list"""
155+ if self .value is None :
156+ return
157+ try :
158+ while True :
159+ self .value .remove (value )
160+ except ValueError :
161+ return
162+
163+ def _quote_and_escape_single_value (self , value : str ) -> str :
164+ """Escapes and quotes a single value if necessary.
165+
166+ A value is surrounded by double quotes if it contains comma (,) or whitespace.
167+ Any double quote characters present in the value (before quoting) are escaped
168+ with a backslash.
169+ """
170+ escaped = value .replace ('"' , '\\ "' )
171+ needs_quoting = any (char == "," or char .isspace () for char in escaped )
172+ quoted = f'"{ escaped } "' if needs_quoting else escaped
173+ return quoted
174+
175+ def get_value (self ) -> str :
176+ """
177+ Get comma-delimited string values.
178+
179+ Values with spaces or commas are double-quoted.
180+ """
181+ if self .value is None :
182+ return ""
183+ return "," .join (self ._quote_and_escape_single_value (val ) for val in self .value )
184+
185+ def get_value_list (self ) -> list [str ]:
186+ """Get string values as a list"""
187+ if self .value is None :
188+ return []
189+ else :
190+ return self .value
191+
192+ def __eq__ (self , other : object ) -> bool :
193+ """Name, values, and kind must match. Values order must match."""
194+ if not isinstance (other , Field ):
195+ return False
196+ return (
197+ self .name == other .name
198+ and self .kind == other .kind
199+ and self .value == other .value
200+ )
201+
202+ def __repr__ (self ) -> str :
203+ return f"Field({ self .kind .name } { self .name } : { self .get_value ()} )"
204+
205+
206+ class Fields (interfaces .http .Fields ):
207+ """Collection of Field entries mapped by name."""
208+
209+ def __init__ (
210+ self ,
211+ initial : list [interfaces .http .Field ] | None = None ,
212+ * ,
213+ encoding : str = "utf-8" ,
214+ ) -> None :
215+ init_tuples = [] if initial is None else [(fld .name , fld ) for fld in initial ]
216+ self .entries : OrderedDict [str , interfaces .http .Field ] = OrderedDict (init_tuples )
217+ self .encoding : str = encoding
218+
219+ def set_field (self , field : interfaces .http .Field ) -> None :
220+ """Set entry for a Field name."""
221+ self .entries [field .name ] = field
222+
223+ def get_field (self , name : str ) -> interfaces .http .Field :
224+ """Retrieve Field entry"""
225+ return self .entries [name ]
226+
227+ def remove_field (self , name : str ) -> None :
228+ """Delete entry from collection"""
229+ del self .entries [name ]
230+
231+ def get_by_type (self , kind : FieldPosition ) -> list [interfaces .http .Field ]:
232+ """Helper function for retrieving specific types of fields
233+
234+ Used to grab all headers or all trailers
235+ """
236+ return [entry for entry in self .entries .values () if entry .kind == kind ]
237+
238+ def __eq__ (self , other : object ) -> bool :
239+ """Encoding must match. Entries must match in values but not order."""
240+ if not isinstance (other , Fields ):
241+ return False
242+ if self .encoding != other .encoding :
243+ return False
244+ if set (self .entries .keys ()) != set (other .entries .keys ()):
245+ return False
246+ for field_name , self_field in self .entries .items ():
247+ other_field = other .get_field (field_name )
248+ if self_field != other_field :
249+ return False
250+ return True
251+
252+
120253@dataclass
121- class Endpoint (http_interface .Endpoint ):
122- url : http_interface .URI
123- headers : http_interface .HeadersList = field (default_factory = list )
254+ class Endpoint (interfaces . http .Endpoint ):
255+ url : interfaces . http .URI
256+ headers : interfaces . http .HeadersList = field (default_factory = list )
124257
125258
126259@dataclass
@@ -131,10 +264,10 @@ class StaticEndpointParams:
131264 :params url: A static URI to route requests to.
132265 """
133266
134- url : str | http_interface .URI
267+ url : str | interfaces . http .URI
135268
136269
137- class StaticEndpointResolver (http_interface .EndpointResolver [StaticEndpointParams ]):
270+ class StaticEndpointResolver (interfaces . http .EndpointResolver [StaticEndpointParams ]):
138271 """A basic endpoint resolver that forwards a static url."""
139272
140273 async def resolve_endpoint (self , params : StaticEndpointParams ) -> Endpoint :
@@ -164,7 +297,7 @@ async def resolve_endpoint(self, params: StaticEndpointParams) -> Endpoint:
164297
165298
166299class _StaticEndpointConfig (Protocol ):
167- endpoint_resolver : http_interface .EndpointResolver [StaticEndpointParams ] | None
300+ endpoint_resolver : interfaces . http .EndpointResolver [StaticEndpointParams ] | None
168301
169302
170303def set_static_endpoint_resolver (config : _StaticEndpointConfig ) -> None :
0 commit comments