diff --git a/samcli/commands/local/lib/swagger/parser.py b/samcli/commands/local/lib/swagger/parser.py index 9c46e0c631..68ada78024 100644 --- a/samcli/commands/local/lib/swagger/parser.py +++ b/samcli/commands/local/lib/swagger/parser.py @@ -81,7 +81,7 @@ def get_authorizers(self, event_type: str = Route.API) -> Dict[str, Authorizer]: authorizers: Dict[str, Authorizer] = {} authorizer_dict = {} - document_version = self.swagger.get(SwaggerParser._SWAGGER) or self.swagger.get(SwaggerParser._OPENAPI) or "" + document_version = self._get_document_version() if document_version.startswith(SwaggerParser._2_X_VERSION): LOG.debug("Parsing Swagger document using 2.0 specification") @@ -240,6 +240,19 @@ def _get_lambda_identity_sources( return identity_sources + def _get_document_version(self) -> str: + """ + Helper method to fetch the Swagger document version + + Returns + ------- + str + A string representing a version, blank if not found + """ + document_version = self.swagger.get(SwaggerParser._SWAGGER) or self.swagger.get(SwaggerParser._OPENAPI) or "" + + return str(document_version) + def get_default_authorizer(self, event_type: str) -> Union[str, None]: """ Parses the body definition to find root level Authorizer definitions @@ -254,7 +267,7 @@ def get_default_authorizer(self, event_type: str) -> Union[str, None]: Union[str, None] Returns the name of the authorizer, if there is one defined, otherwise None """ - document_version = self.swagger.get(SwaggerParser._SWAGGER) or self.swagger.get(SwaggerParser._OPENAPI) or "" + document_version = self._get_document_version() authorizers = self.swagger.get(SwaggerParser._SWAGGER_SECURITY, []) if not authorizers: diff --git a/tests/unit/commands/local/lib/swagger/test_parser.py b/tests/unit/commands/local/lib/swagger/test_parser.py index 84ce8899de..d854ca595e 100644 --- a/tests/unit/commands/local/lib/swagger/test_parser.py +++ b/tests/unit/commands/local/lib/swagger/test_parser.py @@ -1022,3 +1022,21 @@ def test_invalid_identity_source_throws_exception(self): with self.assertRaises(InvalidSecurityDefinition): parser._get_lambda_identity_sources(Mock(), "request", Route.API, properties, auth_properties) + + +class TestGetDocumentVersion(TestCase): + @parameterized.expand( + [ + ({"swagger": "2.0"}, "2.0"), + ({"swagger": 2.0}, "2.0"), + ({"openapi": "3.0"}, "3.0"), + ({"openapi": 3.0}, "3.0"), + ({"not valid": 3.0}, ""), + ({}, ""), + ] + ) + def test_get_document_version(self, swagger_doc, expected_output): + parser = SwaggerParser(Mock(), swagger_doc) + output = parser._get_document_version() + + self.assertEqual(output, expected_output)