Skip to content
10 changes: 6 additions & 4 deletions samcli/lib/init/template_modifiers/cli_template_modifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def _section_position(self, section: str, position: int = 0) -> int:
return section_index
return -1

def _add_fields_to_section(self, position: int, fields: str) -> Any:
def _add_fields_to_section(self, position: int, fields: List[str]) -> Any:
"""
Adds fields to section in the template

Expand All @@ -72,10 +72,12 @@ def _add_fields_to_section(self, position: int, fields: str) -> Any:
list
array with updated template data
"""
template = self.template[position:]
section_space_count = self.template[position].find(self.template[position].strip())
start_position = position + 1
template = self.template[start_position:]
for index, line in enumerate(template):
if not (line.startswith(" ") or line.startswith("#")):
return self.template[: position + index] + fields + self.template[position + index :]
if not (line.find(line.strip()) > section_space_count or line.startswith("#")):
return self.template[: start_position + index] + fields + self.template[start_position + index :]
return self.template

def _field_position(self, position: int, field: str) -> Any:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,14 @@

class XRayTracingTemplateModifier(TemplateModifier):

FIELD_NAME = "Tracing"
FIELD_NAME_FUNCTION_TRACING = "Tracing"
FIELD_NAME_API_TRACING = "TracingEnabled"
GLOBALS = "Globals:\n"
RESOURCE = "Resources:\n"
FUNCTION = " Function:\n"
TRACING = " Tracing: Active\n"
TRACING_FUNCTION = " Tracing: Active\n"
API = " Api:\n"
TRACING_API = " TracingEnabled: True\n"
COMMENT = (
"# More info about Globals: "
"https://github.com/awslabs/serverless-application-model/blob/master/docs/globals.rst\n"
Expand All @@ -26,37 +29,60 @@ def _add_new_field_to_template(self):
global_section_position = self._section_position(self.GLOBALS)

if global_section_position >= 0:
function_section_position = self._section_position(self.FUNCTION, global_section_position)

if function_section_position >= 0:
field_positon = self._field_position(function_section_position, self.FIELD_NAME)
if field_positon >= 0:
self.template[field_positon] = self.TRACING
return
self._add_tracing_section(
global_section_position, self.FUNCTION, self.FIELD_NAME_FUNCTION_TRACING, self.TRACING_FUNCTION
)
self._add_tracing_section(global_section_position, self.API, self.FIELD_NAME_API_TRACING, self.TRACING_API)
else:
self._add_tracing_with_globals()

new_fields = [self.TRACING]
_section_position = function_section_position
def _add_tracing_with_globals(self):
"""Adds Globals and tracing fields"""
resource_section_position = self._section_position(self.RESOURCE)
globals_section_data = [
self.COMMENT,
self.GLOBALS,
self.FUNCTION,
self.TRACING_FUNCTION,
self.API,
self.TRACING_API,
"\n",
]
self.template = (
self.template[:resource_section_position] + globals_section_data + self.template[resource_section_position:]
)

else:
new_fields = [self.FUNCTION, self.TRACING]
_section_position = global_section_position + 1
def _add_tracing_section(
self,
global_section_position: int,
parent_section: str,
tracing_field_name: str,
tracing_field: str,
):
"""
Adds tracing into the designated field

self.template = self._add_fields_to_section(_section_position, new_fields)
Parameters
----------
global_section_position : dict
Position of the Globals field in the template
parent_section: str
Name of the parent section that the tracing field would be added.
tracing_field_name: str
Name of the tracing field, which will be used to check if it already exist
tracing_field: str
Name of the whole tracing field, which includes its name and value
"""
parent_section_position = self._section_position(parent_section, global_section_position)
if parent_section_position >= 0:
field_positon_function = self._field_position(parent_section_position, tracing_field_name)
if field_positon_function >= 0:
self.template[field_positon_function] = tracing_field

else:
self.template = self._add_fields_to_section(parent_section_position, [tracing_field])
else:
resource_section_position = self._section_position(self.RESOURCE)
globals_section_data = [
self.COMMENT,
self.GLOBALS,
self.FUNCTION,
self.TRACING,
"\n",
]
self.template = (
self.template[:resource_section_position]
+ globals_section_data
+ self.template[resource_section_position:]
)
self.template = self._add_fields_to_section(global_section_position, [parent_section, tracing_field])

def _print_sanity_check_error(self):
link = (
Expand Down
140 changes: 138 additions & 2 deletions tests/unit/lib/init/test_cli_template_modifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ def test_must_add_new_field_to_template(self, get_template_patch):
"Globals:\n",
" Function:\n",
" Tracing: Active\n",
" Api:\n",
" TracingEnabled: True\n",
"\n",
"Resources:\n",
" HelloWorldFunction:\n",
Expand Down Expand Up @@ -69,7 +71,90 @@ def test_must_add_new_function_field_to_template(self, get_template_patch):
"Globals:\n",
" Api:\n",
" api_field: field_value\n",
" TracingEnabled: True\n",
" Function:\n",
" Tracing: Active\n",
"\n",
"Resources:\n",
" HelloWorldFunction:\n",
" Type: AWS::Serverless::Function\n",
" Properties:\n",
" CodeUri: hello_world/\n",
" Handler: app.lambda_handler\n",
]

template_modifier = XRayTracingTemplateModifier(self.location)
template_modifier._add_new_field_to_template()

self.assertEqual(template_modifier.template, expected_template_data)

@patch("samcli.lib.init.template_modifiers.cli_template_modifier.TemplateModifier._get_template")
def test_must_add_new_api_function_field_to_template(self, get_template_patch):
get_template_patch.return_value = [
"# More info about Globals: https://github.com/awslabs/serverless-application-model/blob/master/docs/globals.rst\n",
"Globals:\n",
" HttpApi:\n",
" http_api_field: field_value\n",
"\n",
"Resources:\n",
" HelloWorldFunction:\n",
" Type: AWS::Serverless::Function\n",
" Properties:\n",
" CodeUri: hello_world/\n",
" Handler: app.lambda_handler\n",
]

expected_template_data = [
"# More info about Globals: https://github.com/awslabs/serverless-application-model/blob/master/docs/globals.rst\n",
"Globals:\n",
" HttpApi:\n",
" http_api_field: field_value\n",
" Function:\n",
" Tracing: Active\n",
" Api:\n",
" TracingEnabled: True\n",
"\n",
"Resources:\n",
" HelloWorldFunction:\n",
" Type: AWS::Serverless::Function\n",
" Properties:\n",
" CodeUri: hello_world/\n",
" Handler: app.lambda_handler\n",
]

template_modifier = XRayTracingTemplateModifier(self.location)
template_modifier._add_new_field_to_template()

self.assertEqual(template_modifier.template, expected_template_data)

@patch("samcli.lib.init.template_modifiers.cli_template_modifier.TemplateModifier._get_template")
def test_must_replace_new_field_to_template(self, get_template_patch):
get_template_patch.return_value = [
"# More info about Globals: https://github.com/awslabs/serverless-application-model/blob/master/docs/globals.rst\n",
"Globals:\n",
" Api:\n",
" api_field: field_value\n",
" TracingEnabled: False\n",
" Function:\n",
" function_field: field_value\n",
" Tracing: PassThrough\n",
"\n",
"Resources:\n",
" HelloWorldFunction:\n",
" Type: AWS::Serverless::Function\n",
" Properties:\n",
" CodeUri: hello_world/\n",
" Handler: app.lambda_handler\n",
]

expected_template_data = [
"# More info about Globals: https://github.com/awslabs/serverless-application-model/blob/master/docs/globals.rst\n",
"Globals:\n",
" Api:\n",
" api_field: field_value\n",
" TracingEnabled: True\n",
" Function:\n",
" function_field: field_value\n",
" Tracing: Active\n",
"\n",
"Resources:\n",
Expand Down Expand Up @@ -107,6 +192,8 @@ def test_must_add_new_tracing_field_to_template(self, get_template_patch):
" Function:\n",
" Timeout: 3\n",
" Tracing: Active\n",
" Api:\n",
" TracingEnabled: True\n",
"\n",
"Resources:\n",
" HelloWorldFunction:\n",
Expand All @@ -127,6 +214,8 @@ def test_must_get_section_position(self, get_template_patch):
"Globals:\n",
" Function:\n",
" Tracing: Active\n",
" Api:\n",
" TracingEnabled: True\n",
"\n",
"Resources:\n",
" HelloWorldFunction:\n",
Expand All @@ -139,14 +228,45 @@ def test_must_get_section_position(self, get_template_patch):
template_modifier = XRayTracingTemplateModifier(self.location)
global_location = template_modifier._section_position("Globals:\n")
function_location = template_modifier._section_position(" Function:\n")
api_location = template_modifier._section_position(" Api:\n")
resource_location = template_modifier._section_position("Resources:\n")

self.assertEqual(global_location, 1)
self.assertEqual(function_location, 2)
self.assertEqual(resource_location, 5)
self.assertEqual(api_location, 4)
self.assertEqual(resource_location, 7)

@patch("samcli.lib.init.template_modifiers.cli_template_modifier.TemplateModifier._get_template")
def test_must_get_field_position(self, get_template_patch):
def test_must_get_section_position_desc(self, get_template_patch):
get_template_patch.return_value = [
"# More info about Globals: https://github.com/awslabs/serverless-application-model/blob/master/docs/globals.rst\n",
"Globals:\n",
" Api:\n",
" TracingEnabled: True\n",
" Function:\n",
" Tracing: Active\n",
"\n",
"Resources:\n",
" HelloWorldFunction:\n",
" Type: AWS::Serverless::Function\n",
" Properties:\n",
" CodeUri: hello_world/\n",
" Handler: app.lambda_handler\n",
]

template_modifier = XRayTracingTemplateModifier(self.location)
global_location = template_modifier._section_position("Globals:\n")
function_location = template_modifier._section_position(" Function:\n")
api_location = template_modifier._section_position(" Api:\n")
resource_location = template_modifier._section_position("Resources:\n")

self.assertEqual(global_location, 1)
self.assertEqual(api_location, 2)
self.assertEqual(function_location, 4)
self.assertEqual(resource_location, 7)

@patch("samcli.lib.init.template_modifiers.cli_template_modifier.TemplateModifier._get_template")
def test_must_get_function_field_position(self, get_template_patch):
get_template_patch.return_value = [
"Resources:\n",
" HelloWorldFunction:\n",
Expand All @@ -161,6 +281,22 @@ def test_must_get_field_position(self, get_template_patch):

self.assertEqual(tracing_location, -1)

@patch("samcli.lib.init.template_modifiers.cli_template_modifier.TemplateModifier._get_template")
def test_must_get_api_field_position(self, get_template_patch):
get_template_patch.return_value = [
"Resources:\n",
" HelloWorldFunction:\n",
" Type: AWS::Serverless::Function\n",
" Properties:\n",
" CodeUri: hello_world/\n",
" Handler: app.lambda_handler\n",
]

template_modifier = XRayTracingTemplateModifier(self.location)
tracing_location = template_modifier._field_position(0, "TracingEnabled")

self.assertEqual(tracing_location, -1)

@patch("samcli.lib.init.template_modifiers.xray_tracing_template_modifier.LOG")
@patch("samcli.lib.init.template_modifiers.cli_template_modifier.parse_yaml_file")
def test_must_fail_sanity_check(self, parse_yaml_file_mock, log_mock):
Expand Down