diff --git a/samcli/commands/sync/command.py b/samcli/commands/sync/command.py index be3297da8e..7270817818 100644 --- a/samcli/commands/sync/command.py +++ b/samcli/commands/sync/command.py @@ -27,6 +27,7 @@ ) from samcli.cli.cli_config_file import configuration_option, TomlProvider from samcli.commands._utils.click_mutex import ClickMutex +from samcli.commands.sync.sync_context import SyncContext from samcli.lib.utils.colors import Colored from samcli.lib.utils.version_checker import check_newer_version from samcli.lib.bootstrap.bootstrap import manage_stack @@ -329,14 +330,17 @@ def do_cli( disable_rollback=False, poll_delay=poll_delay, ) as deploy_context: - if watch: - execute_watch(template_file, build_context, package_context, deploy_context, dependency_layer) - elif code: - execute_code_sync( - template_file, build_context, deploy_context, resource_id, resource, dependency_layer - ) - else: - execute_infra_contexts(build_context, package_context, deploy_context) + with SyncContext(dependency_layer, build_context.build_dir, build_context.cache_dir): + if watch: + execute_watch( + template_file, build_context, package_context, deploy_context, dependency_layer + ) + elif code: + execute_code_sync( + template_file, build_context, deploy_context, resource_id, resource, dependency_layer + ) + else: + execute_infra_contexts(build_context, package_context, deploy_context) def execute_infra_contexts( diff --git a/samcli/commands/sync/sync_context.py b/samcli/commands/sync/sync_context.py new file mode 100644 index 0000000000..6995494875 --- /dev/null +++ b/samcli/commands/sync/sync_context.py @@ -0,0 +1,106 @@ +""" +Context object used by sync command +""" +import logging +from dataclasses import dataclass +from pathlib import Path +from typing import Optional, cast, Dict + +import tomlkit +from tomlkit.api import _TOMLDocument as TOMLDocument +from tomlkit.items import Item + +from samcli.lib.build.build_graph import DEFAULT_DEPENDENCIES_DIR +from samcli.lib.utils.osutils import rmtree_if_exists + +LOG = logging.getLogger(__name__) + + +DEFAULT_SYNC_STATE_FILE_NAME = "sync.toml" + +SYNC_STATE = "sync_state" +DEPENDENCY_LAYER = "dependency_layer" + + +@dataclass +class SyncState: + dependency_layer: bool + + +def _sync_state_to_toml_document(sync_state: SyncState) -> TOMLDocument: + sync_state_toml_table = tomlkit.table() + sync_state_toml_table[DEPENDENCY_LAYER] = sync_state.dependency_layer + + toml_document = tomlkit.document() + toml_document.add((tomlkit.comment("This file is auto generated by SAM CLI sync command"))) + toml_document.add(SYNC_STATE, cast(Item, sync_state_toml_table)) + + return toml_document + + +def _toml_document_to_sync_state(toml_document: Dict) -> Optional[SyncState]: + if not toml_document: + return None + + sync_state_toml_table = toml_document.get(SYNC_STATE) + if not sync_state_toml_table: + return None + + return SyncState(sync_state_toml_table.get(DEPENDENCY_LAYER)) + + +class SyncContext: + + _current_state: SyncState + _previous_state: Optional[SyncState] + _build_dir: Path + _cache_dir: Path + _file_path: Path + + def __init__(self, dependency_layer: bool, build_dir: str, cache_dir: str): + self._current_state = SyncState(dependency_layer) + self._previous_state = None + self._build_dir = Path(build_dir) + self._cache_dir = Path(cache_dir) + self._file_path = Path(build_dir).parent.joinpath(DEFAULT_SYNC_STATE_FILE_NAME) + + def __enter__(self) -> "SyncContext": + self._read() + LOG.debug( + "Entering sync context, previous state: %s, current state: %s", self._previous_state, self._current_state + ) + + # if adl parameter is changed between sam sync runs, cleanup build, cache and dependencies folders + if self._previous_state and self._previous_state.dependency_layer != self._current_state.dependency_layer: + self._cleanup_build_folders() + + return self + + def __exit__(self, *args): + self._write() + + def _write(self) -> None: + with open(self._file_path, "w+") as file: + file.write(tomlkit.dumps(_sync_state_to_toml_document(self._current_state))) + + def _read(self) -> None: + try: + with open(self._file_path) as file: + toml_document = cast(Dict, tomlkit.loads(file.read())) + self._previous_state = _toml_document_to_sync_state(toml_document) + except OSError: + LOG.debug("Missing previous sync state, will create a new file at the end of this execution") + + def _cleanup_build_folders(self): + """ + Cleans up build, cache and dependencies folders for clean start of the next session + """ + LOG.debug("Cleaning up build directory %s", self._build_dir) + rmtree_if_exists(self._build_dir) + + LOG.debug("Cleaning up cache directory %s", self._cache_dir) + rmtree_if_exists(self._cache_dir) + + dependencies_dir = Path(DEFAULT_DEPENDENCIES_DIR) + LOG.debug("Cleaning up dependencies directory: %s", dependencies_dir) + rmtree_if_exists(dependencies_dir) diff --git a/tests/integration/sync/test_sync_infra.py b/tests/integration/sync/test_sync_infra.py index 14e7bec97c..47c4260ce1 100644 --- a/tests/integration/sync/test_sync_infra.py +++ b/tests/integration/sync/test_sync_infra.py @@ -315,3 +315,48 @@ def test_cdk_templates(self, template_file, template_after, function_id, depende lambda_response = json.loads(self._get_lambda_response(lambda_function)) self.assertIn("extra_message", lambda_response) self.assertEqual(lambda_response.get("message"), "9") + + +@skipIf(SKIP_SYNC_TESTS, "Skip sync tests in CI/CD only") +@parameterized_class([{"dependency_layer": True}, {"dependency_layer": False}]) +class TestSyncInfraWithJava(SyncIntegBase): + @parameterized.expand(["infra/template-java.yaml"]) + def test_sync_infra_with_java(self, template_file): + """This will test a case where user will flip ADL flag between sync sessions""" + template_path = str(self.test_data_path.joinpath(template_file)) + stack_name = self._method_to_stack_name(self.id()) + self.stacks.append({"name": stack_name}) + + # first run with current dependency layer value + self._run_sync_and_validate_lambda_call(self.dependency_layer, template_path, stack_name) + + # now flip the dependency layer value and re-run the sync & tests + self._run_sync_and_validate_lambda_call(not self.dependency_layer, template_path, stack_name) + + def _run_sync_and_validate_lambda_call(self, dependency_layer: bool, template_path: str, stack_name: str) -> None: + # Run infra sync + sync_command_list = self.get_sync_command_list( + template_file=template_path, + code=False, + watch=False, + dependency_layer=dependency_layer, + stack_name=stack_name, + parameter_overrides="Parameter=Clarity", + image_repository=self.ecr_repo_name, + s3_prefix=self.s3_prefix, + kms_key_id=self.kms_key, + capabilities_list=["CAPABILITY_IAM", "CAPABILITY_AUTO_EXPAND"], + tags="integ=true clarity=yes foo_bar=baz", + ) + sync_process_execute = run_command_with_input(sync_command_list, "y\n".encode()) + self.assertEqual(sync_process_execute.process.returncode, 0) + self.assertIn("Sync infra completed.", str(sync_process_execute.stderr)) + + self.stack_resources = self._get_stacks(stack_name) + lambda_functions = self.stack_resources.get(AWS_LAMBDA_FUNCTION) + for lambda_function in lambda_functions: + lambda_response = json.loads(self._get_lambda_response(lambda_function)) + self.assertIn("message", lambda_response) + self.assertIn("sum", lambda_response) + self.assertEqual(lambda_response.get("message"), "hello world") + self.assertEqual(lambda_response.get("sum"), 12) diff --git a/tests/integration/testdata/sync/infra/before/Java/HelloWorldFunction/pom.xml b/tests/integration/testdata/sync/infra/before/Java/HelloWorldFunction/pom.xml new file mode 100644 index 0000000000..88390a5fc1 --- /dev/null +++ b/tests/integration/testdata/sync/infra/before/Java/HelloWorldFunction/pom.xml @@ -0,0 +1,58 @@ + + 4.0.0 + helloworld + HelloWorld + 1.0 + jar + A sample Hello World created for SAM CLI. + + 11 + 11 + + + + + helloworld + HelloWorldLayer + 1.0 + provided + + + com.amazonaws + aws-lambda-java-core + 1.2.1 + + + com.amazonaws + aws-lambda-java-events + 3.11.0 + + + junit + junit + 4.13.2 + test + + + + + + + org.apache.maven.plugins + maven-shade-plugin + 3.2.4 + + + + + package + + shade + + + + + + + diff --git a/tests/integration/testdata/sync/infra/before/Java/HelloWorldFunction/src/main/java/helloworld/App.java b/tests/integration/testdata/sync/infra/before/Java/HelloWorldFunction/src/main/java/helloworld/App.java new file mode 100644 index 0000000000..f8ea92bbd6 --- /dev/null +++ b/tests/integration/testdata/sync/infra/before/Java/HelloWorldFunction/src/main/java/helloworld/App.java @@ -0,0 +1,53 @@ +package helloworld; + +import java.io.BufferedReader; +import java.io.IOException; +import java.io.InputStreamReader; +import java.net.URL; +import java.util.HashMap; +import java.util.Map; +import java.util.stream.Collectors; + +import com.amazonaws.services.lambda.runtime.Context; +import com.amazonaws.services.lambda.runtime.RequestHandler; +import com.amazonaws.services.lambda.runtime.events.APIGatewayProxyRequestEvent; +import com.amazonaws.services.lambda.runtime.events.APIGatewayProxyResponseEvent; + +import helloworldlayer.SimpleMath; + +/** + * Handler for requests to Lambda function. + */ +public class App implements RequestHandler { + + public APIGatewayProxyResponseEvent handleRequest(final APIGatewayProxyRequestEvent input, final Context context) { + Map headers = new HashMap<>(); + headers.put("Content-Type", "application/json"); + headers.put("X-Custom-Header", "application/json"); + + APIGatewayProxyResponseEvent response = new APIGatewayProxyResponseEvent() + .withHeaders(headers); + + int sumResult = SimpleMath.sum(7, 5); + + try { + final String pageContents = this.getPageContents("https://checkip.amazonaws.com"); + String output = String.format("{ \"message\": \"hello world\", \"location\": \"%s\", \"sum\": %d }", pageContents, sumResult); + + return response + .withStatusCode(200) + .withBody(output); + } catch (IOException e) { + return response + .withBody("{}") + .withStatusCode(500); + } + } + + private String getPageContents(String address) throws IOException{ + URL url = new URL(address); + try(BufferedReader br = new BufferedReader(new InputStreamReader(url.openStream()))) { + return br.lines().collect(Collectors.joining(System.lineSeparator())); + } + } +} diff --git a/tests/integration/testdata/sync/infra/before/Java/HelloWorldLayer/pom.xml b/tests/integration/testdata/sync/infra/before/Java/HelloWorldLayer/pom.xml new file mode 100644 index 0000000000..0dd9b7b873 --- /dev/null +++ b/tests/integration/testdata/sync/infra/before/Java/HelloWorldLayer/pom.xml @@ -0,0 +1,36 @@ + + 4.0.0 + helloworld + HelloWorldLayer + 1.0 + jar + A sample Hello World created for SAM CLI. + + 11 + 11 + + + + + + + + + org.apache.maven.plugins + maven-shade-plugin + 3.2.4 + + + + + package + + shade + + + + + + + diff --git a/tests/integration/testdata/sync/infra/before/Java/HelloWorldLayer/src/main/java/helloworldlayer/SimpleMath.java b/tests/integration/testdata/sync/infra/before/Java/HelloWorldLayer/src/main/java/helloworldlayer/SimpleMath.java new file mode 100644 index 0000000000..1ad779173b --- /dev/null +++ b/tests/integration/testdata/sync/infra/before/Java/HelloWorldLayer/src/main/java/helloworldlayer/SimpleMath.java @@ -0,0 +1,8 @@ +package helloworldlayer; + +public class SimpleMath { + + public static int sum(int a, int b) { + return a + b; + } +} diff --git a/tests/integration/testdata/sync/infra/template-java.yaml b/tests/integration/testdata/sync/infra/template-java.yaml new file mode 100644 index 0000000000..2fc6001dee --- /dev/null +++ b/tests/integration/testdata/sync/infra/template-java.yaml @@ -0,0 +1,27 @@ +AWSTemplateFormatVersion: "2010-09-09" +Transform: AWS::Serverless-2016-10-31 + +Globals: + Function: + Timeout: 30 + +Resources: + HelloWorldFunction: + Type: AWS::Serverless::Function + Properties: + CodeUri: before/Java/HelloWorldFunction + Handler: helloworld.App::handleRequest + Runtime: java11 + MemorySize: 512 + Layers: + - !Ref HelloWorldLayer + + HelloWorldLayer: + Type: AWS::Serverless::LayerVersion + Properties: + ContentUri: before/Java/HelloWorldLayer + CompatibleRuntimes: + - java11 + Metadata: + BuildMethod: java11 + BuildArchitecture: x86_64 diff --git a/tests/unit/commands/sync/test_command.py b/tests/unit/commands/sync/test_command.py index 07c369141f..5846b638d8 100644 --- a/tests/unit/commands/sync/test_command.py +++ b/tests/unit/commands/sync/test_command.py @@ -60,11 +60,13 @@ def setUp(self): @patch("samcli.commands.deploy.deploy_context.DeployContext") @patch("samcli.commands.build.command.os") @patch("samcli.commands.sync.command.manage_stack") + @patch("samcli.commands.sync.command.SyncContext") def test_infra_must_succeed_sync( self, code, watch, auto_dependency_layer, + SyncContextMock, manage_stack_mock, os_mock, DeployContextMock, @@ -84,6 +86,8 @@ def test_infra_must_succeed_sync( PackageContextMock.return_value.__enter__.return_value = package_context_mock deploy_context_mock = Mock() DeployContextMock.return_value.__enter__.return_value = deploy_context_mock + sync_context_mock = Mock() + SyncContextMock.return_value.__enter__.return_value = sync_context_mock do_cli( self.template_file, @@ -188,11 +192,13 @@ def test_infra_must_succeed_sync( @patch("samcli.commands.deploy.deploy_context.DeployContext") @patch("samcli.commands.build.command.os") @patch("samcli.commands.sync.command.manage_stack") + @patch("samcli.commands.sync.command.SyncContext") def test_watch_must_succeed_sync( self, code, watch, auto_dependency_layer, + SyncContextMock, manage_stack_mock, os_mock, DeployContextMock, @@ -212,6 +218,8 @@ def test_watch_must_succeed_sync( PackageContextMock.return_value.__enter__.return_value = package_context_mock deploy_context_mock = Mock() DeployContextMock.return_value.__enter__.return_value = deploy_context_mock + sync_context_mock = Mock() + SyncContextMock.return_value.__enter__.return_value = sync_context_mock do_cli( self.template_file, @@ -314,11 +322,13 @@ def test_watch_must_succeed_sync( @patch("samcli.commands.deploy.deploy_context.DeployContext") @patch("samcli.commands.build.command.os") @patch("samcli.commands.sync.command.manage_stack") + @patch("samcli.commands.sync.command.SyncContext") def test_code_must_succeed_sync( self, code, watch, auto_dependency_layer, + SyncContextMock, manage_stack_mock, os_mock, DeployContextMock, @@ -338,6 +348,8 @@ def test_code_must_succeed_sync( PackageContextMock.return_value.__enter__.return_value = package_context_mock deploy_context_mock = Mock() DeployContextMock.return_value.__enter__.return_value = deploy_context_mock + sync_context_mock = Mock() + SyncContextMock.return_value.__enter__.return_value = sync_context_mock do_cli( self.template_file, diff --git a/tests/unit/commands/sync/test_sync_context.py b/tests/unit/commands/sync/test_sync_context.py new file mode 100644 index 0000000000..9238fede82 --- /dev/null +++ b/tests/unit/commands/sync/test_sync_context.py @@ -0,0 +1,98 @@ +from pathlib import Path +from unittest import TestCase, mock +from unittest.mock import mock_open, call, patch, Mock, MagicMock + +import tomlkit +from parameterized import parameterized, parameterized_class + +from samcli.commands.sync.sync_context import ( + SyncState, + _sync_state_to_toml_document, + SYNC_STATE, + DEPENDENCY_LAYER, + _toml_document_to_sync_state, + SyncContext, +) +from samcli.lib.build.build_graph import DEFAULT_DEPENDENCIES_DIR + + +class TestSyncState(TestCase): + @parameterized.expand([(True,), (False,)]) + def test_sync_state(self, dependency_layer): + sync_state = SyncState(dependency_layer) + self.assertEqual(sync_state.dependency_layer, dependency_layer) + + +TOML_TEMPLATE = """ +[sync_state] +dependency_layer = {dependency_layer}""" + + +class TestSyncStateToTomlSerde(TestCase): + @parameterized.expand([(True,), (False,)]) + def test_sync_state_to_toml(self, dependency_layer): + sync_state = SyncState(dependency_layer) + + toml_document = _sync_state_to_toml_document(sync_state) + self.assertIsNotNone(toml_document) + + sync_state_toml_table = toml_document.get(SYNC_STATE) + self.assertIsNotNone(sync_state_toml_table) + + dependency_layer_toml_field = sync_state_toml_table.get(DEPENDENCY_LAYER) + self.assertEqual(dependency_layer_toml_field, dependency_layer) + + @parameterized.expand([(True,), (False,)]) + def test_toml_to_sync_state(self, dependency_layer): + toml_doc = tomlkit.loads(TOML_TEMPLATE.format(dependency_layer=str(dependency_layer).lower())) + + sync_state = _toml_document_to_sync_state(toml_doc) + self.assertEqual(sync_state.dependency_layer, dependency_layer) + + def test_none_toml_doc_should_return_none(self): + self.assertIsNone(_toml_document_to_sync_state(None)) + + def test_none_toml_table_should_return_none(self): + self.assertIsNone(_toml_document_to_sync_state(tomlkit.document())) + + +@parameterized_class([{"dependency_layer": True}, {"dependency_layer": False}]) +class TestSyncContext(TestCase): + + dependency_layer: bool + + def setUp(self) -> None: + self.build_dir = "build_dir" + self.cache_dir = "cache_dir" + self.sync_context = SyncContext(self.dependency_layer, self.build_dir, self.cache_dir) + + @parameterized.expand([(True,), (False,)]) + @patch("samcli.commands.sync.sync_context.rmtree_if_exists") + def test_sync_context_dependency_layer(self, previous_dependency_layer_value, patched_rmtree_if_exists): + previous_session_state = TOML_TEMPLATE.format(dependency_layer=str(previous_dependency_layer_value).lower()) + with mock.patch("builtins.open", mock_open(read_data=previous_session_state)) as mock_file: + with self.sync_context: + pass + + mock_file.assert_has_calls( + [call().write(tomlkit.dumps(_sync_state_to_toml_document(self.sync_context._current_state)))] + ) + + if previous_dependency_layer_value != self.dependency_layer: + patched_rmtree_if_exists.assert_has_calls( + [ + call(self.sync_context._build_dir), + call(self.sync_context._cache_dir), + call(Path(DEFAULT_DEPENDENCIES_DIR)), + ] + ) + + @patch("samcli.commands.sync.sync_context.rmtree_if_exists") + def test_sync_context_has_no_previous_state_if_file_doesnt_exist(self, patched_rmtree_if_exists): + with mock.patch("builtins.open", mock_open()) as mock_file: + mock_file.side_effect = [OSError("File does not exist"), MagicMock()] + with self.sync_context: + pass + self.assertIsNone(self.sync_context._previous_state) + self.assertIsNotNone(self.sync_context._current_state) + patched_rmtree_if_exists.assert_not_called()