Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 42 additions & 7 deletions samcli/commands/local/lib/sam_api_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,23 +103,58 @@ def extract_cors(self, cors_prop):
"""
cors = None
if cors_prop and isinstance(cors_prop, dict):
allow_methods = cors_prop.get("AllowMethods", ",".join(sorted(Route.ANY_HTTP_METHODS)))
allow_methods = self.normalize_cors_allow_methods(allow_methods)
allow_methods = self._get_cors_prop(cors_prop, "AllowMethods")
if allow_methods:
allow_methods = self.normalize_cors_allow_methods(allow_methods)
else:
allow_methods = ",".join(sorted(Route.ANY_HTTP_METHODS))

allow_origin = self._get_cors_prop(cors_prop, "AllowOrigin")
allow_headers = self._get_cors_prop(cors_prop, "AllowHeaders")
max_age = self._get_cors_prop(cors_prop, "MaxAge")

cors = Cors(
allow_origin=cors_prop.get("AllowOrigin"),
allow_methods=allow_methods,
allow_headers=cors_prop.get("AllowHeaders"),
max_age=cors_prop.get("MaxAge"),
allow_origin=allow_origin, allow_methods=allow_methods, allow_headers=allow_headers, max_age=max_age
)
elif cors_prop and isinstance(cors_prop, string_types):
allow_origin = cors_prop
if not (allow_origin.startswith("'") and allow_origin.endswith("'")):
raise InvalidSamDocumentException(
"Cors Properties must be a quoted string " '(i.e. "\'*\'" is correct, but "*" is not).'
)
allow_origin = allow_origin.strip("'")

cors = Cors(
allow_origin=cors_prop,
allow_origin=allow_origin,
allow_methods=",".join(sorted(Route.ANY_HTTP_METHODS)),
allow_headers=None,
max_age=None,
)
return cors

@staticmethod
def _get_cors_prop(cors_dict, prop_name):
"""
Extract cors properties from dictionary and remove extra quotes.

Parameters
----------
cors_dict : dict
Resource properties for Cors

Return
------
A string with the extra quotes removed
"""
prop = cors_dict.get(prop_name)
if prop:
if (not isinstance(prop, string_types)) or (not (prop.startswith("'") and prop.endswith("'"))):
raise InvalidSamDocumentException(
"{} must be a quoted string " '(i.e. "\'value\'" is correct, but "value" is not).'.format(prop_name)
)
prop = prop.strip("'")
return prop

@staticmethod
def normalize_cors_allow_methods(allow_methods):
"""
Expand Down
8 changes: 4 additions & 4 deletions tests/integration/testdata/start_api/swagger-template.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,10 @@ Resources:
Variables:
VarName: varValue
Cors:
AllowOrigin: "*"
AllowMethods: "GET"
AllowHeaders: "origin, x-requested-with"
MaxAge: 510
AllowOrigin: "'*''"
AllowMethods: "'GET'"
AllowHeaders: "'origin, x-requested-with'"
MaxAge: "'510'"
DefinitionBody:
swagger: "2.0"
info:
Expand Down
2 changes: 1 addition & 1 deletion tests/integration/testdata/start_api/template.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ Globals:
- image~1png
Variables:
VarName: varValue
Cors: "*"
Cors: "'*''"
Resources:
HelloWorldFunction:
Type: AWS::Serverless::Function
Expand Down
137 changes: 116 additions & 21 deletions tests/unit/commands/local/lib/test_sam_api_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -816,7 +816,7 @@ def test_provider_parse_cors_string(self):
"Type": "AWS::Serverless::Api",
"Properties": {
"StageName": "Prod",
"Cors": "*",
"Cors": "'*'",
"DefinitionBody": {
"paths": {
"/path2": {
Expand Down Expand Up @@ -873,10 +873,10 @@ def test_provider_parse_cors_dict(self):
"Properties": {
"StageName": "Prod",
"Cors": {
"AllowMethods": "POST, GET",
"AllowOrigin": "*",
"AllowHeaders": "Upgrade-Insecure-Requests",
"MaxAge": 600,
"AllowMethods": "'POST, GET'",
"AllowOrigin": "'*'",
"AllowHeaders": "'Upgrade-Insecure-Requests'",
"MaxAge": "'600'",
},
"DefinitionBody": {
"paths": {
Expand Down Expand Up @@ -918,7 +918,7 @@ def test_provider_parse_cors_dict(self):
allow_origin="*",
allow_methods=",".join(sorted(["POST", "GET", "OPTIONS"])),
allow_headers="Upgrade-Insecure-Requests",
max_age=600,
max_age="600",
)
route1 = Route(path="/path2", methods=["POST", "OPTIONS"], function_name="NoApiEventFunction")
route2 = Route(path="/path", methods=["POST", "OPTIONS"], function_name="NoApiEventFunction")
Expand All @@ -936,10 +936,10 @@ def test_provider_parse_cors_dict_star_allow(self):
"Properties": {
"StageName": "Prod",
"Cors": {
"AllowMethods": "*",
"AllowOrigin": "*",
"AllowHeaders": "Upgrade-Insecure-Requests",
"MaxAge": 600,
"AllowMethods": "'*'",
"AllowOrigin": "'*'",
"AllowHeaders": "'Upgrade-Insecure-Requests'",
"MaxAge": "'600'",
},
"DefinitionBody": {
"paths": {
Expand Down Expand Up @@ -981,7 +981,7 @@ def test_provider_parse_cors_dict_star_allow(self):
allow_origin="*",
allow_methods=",".join(sorted(Route.ANY_HTTP_METHODS)),
allow_headers="Upgrade-Insecure-Requests",
max_age=600,
max_age="600",
)
route1 = Route(path="/path2", methods=["POST", "OPTIONS"], function_name="NoApiEventFunction")
route2 = Route(path="/path", methods=["POST", "OPTIONS"], function_name="NoApiEventFunction")
Expand All @@ -991,7 +991,7 @@ def test_provider_parse_cors_dict_star_allow(self):
self.assertIn(route2, routes)
self.assertEquals(provider.api.cors, cors)

def test_invalid_cors_dict_allow_methods(self):
def test_raises_error_when_cors_allowmethods_not_single_quoted(self):
template = {
"Resources": {
"TestApi": {
Expand All @@ -1000,9 +1000,104 @@ def test_invalid_cors_dict_allow_methods(self):
"StageName": "Prod",
"Cors": {
"AllowMethods": "GET, INVALID_METHOD",
"AllowOrigin": "*",
"AllowHeaders": "Upgrade-Insecure-Requests",
"MaxAge": 600,
"AllowOrigin": "'*'",
"AllowHeaders": "'Upgrade-Insecure-Requests'",
"MaxAge": "'600'",
},
"DefinitionBody": {
"paths": {
"/path2": {
"post": {
"x-amazon-apigateway-integration": {
"type": "aws_proxy",
"uri": {
"Fn::Sub": "arn:aws:apigateway:${AWS::Region}:lambda:path/2015-03-31"
"/functions/${NoApiEventFunction.Arn}/invocations"
},
"responses": {},
}
}
},
"/path": {
"post": {
"x-amazon-apigateway-integration": {
"type": "aws_proxy",
"uri": {
"Fn::Sub": "arn:aws:apigateway:${AWS::Region}:lambda:path/2015-03-31"
"/functions/${NoApiEventFunction.Arn}/invocations"
},
"responses": {},
}
}
},
}
},
},
}
}
}
with self.assertRaises(
InvalidSamDocumentException, msg="ApiProvider should fail for Invalid Cors AllowMethods not single quoted"
):
ApiProvider(template)

def test_raises_error_when_cors_value_not_single_quoted(self):
template = {
"Resources": {
"TestApi": {
"Type": "AWS::Serverless::Api",
"Properties": {
"StageName": "Prod",
"Cors": "example.com",
"DefinitionBody": {
"paths": {
"/path2": {
"post": {
"x-amazon-apigateway-integration": {
"type": "aws_proxy",
"uri": {
"Fn::Sub": "arn:aws:apigateway:${AWS::Region}:lambda:path/2015-03-31"
"/functions/${NoApiEventFunction.Arn}/invocations"
},
"responses": {},
}
}
},
"/path": {
"post": {
"x-amazon-apigateway-integration": {
"type": "aws_proxy",
"uri": {
"Fn::Sub": "arn:aws:apigateway:${AWS::Region}:lambda:path/2015-03-31"
"/functions/${NoApiEventFunction.Arn}/invocations"
},
"responses": {},
}
}
},
}
},
},
}
}
}
with self.assertRaises(
InvalidSamDocumentException, msg="ApiProvider should fail for Invalid Cors value not single quoted"
):
ApiProvider(template)

def test_invalid_cors_dict_allow_methods(self):
template = {
"Resources": {
"TestApi": {
"Type": "AWS::Serverless::Api",
"Properties": {
"StageName": "Prod",
"Cors": {
"AllowMethods": "'GET, INVALID_METHOD'",
"AllowOrigin": "'*'",
"AllowHeaders": "'Upgrade-Insecure-Requests'",
"MaxAge": "'600'",
},
"DefinitionBody": {
"paths": {
Expand Down Expand Up @@ -1048,7 +1143,7 @@ def test_default_cors_dict_prop(self):
"Type": "AWS::Serverless::Api",
"Properties": {
"StageName": "Prod",
"Cors": {"AllowOrigin": "www.domain.com"},
"Cors": {"AllowOrigin": "'www.domain.com'"},
"DefinitionBody": {
"paths": {
"/path2": {
Expand Down Expand Up @@ -1085,10 +1180,10 @@ def test_global_cors(self):
"Globals": {
"Api": {
"Cors": {
"AllowMethods": "GET",
"AllowOrigin": "*",
"AllowHeaders": "Upgrade-Insecure-Requests",
"MaxAge": 600,
"AllowMethods": "'GET'",
"AllowOrigin": "'*'",
"AllowHeaders": "'Upgrade-Insecure-Requests'",
"MaxAge": "'600'",
}
}
},
Expand Down Expand Up @@ -1137,7 +1232,7 @@ def test_global_cors(self):
allow_origin="*",
allow_headers="Upgrade-Insecure-Requests",
allow_methods=",".join(["GET", "OPTIONS"]),
max_age=600,
max_age="600",
)
route1 = Route(path="/path2", methods=["GET", "OPTIONS"], function_name="NoApiEventFunction")
route2 = Route(path="/path", methods=["GET", "OPTIONS"], function_name="NoApiEventFunction")
Expand Down