Skip to content

Commit 1fa7773

Browse files
feat(data_classes): return empty dict or list instead of None (aws-powertools#4606)
* feat(data_classes): return empty dict or list instead of None This simplifies the code internally and also for users. Also wrap all headers in CaseInsensitiveDict from requests. These changes replace the need of utility functions like get_header_value, get_query_string_value or get_multi_value_query_string_values, which are removed. * Add custom CaseInsensitiveDict This is hopefully a simpler implementations that the requests' package one, but still had to be minimally complex to be complete. * Update CHANGELOG.md * Revert changes in prev_result * Revert changes to examples using "cloudfront-viewer-country" * Update tests/unit/data_classes/test_appsync_resolver_event.py * Minor simplification in CaseInsensitiveDict --------- Co-authored-by: Leandro Damascena <lcdama@amazon.pt>
1 parent 46473a1 commit 1fa7773

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

51 files changed

+280
-880
lines changed

Diff for: aws_lambda_powertools/event_handler/api_gateway.py

+1-5
Original file line numberDiff line numberDiff line change
@@ -817,11 +817,7 @@ def _has_compression_enabled(
817817
bool
818818
True if compression is enabled and the "gzip" encoding is accepted, False otherwise.
819819
"""
820-
encoding: str = event.get_header_value(
821-
name="accept-encoding",
822-
default_value="",
823-
case_sensitive=False,
824-
) # noqa: E501
820+
encoding = event.headers.get("accept-encoding", "")
825821
if "gzip" in encoding:
826822
if response_compression is not None:
827823
return response_compression # e.g., Response(compress=False/True))

Diff for: aws_lambda_powertools/event_handler/appsync.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,7 @@ def handler(event, context: LambdaContext):
127127
class MyCustomModel(AppSyncResolverEvent):
128128
@property
129129
def country_viewer(self) -> str:
130-
return self.request_headers.get("cloudfront-viewer-country")
130+
return self.request_headers.get("cloudfront-viewer-country", "")
131131
132132
133133
@app.resolver(field_name="listLocations")

Diff for: aws_lambda_powertools/event_handler/middlewares/base.py

+1-4
Original file line numberDiff line numberDiff line change
@@ -47,10 +47,7 @@ def __init__(self, header: str):
4747
def handler(self, app: APIGatewayRestResolver, next_middleware: NextMiddleware) -> Response:
4848
# BEFORE logic
4949
request_id = app.current_event.request_context.request_id
50-
correlation_id = app.current_event.get_header_value(
51-
name=self.header,
52-
default_value=request_id,
53-
)
50+
correlation_id = app.current_event.headers.get(self.header, request_id)
5451
5552
# Call next middleware or route handler ('/todos')
5653
response = next_middleware(app)

Diff for: aws_lambda_powertools/event_handler/middlewares/openapi_validation.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import json
33
import logging
44
from copy import deepcopy
5-
from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Tuple
5+
from typing import Any, Callable, Dict, List, Mapping, MutableMapping, Optional, Sequence, Tuple
66

77
from pydantic import BaseModel
88

@@ -237,8 +237,8 @@ def _get_body(self, app: EventHandlerInstance) -> Dict[str, Any]:
237237
Get the request body from the event, and parse it as JSON.
238238
"""
239239

240-
content_type_value = app.current_event.get_header_value("content-type")
241-
if not content_type_value or content_type_value.strip().startswith("application/json"):
240+
content_type = app.current_event.headers.get("content-type")
241+
if not content_type or content_type.strip().startswith("application/json"):
242242
try:
243243
return app.current_event.json_body
244244
except json.JSONDecodeError as e:
@@ -410,7 +410,7 @@ def _normalize_multi_query_string_with_param(
410410
return resolved_query_string
411411

412412

413-
def _normalize_multi_header_values_with_param(headers: Dict[str, Any], params: Sequence[ModelField]):
413+
def _normalize_multi_header_values_with_param(headers: MutableMapping[str, Any], params: Sequence[ModelField]):
414414
"""
415415
Extract and normalize resolved_headers_field
416416

Diff for: aws_lambda_powertools/event_handler/util.py

+4-12
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,4 @@
1-
from typing import Any, Dict
2-
3-
from aws_lambda_powertools.utilities.data_classes.shared_functions import get_header_value
1+
from typing import Any, Mapping, Optional
42

53

64
class _FrozenDict(dict):
@@ -18,25 +16,19 @@ def __hash__(self):
1816
return hash(frozenset(self.keys()))
1917

2018

21-
def extract_origin_header(resolver_headers: Dict[str, Any]):
19+
def extract_origin_header(resolved_headers: Mapping[str, Any]) -> Optional[str]:
2220
"""
2321
Extracts the 'origin' or 'Origin' header from the provided resolver headers.
2422
2523
The 'origin' or 'Origin' header can be either a single header or a multi-header.
2624
2725
Args:
28-
resolver_headers (Dict): A dictionary containing the headers.
26+
resolved_headers (Mapping): A dictionary containing the headers.
2927
3028
Returns:
3129
Optional[str]: The value(s) of the origin header or None.
3230
"""
33-
resolved_header = get_header_value(
34-
headers=resolver_headers,
35-
name="origin",
36-
default_value=None,
37-
case_sensitive=False,
38-
)
31+
resolved_header = resolved_headers.get("origin")
3932
if isinstance(resolved_header, list):
4033
return resolved_header[0]
41-
4234
return resolved_header

Diff for: aws_lambda_powertools/utilities/data_classes/alb_event.py

+7-16
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Any, Dict, List, Optional
1+
from typing import Any, Dict, List
22

33
from aws_lambda_powertools.shared.headers_serializer import (
44
BaseHeadersSerializer,
@@ -7,6 +7,7 @@
77
)
88
from aws_lambda_powertools.utilities.data_classes.common import (
99
BaseProxyEvent,
10+
CaseInsensitiveDict,
1011
DictWrapper,
1112
)
1213

@@ -37,25 +38,15 @@ def multi_value_query_string_parameters(self) -> Dict[str, List[str]]:
3738

3839
@property
3940
def resolved_query_string_parameters(self) -> Dict[str, List[str]]:
40-
if self.multi_value_query_string_parameters:
41-
return self.multi_value_query_string_parameters
42-
43-
return super().resolved_query_string_parameters
41+
return self.multi_value_query_string_parameters or super().resolved_query_string_parameters
4442

4543
@property
46-
def resolved_headers_field(self) -> Dict[str, Any]:
47-
headers: Dict[str, Any] = {}
48-
49-
if self.multi_value_headers:
50-
headers = self.multi_value_headers
51-
else:
52-
headers = self.headers
53-
54-
return {key.lower(): value for key, value in headers.items()}
44+
def multi_value_headers(self) -> Dict[str, List[str]]:
45+
return CaseInsensitiveDict(self.get("multiValueHeaders"))
5546

5647
@property
57-
def multi_value_headers(self) -> Optional[Dict[str, List[str]]]:
58-
return self.get("multiValueHeaders")
48+
def resolved_headers_field(self) -> Dict[str, Any]:
49+
return self.multi_value_headers or self.headers
5950

6051
def header_serializer(self) -> BaseHeadersSerializer:
6152
# When using the ALB integration, the `multiValueHeaders` feature can be disabled (default) or enabled.

Diff for: aws_lambda_powertools/utilities/data_classes/api_gateway_authorizer_event.py

+10-85
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,13 @@
11
import enum
22
import re
3-
from typing import Any, Dict, List, Optional, overload
3+
from typing import Any, Dict, List, Optional
44

55
from aws_lambda_powertools.utilities.data_classes.common import (
66
BaseRequestContext,
77
BaseRequestContextV2,
8+
CaseInsensitiveDict,
89
DictWrapper,
910
)
10-
from aws_lambda_powertools.utilities.data_classes.shared_functions import (
11-
get_header_value,
12-
)
1311

1412

1513
class APIGatewayRouteArn:
@@ -144,7 +142,7 @@ def http_method(self) -> str:
144142

145143
@property
146144
def headers(self) -> Dict[str, str]:
147-
return self["headers"]
145+
return CaseInsensitiveDict(self["headers"])
148146

149147
@property
150148
def query_string_parameters(self) -> Dict[str, str]:
@@ -162,45 +160,6 @@ def stage_variables(self) -> Dict[str, str]:
162160
def request_context(self) -> BaseRequestContext:
163161
return BaseRequestContext(self._data)
164162

165-
@overload
166-
def get_header_value(
167-
self,
168-
name: str,
169-
default_value: str,
170-
case_sensitive: bool = False,
171-
) -> str: ...
172-
173-
@overload
174-
def get_header_value(
175-
self,
176-
name: str,
177-
default_value: Optional[str] = None,
178-
case_sensitive: bool = False,
179-
) -> Optional[str]: ...
180-
181-
def get_header_value(
182-
self,
183-
name: str,
184-
default_value: Optional[str] = None,
185-
case_sensitive: bool = False,
186-
) -> Optional[str]:
187-
"""Get header value by name
188-
189-
Parameters
190-
----------
191-
name: str
192-
Header name
193-
default_value: str, optional
194-
Default value if no value was found by name
195-
case_sensitive: bool
196-
Whether to use a case-sensitive look up
197-
Returns
198-
-------
199-
str, optional
200-
Header value
201-
"""
202-
return get_header_value(self.headers, name, default_value, case_sensitive)
203-
204163

205164
class APIGatewayAuthorizerEventV2(DictWrapper):
206165
"""API Gateway Authorizer Event Format 2.0
@@ -234,14 +193,14 @@ def parsed_arn(self) -> APIGatewayRouteArn:
234193
return parse_api_gateway_arn(self.route_arn)
235194

236195
@property
237-
def identity_source(self) -> Optional[List[str]]:
196+
def identity_source(self) -> List[str]:
238197
"""The identity source for which authorization is requested.
239198
240199
For a REQUEST authorizer, this is optional. The value is a set of one or more mapping expressions of the
241200
specified request parameters. The identity source can be headers, query string parameters, stage variables,
242201
and context parameters.
243202
"""
244-
return self.get("identitySource")
203+
return self.get("identitySource") or []
245204

246205
@property
247206
def route_key(self) -> str:
@@ -265,7 +224,7 @@ def cookies(self) -> List[str]:
265224
@property
266225
def headers(self) -> Dict[str, str]:
267226
"""Http headers"""
268-
return self["headers"]
227+
return CaseInsensitiveDict(self["headers"])
269228

270229
@property
271230
def query_string_parameters(self) -> Dict[str, str]:
@@ -276,46 +235,12 @@ def request_context(self) -> BaseRequestContextV2:
276235
return BaseRequestContextV2(self._data)
277236

278237
@property
279-
def path_parameters(self) -> Optional[Dict[str, str]]:
280-
return self.get("pathParameters")
238+
def path_parameters(self) -> Dict[str, str]:
239+
return self.get("pathParameters") or {}
281240

282241
@property
283-
def stage_variables(self) -> Optional[Dict[str, str]]:
284-
return self.get("stageVariables")
285-
286-
@overload
287-
def get_header_value(self, name: str, default_value: str, case_sensitive: bool = False) -> str: ...
288-
289-
@overload
290-
def get_header_value(
291-
self,
292-
name: str,
293-
default_value: Optional[str] = None,
294-
case_sensitive: bool = False,
295-
) -> Optional[str]: ...
296-
297-
def get_header_value(
298-
self,
299-
name: str,
300-
default_value: Optional[str] = None,
301-
case_sensitive: bool = False,
302-
) -> Optional[str]:
303-
"""Get header value by name
304-
305-
Parameters
306-
----------
307-
name: str
308-
Header name
309-
default_value: str, optional
310-
Default value if no value was found by name
311-
case_sensitive: bool
312-
Whether to use a case-sensitive look up
313-
Returns
314-
-------
315-
str, optional
316-
Header value
317-
"""
318-
return get_header_value(self.headers, name, default_value, case_sensitive)
242+
def stage_variables(self) -> Dict[str, str]:
243+
return self.get("stageVariables") or {}
319244

320245

321246
class APIGatewayAuthorizerResponseV2:

Diff for: aws_lambda_powertools/utilities/data_classes/api_gateway_proxy_event.py

+16-25
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from functools import cached_property
12
from typing import Any, Dict, List, Optional
23

34
from aws_lambda_powertools.shared.headers_serializer import (
@@ -9,6 +10,7 @@
910
BaseProxyEvent,
1011
BaseRequestContext,
1112
BaseRequestContextV2,
13+
CaseInsensitiveDict,
1214
DictWrapper,
1315
)
1416

@@ -113,7 +115,7 @@ def resource(self) -> str:
113115

114116
@property
115117
def multi_value_headers(self) -> Dict[str, List[str]]:
116-
return self.get("multiValueHeaders") or {} # key might exist but can be `null`
118+
return CaseInsensitiveDict(self.get("multiValueHeaders"))
117119

118120
@property
119121
def multi_value_query_string_parameters(self) -> Dict[str, List[str]]:
@@ -128,26 +130,19 @@ def resolved_query_string_parameters(self) -> Dict[str, List[str]]:
128130

129131
@property
130132
def resolved_headers_field(self) -> Dict[str, Any]:
131-
headers: Dict[str, Any] = {}
132-
133-
if self.multi_value_headers:
134-
headers = self.multi_value_headers
135-
else:
136-
headers = self.headers
137-
138-
return {key.lower(): value for key, value in headers.items()}
133+
return self.multi_value_headers or self.headers
139134

140135
@property
141136
def request_context(self) -> APIGatewayEventRequestContext:
142137
return APIGatewayEventRequestContext(self._data)
143138

144139
@property
145-
def path_parameters(self) -> Optional[Dict[str, str]]:
146-
return self.get("pathParameters")
140+
def path_parameters(self) -> Dict[str, str]:
141+
return self.get("pathParameters") or {}
147142

148143
@property
149-
def stage_variables(self) -> Optional[Dict[str, str]]:
150-
return self.get("stageVariables")
144+
def stage_variables(self) -> Dict[str, str]:
145+
return self.get("stageVariables") or {}
151146

152147
def header_serializer(self) -> BaseHeadersSerializer:
153148
return MultiValueHeadersSerializer()
@@ -289,20 +284,20 @@ def raw_query_string(self) -> str:
289284
return self["rawQueryString"]
290285

291286
@property
292-
def cookies(self) -> Optional[List[str]]:
293-
return self.get("cookies")
287+
def cookies(self) -> List[str]:
288+
return self.get("cookies") or []
294289

295290
@property
296291
def request_context(self) -> RequestContextV2:
297292
return RequestContextV2(self._data)
298293

299294
@property
300-
def path_parameters(self) -> Optional[Dict[str, str]]:
301-
return self.get("pathParameters")
295+
def path_parameters(self) -> Dict[str, str]:
296+
return self.get("pathParameters") or {}
302297

303298
@property
304-
def stage_variables(self) -> Optional[Dict[str, str]]:
305-
return self.get("stageVariables")
299+
def stage_variables(self) -> Dict[str, str]:
300+
return self.get("stageVariables") or {}
306301

307302
@property
308303
def path(self) -> str:
@@ -319,10 +314,6 @@ def http_method(self) -> str:
319314
def header_serializer(self):
320315
return HttpApiHeadersSerializer()
321316

322-
@property
317+
@cached_property
323318
def resolved_headers_field(self) -> Dict[str, Any]:
324-
if self.headers is not None:
325-
headers = {key.lower(): value.split(",") if "," in value else value for key, value in self.headers.items()}
326-
return headers
327-
328-
return {}
319+
return CaseInsensitiveDict((k, v.split(",") if "," in v else v) for k, v in self.headers.items())

0 commit comments

Comments
 (0)