Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .pylintrc
Original file line number Diff line number Diff line change
Expand Up @@ -315,7 +315,7 @@ max-args=5
ignored-argument-names=_.*

# Maximum number of locals for function / method body
max-locals=15
max-locals=17

# Maximum number of return / yield for function / method body
max-returns=6
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,11 @@
"queryStringParameters": {
"foo": "bar"
},
"multiValueQueryStringParameters": {
"foo": [
"bar"
]
},
"pathParameters": {
"proxy": "/{{{path}}}"
},
Expand All @@ -33,6 +38,62 @@
"X-Forwarded-Port": "443",
"X-Forwarded-Proto": "https"
},
"multiValueHeaders": {
"Accept": [
"text/html,application/xhtml+xml,application/xml;q=0.9,image/webp,*/*;q=0.8"
],
"Accept-Encoding": [
"gzip, deflate, sdch"
],
"Accept-Language": [
"en-US,en;q=0.8"
],
"Cache-Control": [
"max-age=0"
],
"CloudFront-Forwarded-Proto": [
"https"
],
"CloudFront-Is-Desktop-Viewer": [
"true"
],
"CloudFront-Is-Mobile-Viewer": [
"false"
],
"CloudFront-Is-SmartTV-Viewer": [
"false"
],
"CloudFront-Is-Tablet-Viewer": [
"false"
],
"CloudFront-Viewer-Country": [
"US"
],
"Host": [
"0123456789.execute-api.{{dns_suffix}}"
],
"Upgrade-Insecure-Requests": [
"1"
],
"User-Agent": [
"Custom User Agent String"
],
"Via": [
"1.1 08f323deadbeefa7af34d5feb414ce27.cloudfront.net (CloudFront)"
],
"X-Amz-Cf-Id": [
"cDehVQoZnx43VYQb9j2-nvCh-9z396Uhbp027Y2JvkCPNLmGJHqlaA=="
],
"X-Forwarded-For": [
"127.0.0.1, 127.0.0.2"
],
"X-Forwarded-Port": [
"443"
],
"X-Forwarded-Proto": [
"https"
]
},
"requestContext": {
"accountId": "{{{account_id}}}",
"resourceId": "123456",
Expand Down
55 changes: 44 additions & 11 deletions samcli/local/apigw/local_apigw_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,21 +346,18 @@ def _construct_event(flask_request, port, binary_types):
identity=identity,
path=endpoint)

event_headers = dict(flask_request.headers)
event_headers["X-Forwarded-Proto"] = flask_request.scheme
event_headers["X-Forwarded-Port"] = str(port)
headers_dict, multi_value_headers_dict = LocalApigwService._event_headers(flask_request, port)

# APIGW does not support duplicate query parameters. Flask gives query params as a list so
# we need to convert only grab the first item unless many were given, were we grab the last to be consistent
# with APIGW
query_string_dict = LocalApigwService._query_string_params(flask_request)
query_string_dict, multi_value_query_string_dict = LocalApigwService._query_string_params(flask_request)

event = ApiGatewayLambdaEvent(http_method=method,
body=request_data,
resource=endpoint,
request_context=context,
query_string_params=query_string_dict,
headers=event_headers,
multi_value_query_string_params=multi_value_query_string_dict,
headers=headers_dict,
multi_value_headers=multi_value_headers_dict,
path_parameters=flask_request.view_args,
path=flask_request.path,
is_base_64_encoded=is_base_64)
Expand All @@ -379,12 +376,13 @@ def _query_string_params(flask_request):
flask_request request
Request from Flask

Returns dict (str: str)
Returns dict (str: str), dict (str: list of str)
-------
Empty dict if no query params where in the request otherwise returns a dictionary of key to value

"""
query_string_dict = {}
multi_value_query_string_dict = {}

# Flask returns an ImmutableMultiDict so convert to a dictionary that becomes
# a dict(str: list) then iterate over
Expand All @@ -394,11 +392,46 @@ def _query_string_params(flask_request):
# if the list is empty, default to empty string
if not query_string_value_length:
query_string_dict[query_string_key] = ""
multi_value_query_string_dict[query_string_key] = [""]
else:
# APIGW doesn't handle duplicate query string keys, picking the last one in the list
query_string_dict[query_string_key] = query_string_list[-1]
multi_value_query_string_dict[query_string_key] = query_string_list

return query_string_dict
return query_string_dict, multi_value_query_string_dict

@staticmethod
def _event_headers(flask_request, port):
"""
Constructs an APIGW equivalent headers dictionary

Parameters
----------
flask_request request
Request from Flask
int port
Forwarded Port

Returns dict (str: str), dict (str: list of str)
-------
Returns a dictionary of key to list of strings

"""
headers_dict = {}
multi_value_headers_dict = {}

# Multi-value request headers is not really supported by Flask.
# See https://github.com/pallets/flask/issues/850
for header_key in flask_request.headers.keys():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What do you think about logging some information about not fully supporting multiple headers? We would have to print this each time, since we have no idea if the customer is impacted but would result in two things:

  1. It would give the customer some information when invoking (ideally linking out to a Github issue). This could result in less of a "Wait is this me or is this SAM CLI not giving the function the headers".
  2. It would give us (SAM CLI) a way to get feedback on how many customers are impacted by this through the Github Issue +1s.

I know this is current state and I think the only alternative is to rewrite the local service or possibly patch Flask in our implementation so we can support this. I just want to try and avoid the confusion and frustration that may come about from this.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good idea, what message should be displayed? Agree, it's time to consider patching Flask or rewriting the local service.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe something like:
"WARNING: Multi-value request headers are not fully supported. See issue: for more details."

We would just need to create the issue with some details on how Flask doesn't support this and therefore we currently can't.

Thoughts?

headers_dict[header_key] = flask_request.headers.get(header_key)
multi_value_headers_dict[header_key] = flask_request.headers.getlist(header_key)

headers_dict["X-Forwarded-Proto"] = flask_request.scheme
multi_value_headers_dict["X-Forwarded-Proto"] = [flask_request.scheme]

headers_dict["X-Forwarded-Port"] = str(port)
multi_value_headers_dict["X-Forwarded-Port"] = [str(port)]

return headers_dict, multi_value_headers_dict

@staticmethod
def _should_base64_encode(binary_types, request_mimetype):
Expand Down
16 changes: 16 additions & 0 deletions samcli/local/events/api_event.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,9 @@ def __init__(self,
resource=None,
request_context=None,
query_string_params=None,
multi_value_query_string_params=None,
headers=None,
multi_value_headers=None,
path_parameters=None,
stage_variables=None,
path=None,
Expand All @@ -145,7 +147,9 @@ def __init__(self,
:param str resource: Resource for the reqeust
:param RequestContext request_context: RequestContext for the request
:param dict query_string_params: Query String parameters
:param dict multi_value_query_string_params: Multi-value Query String parameters
:param dict headers: dict of the request Headers
:param dict multi_value_headers: dict of the multi-value request Headers
:param dict path_parameters: Path Parameters
:param dict stage_variables: API Gateway Stage Variables
:param str path: Path of the request
Expand All @@ -156,9 +160,16 @@ def __init__(self,
query_string_params is not None:
raise TypeError("'query_string_params' must be of type dict or None")

if not isinstance(multi_value_query_string_params, dict) and \
multi_value_query_string_params is not None:
raise TypeError("'multi_value_query_string_params' must be of type dict or None")

if not isinstance(headers, dict) and headers is not None:
raise TypeError("'headers' must be of type dict or None")

if not isinstance(multi_value_headers, dict) and multi_value_headers is not None:
raise TypeError("'multi_value_headers' must be of type dict or None")

if not isinstance(path_parameters, dict) and path_parameters is not None:
raise TypeError("'path_parameters' must be of type dict or None")

Expand All @@ -170,7 +181,9 @@ def __init__(self,
self.resource = resource
self.request_context = request_context
self.query_string_params = query_string_params
self.multi_value_query_string_params = multi_value_query_string_params
self.headers = headers
self.multi_value_headers = multi_value_headers
self.path_parameters = path_parameters
self.stage_variables = stage_variables
self.path = path
Expand All @@ -191,7 +204,10 @@ def to_dict(self):
"resource": self.resource,
"requestContext": request_context_dict,
"queryStringParameters": dict(self.query_string_params) if self.query_string_params else None,
"multiValueQueryStringParameters": dict(self.multi_value_query_string_params)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we need dict(self.multi_value_query_string_params) again? we do validation above to make sure its a dict.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably because someone could mutate the public variable on the object after initializing?

if self.multi_value_query_string_params else None,
"headers": dict(self.headers) if self.headers else None,
"multiValueHeaders": dict(self.multi_value_headers) if self.multi_value_headers else None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here

"pathParameters": dict(self.path_parameters) if self.path_parameters else None,
"stageVariables": dict(self.stage_variables) if self.stage_variables else None,
"path": self.path,
Expand Down
15 changes: 15 additions & 0 deletions tests/integration/local/start_api/test_start_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -421,6 +421,17 @@ def test_request_to_an_endpoint_with_two_different_handlers(self):

self.assertEquals(response_data.get("handler"), 'echo_event_handler_2')

def test_request_with_multi_value_headers(self):
response = requests.get(self.url + "/echoeventbody",
headers={"Content-Type": "application/x-www-form-urlencoded, image/gif"})

self.assertEquals(response.status_code, 200)
response_data = response.json()
self.assertEquals(response_data.get("multiValueHeaders").get("Content-Type"),
["application/x-www-form-urlencoded, image/gif"])
self.assertEquals(response_data.get("headers").get("Content-Type"),
"application/x-www-form-urlencoded, image/gif")

def test_request_with_query_params(self):
"""
Query params given should be put into the Event to Lambda
Expand All @@ -433,6 +444,7 @@ def test_request_with_query_params(self):
response_data = response.json()

self.assertEquals(response_data.get("queryStringParameters"), {"key": "value"})
self.assertEquals(response_data.get("multiValueQueryStringParameters"), {"key": ["value"]})

def test_request_with_list_of_query_params(self):
"""
Expand All @@ -446,6 +458,7 @@ def test_request_with_list_of_query_params(self):
response_data = response.json()

self.assertEquals(response_data.get("queryStringParameters"), {"key": "value2"})
self.assertEquals(response_data.get("multiValueQueryStringParameters"), {"key": ["value", "value2"]})

def test_request_with_path_params(self):
"""
Expand Down Expand Up @@ -480,4 +493,6 @@ def test_forward_headers_are_added_to_event(self):
response_data = response.json()

self.assertEquals(response_data.get("headers").get("X-Forwarded-Proto"), "http")
self.assertEquals(response_data.get("multiValueHeaders").get("X-Forwarded-Proto"), ["http"])
self.assertEquals(response_data.get("headers").get("X-Forwarded-Port"), self.port)
self.assertEquals(response_data.get("multiValueHeaders").get("X-Forwarded-Port"), [self.port])
42 changes: 38 additions & 4 deletions tests/unit/local/apigw/test_local_apigw_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -465,11 +465,16 @@ def setUp(self):
query_param_args_mock = Mock()
query_param_args_mock.lists.return_value = {"query": ["params"]}.items()
self.request_mock.args = query_param_args_mock
self.request_mock.headers = {"Content-Type": "application/json", "X-Test": "Value"}
headers_mock = Mock()
headers_mock.keys.return_value = ["Content-Type", "X-Test"]
headers_mock.get.side_effect = ["application/json", "Value"]
headers_mock.getlist.side_effect = [["application/json"], ["Value"]]
self.request_mock.headers = headers_mock
self.request_mock.view_args = {"path": "params"}
self.request_mock.scheme = "http"

expected = '{"body": "DATA!!!!", "httpMethod": "GET", ' \
'"multiValueQueryStringParameters": {"query": ["params"]}, ' \
'"queryStringParameters": {"query": "params"}, "resource": ' \
'"endpoint", "requestContext": {"httpMethod": "GET", "requestId": ' \
'"c6af9ac6-7b61-11e6-9a41-93e8deadbeef", "path": "endpoint", "extendedRequestId": null, ' \
Expand All @@ -479,6 +484,8 @@ def setUp(self):
'"Custom User Agent String", "caller": null, "cognitoAuthenticationType": null, "sourceIp": ' \
'"190.0.0.0", "user": null}, "accountId": "123456789012"}, "headers": {"Content-Type": ' \
'"application/json", "X-Test": "Value", "X-Forwarded-Port": "3000", "X-Forwarded-Proto": "http"}, ' \
'"multiValueHeaders": {"Content-Type": ["application/json"], "X-Test": ["Value"], '\
'"X-Forwarded-Port": ["3000"], "X-Forwarded-Proto": ["http"]}, ' \
'"stageVariables": null, "path": "path", "pathParameters": {"path": "params"}, ' \
'"isBase64Encoded": false}'

Expand All @@ -505,18 +512,45 @@ def test_construct_event_with_binary_data(self, should_base64_encode_patch):
self.request_mock.get_data.return_value = binary_body
self.expected_dict["body"] = base64_body
self.expected_dict["isBase64Encoded"] = True
self.maxDiff = None

actual_event_str = LocalApigwService._construct_event(self.request_mock, 3000, binary_types=[])
self.assertEquals(json.loads(actual_event_str), self.expected_dict)

def test_event_headers_with_empty_list(self):
request_mock = Mock()
headers_mock = Mock()
headers_mock.keys.return_value = []
request_mock.headers = headers_mock
request_mock.scheme = "http"

actual_query_string = LocalApigwService._event_headers(request_mock, "3000")
self.assertEquals(actual_query_string, ({"X-Forwarded-Proto": "http", "X-Forwarded-Port": "3000"},
{"X-Forwarded-Proto": ["http"], "X-Forwarded-Port": ["3000"]}))

def test_event_headers_with_non_empty_list(self):
request_mock = Mock()
headers_mock = Mock()
headers_mock.keys.return_value = ["Content-Type", "X-Test"]
headers_mock.get.side_effect = ["application/json", "Value"]
headers_mock.getlist.side_effect = [["application/json"], ["Value"]]
request_mock.headers = headers_mock
request_mock.scheme = "http"

actual_query_string = LocalApigwService._event_headers(request_mock, "3000")
self.assertEquals(actual_query_string, ({"Content-Type": "application/json", "X-Test": "Value",
"X-Forwarded-Proto": "http", "X-Forwarded-Port": "3000"},
{"Content-Type": ["application/json"], "X-Test": ["Value"],
"X-Forwarded-Proto": ["http"], "X-Forwarded-Port": ["3000"]}))

def test_query_string_params_with_empty_params(self):
request_mock = Mock()
query_param_args_mock = Mock()
query_param_args_mock.lists.return_value = {}.items()
request_mock.args = query_param_args_mock

actual_query_string = LocalApigwService._query_string_params(request_mock)
self.assertEquals(actual_query_string, {})
self.assertEquals(actual_query_string, ({}, {}))

def test_query_string_params_with_param_value_being_empty_list(self):
request_mock = Mock()
Expand All @@ -525,7 +559,7 @@ def test_query_string_params_with_param_value_being_empty_list(self):
request_mock.args = query_param_args_mock

actual_query_string = LocalApigwService._query_string_params(request_mock)
self.assertEquals(actual_query_string, {"param": ""})
self.assertEquals(actual_query_string, ({"param": ""}, {"param": [""]}))

def test_query_string_params_with_param_value_being_non_empty_list(self):
request_mock = Mock()
Expand All @@ -534,7 +568,7 @@ def test_query_string_params_with_param_value_being_non_empty_list(self):
request_mock.args = query_param_args_mock

actual_query_string = LocalApigwService._query_string_params(request_mock)
self.assertEquals(actual_query_string, {"param": "b"})
self.assertEquals(actual_query_string, ({"param": "b"}, {"param": ["a", "b"]}))


class TestService_should_base64_encode(TestCase):
Expand Down
Loading