diff --git a/samcli/commands/sync/command.py b/samcli/commands/sync/command.py
index be3297da8e..7270817818 100644
--- a/samcli/commands/sync/command.py
+++ b/samcli/commands/sync/command.py
@@ -27,6 +27,7 @@
)
from samcli.cli.cli_config_file import configuration_option, TomlProvider
from samcli.commands._utils.click_mutex import ClickMutex
+from samcli.commands.sync.sync_context import SyncContext
from samcli.lib.utils.colors import Colored
from samcli.lib.utils.version_checker import check_newer_version
from samcli.lib.bootstrap.bootstrap import manage_stack
@@ -329,14 +330,17 @@ def do_cli(
disable_rollback=False,
poll_delay=poll_delay,
) as deploy_context:
- if watch:
- execute_watch(template_file, build_context, package_context, deploy_context, dependency_layer)
- elif code:
- execute_code_sync(
- template_file, build_context, deploy_context, resource_id, resource, dependency_layer
- )
- else:
- execute_infra_contexts(build_context, package_context, deploy_context)
+ with SyncContext(dependency_layer, build_context.build_dir, build_context.cache_dir):
+ if watch:
+ execute_watch(
+ template_file, build_context, package_context, deploy_context, dependency_layer
+ )
+ elif code:
+ execute_code_sync(
+ template_file, build_context, deploy_context, resource_id, resource, dependency_layer
+ )
+ else:
+ execute_infra_contexts(build_context, package_context, deploy_context)
def execute_infra_contexts(
diff --git a/samcli/commands/sync/sync_context.py b/samcli/commands/sync/sync_context.py
new file mode 100644
index 0000000000..6995494875
--- /dev/null
+++ b/samcli/commands/sync/sync_context.py
@@ -0,0 +1,106 @@
+"""
+Context object used by sync command
+"""
+import logging
+from dataclasses import dataclass
+from pathlib import Path
+from typing import Optional, cast, Dict
+
+import tomlkit
+from tomlkit.api import _TOMLDocument as TOMLDocument
+from tomlkit.items import Item
+
+from samcli.lib.build.build_graph import DEFAULT_DEPENDENCIES_DIR
+from samcli.lib.utils.osutils import rmtree_if_exists
+
+LOG = logging.getLogger(__name__)
+
+
+DEFAULT_SYNC_STATE_FILE_NAME = "sync.toml"
+
+SYNC_STATE = "sync_state"
+DEPENDENCY_LAYER = "dependency_layer"
+
+
+@dataclass
+class SyncState:
+ dependency_layer: bool
+
+
+def _sync_state_to_toml_document(sync_state: SyncState) -> TOMLDocument:
+ sync_state_toml_table = tomlkit.table()
+ sync_state_toml_table[DEPENDENCY_LAYER] = sync_state.dependency_layer
+
+ toml_document = tomlkit.document()
+ toml_document.add((tomlkit.comment("This file is auto generated by SAM CLI sync command")))
+ toml_document.add(SYNC_STATE, cast(Item, sync_state_toml_table))
+
+ return toml_document
+
+
+def _toml_document_to_sync_state(toml_document: Dict) -> Optional[SyncState]:
+ if not toml_document:
+ return None
+
+ sync_state_toml_table = toml_document.get(SYNC_STATE)
+ if not sync_state_toml_table:
+ return None
+
+ return SyncState(sync_state_toml_table.get(DEPENDENCY_LAYER))
+
+
+class SyncContext:
+
+ _current_state: SyncState
+ _previous_state: Optional[SyncState]
+ _build_dir: Path
+ _cache_dir: Path
+ _file_path: Path
+
+ def __init__(self, dependency_layer: bool, build_dir: str, cache_dir: str):
+ self._current_state = SyncState(dependency_layer)
+ self._previous_state = None
+ self._build_dir = Path(build_dir)
+ self._cache_dir = Path(cache_dir)
+ self._file_path = Path(build_dir).parent.joinpath(DEFAULT_SYNC_STATE_FILE_NAME)
+
+ def __enter__(self) -> "SyncContext":
+ self._read()
+ LOG.debug(
+ "Entering sync context, previous state: %s, current state: %s", self._previous_state, self._current_state
+ )
+
+ # if adl parameter is changed between sam sync runs, cleanup build, cache and dependencies folders
+ if self._previous_state and self._previous_state.dependency_layer != self._current_state.dependency_layer:
+ self._cleanup_build_folders()
+
+ return self
+
+ def __exit__(self, *args):
+ self._write()
+
+ def _write(self) -> None:
+ with open(self._file_path, "w+") as file:
+ file.write(tomlkit.dumps(_sync_state_to_toml_document(self._current_state)))
+
+ def _read(self) -> None:
+ try:
+ with open(self._file_path) as file:
+ toml_document = cast(Dict, tomlkit.loads(file.read()))
+ self._previous_state = _toml_document_to_sync_state(toml_document)
+ except OSError:
+ LOG.debug("Missing previous sync state, will create a new file at the end of this execution")
+
+ def _cleanup_build_folders(self):
+ """
+ Cleans up build, cache and dependencies folders for clean start of the next session
+ """
+ LOG.debug("Cleaning up build directory %s", self._build_dir)
+ rmtree_if_exists(self._build_dir)
+
+ LOG.debug("Cleaning up cache directory %s", self._cache_dir)
+ rmtree_if_exists(self._cache_dir)
+
+ dependencies_dir = Path(DEFAULT_DEPENDENCIES_DIR)
+ LOG.debug("Cleaning up dependencies directory: %s", dependencies_dir)
+ rmtree_if_exists(dependencies_dir)
diff --git a/tests/integration/sync/test_sync_infra.py b/tests/integration/sync/test_sync_infra.py
index 14e7bec97c..47c4260ce1 100644
--- a/tests/integration/sync/test_sync_infra.py
+++ b/tests/integration/sync/test_sync_infra.py
@@ -315,3 +315,48 @@ def test_cdk_templates(self, template_file, template_after, function_id, depende
lambda_response = json.loads(self._get_lambda_response(lambda_function))
self.assertIn("extra_message", lambda_response)
self.assertEqual(lambda_response.get("message"), "9")
+
+
+@skipIf(SKIP_SYNC_TESTS, "Skip sync tests in CI/CD only")
+@parameterized_class([{"dependency_layer": True}, {"dependency_layer": False}])
+class TestSyncInfraWithJava(SyncIntegBase):
+ @parameterized.expand(["infra/template-java.yaml"])
+ def test_sync_infra_with_java(self, template_file):
+ """This will test a case where user will flip ADL flag between sync sessions"""
+ template_path = str(self.test_data_path.joinpath(template_file))
+ stack_name = self._method_to_stack_name(self.id())
+ self.stacks.append({"name": stack_name})
+
+ # first run with current dependency layer value
+ self._run_sync_and_validate_lambda_call(self.dependency_layer, template_path, stack_name)
+
+ # now flip the dependency layer value and re-run the sync & tests
+ self._run_sync_and_validate_lambda_call(not self.dependency_layer, template_path, stack_name)
+
+ def _run_sync_and_validate_lambda_call(self, dependency_layer: bool, template_path: str, stack_name: str) -> None:
+ # Run infra sync
+ sync_command_list = self.get_sync_command_list(
+ template_file=template_path,
+ code=False,
+ watch=False,
+ dependency_layer=dependency_layer,
+ stack_name=stack_name,
+ parameter_overrides="Parameter=Clarity",
+ image_repository=self.ecr_repo_name,
+ s3_prefix=self.s3_prefix,
+ kms_key_id=self.kms_key,
+ capabilities_list=["CAPABILITY_IAM", "CAPABILITY_AUTO_EXPAND"],
+ tags="integ=true clarity=yes foo_bar=baz",
+ )
+ sync_process_execute = run_command_with_input(sync_command_list, "y\n".encode())
+ self.assertEqual(sync_process_execute.process.returncode, 0)
+ self.assertIn("Sync infra completed.", str(sync_process_execute.stderr))
+
+ self.stack_resources = self._get_stacks(stack_name)
+ lambda_functions = self.stack_resources.get(AWS_LAMBDA_FUNCTION)
+ for lambda_function in lambda_functions:
+ lambda_response = json.loads(self._get_lambda_response(lambda_function))
+ self.assertIn("message", lambda_response)
+ self.assertIn("sum", lambda_response)
+ self.assertEqual(lambda_response.get("message"), "hello world")
+ self.assertEqual(lambda_response.get("sum"), 12)
diff --git a/tests/integration/testdata/sync/infra/before/Java/HelloWorldFunction/pom.xml b/tests/integration/testdata/sync/infra/before/Java/HelloWorldFunction/pom.xml
new file mode 100644
index 0000000000..88390a5fc1
--- /dev/null
+++ b/tests/integration/testdata/sync/infra/before/Java/HelloWorldFunction/pom.xml
@@ -0,0 +1,58 @@
+
+ 4.0.0
+ helloworld
+ HelloWorld
+ 1.0
+ jar
+ A sample Hello World created for SAM CLI.
+
+ 11
+ 11
+
+
+
+
+ helloworld
+ HelloWorldLayer
+ 1.0
+ provided
+
+
+ com.amazonaws
+ aws-lambda-java-core
+ 1.2.1
+
+
+ com.amazonaws
+ aws-lambda-java-events
+ 3.11.0
+
+
+ junit
+ junit
+ 4.13.2
+ test
+
+
+
+
+
+
+ org.apache.maven.plugins
+ maven-shade-plugin
+ 3.2.4
+
+
+
+
+ package
+
+ shade
+
+
+
+
+
+
+
diff --git a/tests/integration/testdata/sync/infra/before/Java/HelloWorldFunction/src/main/java/helloworld/App.java b/tests/integration/testdata/sync/infra/before/Java/HelloWorldFunction/src/main/java/helloworld/App.java
new file mode 100644
index 0000000000..f8ea92bbd6
--- /dev/null
+++ b/tests/integration/testdata/sync/infra/before/Java/HelloWorldFunction/src/main/java/helloworld/App.java
@@ -0,0 +1,53 @@
+package helloworld;
+
+import java.io.BufferedReader;
+import java.io.IOException;
+import java.io.InputStreamReader;
+import java.net.URL;
+import java.util.HashMap;
+import java.util.Map;
+import java.util.stream.Collectors;
+
+import com.amazonaws.services.lambda.runtime.Context;
+import com.amazonaws.services.lambda.runtime.RequestHandler;
+import com.amazonaws.services.lambda.runtime.events.APIGatewayProxyRequestEvent;
+import com.amazonaws.services.lambda.runtime.events.APIGatewayProxyResponseEvent;
+
+import helloworldlayer.SimpleMath;
+
+/**
+ * Handler for requests to Lambda function.
+ */
+public class App implements RequestHandler {
+
+ public APIGatewayProxyResponseEvent handleRequest(final APIGatewayProxyRequestEvent input, final Context context) {
+ Map headers = new HashMap<>();
+ headers.put("Content-Type", "application/json");
+ headers.put("X-Custom-Header", "application/json");
+
+ APIGatewayProxyResponseEvent response = new APIGatewayProxyResponseEvent()
+ .withHeaders(headers);
+
+ int sumResult = SimpleMath.sum(7, 5);
+
+ try {
+ final String pageContents = this.getPageContents("https://checkip.amazonaws.com");
+ String output = String.format("{ \"message\": \"hello world\", \"location\": \"%s\", \"sum\": %d }", pageContents, sumResult);
+
+ return response
+ .withStatusCode(200)
+ .withBody(output);
+ } catch (IOException e) {
+ return response
+ .withBody("{}")
+ .withStatusCode(500);
+ }
+ }
+
+ private String getPageContents(String address) throws IOException{
+ URL url = new URL(address);
+ try(BufferedReader br = new BufferedReader(new InputStreamReader(url.openStream()))) {
+ return br.lines().collect(Collectors.joining(System.lineSeparator()));
+ }
+ }
+}
diff --git a/tests/integration/testdata/sync/infra/before/Java/HelloWorldLayer/pom.xml b/tests/integration/testdata/sync/infra/before/Java/HelloWorldLayer/pom.xml
new file mode 100644
index 0000000000..0dd9b7b873
--- /dev/null
+++ b/tests/integration/testdata/sync/infra/before/Java/HelloWorldLayer/pom.xml
@@ -0,0 +1,36 @@
+
+ 4.0.0
+ helloworld
+ HelloWorldLayer
+ 1.0
+ jar
+ A sample Hello World created for SAM CLI.
+
+ 11
+ 11
+
+
+
+
+
+
+
+
+ org.apache.maven.plugins
+ maven-shade-plugin
+ 3.2.4
+
+
+
+
+ package
+
+ shade
+
+
+
+
+
+
+
diff --git a/tests/integration/testdata/sync/infra/before/Java/HelloWorldLayer/src/main/java/helloworldlayer/SimpleMath.java b/tests/integration/testdata/sync/infra/before/Java/HelloWorldLayer/src/main/java/helloworldlayer/SimpleMath.java
new file mode 100644
index 0000000000..1ad779173b
--- /dev/null
+++ b/tests/integration/testdata/sync/infra/before/Java/HelloWorldLayer/src/main/java/helloworldlayer/SimpleMath.java
@@ -0,0 +1,8 @@
+package helloworldlayer;
+
+public class SimpleMath {
+
+ public static int sum(int a, int b) {
+ return a + b;
+ }
+}
diff --git a/tests/integration/testdata/sync/infra/template-java.yaml b/tests/integration/testdata/sync/infra/template-java.yaml
new file mode 100644
index 0000000000..2fc6001dee
--- /dev/null
+++ b/tests/integration/testdata/sync/infra/template-java.yaml
@@ -0,0 +1,27 @@
+AWSTemplateFormatVersion: "2010-09-09"
+Transform: AWS::Serverless-2016-10-31
+
+Globals:
+ Function:
+ Timeout: 30
+
+Resources:
+ HelloWorldFunction:
+ Type: AWS::Serverless::Function
+ Properties:
+ CodeUri: before/Java/HelloWorldFunction
+ Handler: helloworld.App::handleRequest
+ Runtime: java11
+ MemorySize: 512
+ Layers:
+ - !Ref HelloWorldLayer
+
+ HelloWorldLayer:
+ Type: AWS::Serverless::LayerVersion
+ Properties:
+ ContentUri: before/Java/HelloWorldLayer
+ CompatibleRuntimes:
+ - java11
+ Metadata:
+ BuildMethod: java11
+ BuildArchitecture: x86_64
diff --git a/tests/unit/commands/sync/test_command.py b/tests/unit/commands/sync/test_command.py
index 07c369141f..5846b638d8 100644
--- a/tests/unit/commands/sync/test_command.py
+++ b/tests/unit/commands/sync/test_command.py
@@ -60,11 +60,13 @@ def setUp(self):
@patch("samcli.commands.deploy.deploy_context.DeployContext")
@patch("samcli.commands.build.command.os")
@patch("samcli.commands.sync.command.manage_stack")
+ @patch("samcli.commands.sync.command.SyncContext")
def test_infra_must_succeed_sync(
self,
code,
watch,
auto_dependency_layer,
+ SyncContextMock,
manage_stack_mock,
os_mock,
DeployContextMock,
@@ -84,6 +86,8 @@ def test_infra_must_succeed_sync(
PackageContextMock.return_value.__enter__.return_value = package_context_mock
deploy_context_mock = Mock()
DeployContextMock.return_value.__enter__.return_value = deploy_context_mock
+ sync_context_mock = Mock()
+ SyncContextMock.return_value.__enter__.return_value = sync_context_mock
do_cli(
self.template_file,
@@ -188,11 +192,13 @@ def test_infra_must_succeed_sync(
@patch("samcli.commands.deploy.deploy_context.DeployContext")
@patch("samcli.commands.build.command.os")
@patch("samcli.commands.sync.command.manage_stack")
+ @patch("samcli.commands.sync.command.SyncContext")
def test_watch_must_succeed_sync(
self,
code,
watch,
auto_dependency_layer,
+ SyncContextMock,
manage_stack_mock,
os_mock,
DeployContextMock,
@@ -212,6 +218,8 @@ def test_watch_must_succeed_sync(
PackageContextMock.return_value.__enter__.return_value = package_context_mock
deploy_context_mock = Mock()
DeployContextMock.return_value.__enter__.return_value = deploy_context_mock
+ sync_context_mock = Mock()
+ SyncContextMock.return_value.__enter__.return_value = sync_context_mock
do_cli(
self.template_file,
@@ -314,11 +322,13 @@ def test_watch_must_succeed_sync(
@patch("samcli.commands.deploy.deploy_context.DeployContext")
@patch("samcli.commands.build.command.os")
@patch("samcli.commands.sync.command.manage_stack")
+ @patch("samcli.commands.sync.command.SyncContext")
def test_code_must_succeed_sync(
self,
code,
watch,
auto_dependency_layer,
+ SyncContextMock,
manage_stack_mock,
os_mock,
DeployContextMock,
@@ -338,6 +348,8 @@ def test_code_must_succeed_sync(
PackageContextMock.return_value.__enter__.return_value = package_context_mock
deploy_context_mock = Mock()
DeployContextMock.return_value.__enter__.return_value = deploy_context_mock
+ sync_context_mock = Mock()
+ SyncContextMock.return_value.__enter__.return_value = sync_context_mock
do_cli(
self.template_file,
diff --git a/tests/unit/commands/sync/test_sync_context.py b/tests/unit/commands/sync/test_sync_context.py
new file mode 100644
index 0000000000..9238fede82
--- /dev/null
+++ b/tests/unit/commands/sync/test_sync_context.py
@@ -0,0 +1,98 @@
+from pathlib import Path
+from unittest import TestCase, mock
+from unittest.mock import mock_open, call, patch, Mock, MagicMock
+
+import tomlkit
+from parameterized import parameterized, parameterized_class
+
+from samcli.commands.sync.sync_context import (
+ SyncState,
+ _sync_state_to_toml_document,
+ SYNC_STATE,
+ DEPENDENCY_LAYER,
+ _toml_document_to_sync_state,
+ SyncContext,
+)
+from samcli.lib.build.build_graph import DEFAULT_DEPENDENCIES_DIR
+
+
+class TestSyncState(TestCase):
+ @parameterized.expand([(True,), (False,)])
+ def test_sync_state(self, dependency_layer):
+ sync_state = SyncState(dependency_layer)
+ self.assertEqual(sync_state.dependency_layer, dependency_layer)
+
+
+TOML_TEMPLATE = """
+[sync_state]
+dependency_layer = {dependency_layer}"""
+
+
+class TestSyncStateToTomlSerde(TestCase):
+ @parameterized.expand([(True,), (False,)])
+ def test_sync_state_to_toml(self, dependency_layer):
+ sync_state = SyncState(dependency_layer)
+
+ toml_document = _sync_state_to_toml_document(sync_state)
+ self.assertIsNotNone(toml_document)
+
+ sync_state_toml_table = toml_document.get(SYNC_STATE)
+ self.assertIsNotNone(sync_state_toml_table)
+
+ dependency_layer_toml_field = sync_state_toml_table.get(DEPENDENCY_LAYER)
+ self.assertEqual(dependency_layer_toml_field, dependency_layer)
+
+ @parameterized.expand([(True,), (False,)])
+ def test_toml_to_sync_state(self, dependency_layer):
+ toml_doc = tomlkit.loads(TOML_TEMPLATE.format(dependency_layer=str(dependency_layer).lower()))
+
+ sync_state = _toml_document_to_sync_state(toml_doc)
+ self.assertEqual(sync_state.dependency_layer, dependency_layer)
+
+ def test_none_toml_doc_should_return_none(self):
+ self.assertIsNone(_toml_document_to_sync_state(None))
+
+ def test_none_toml_table_should_return_none(self):
+ self.assertIsNone(_toml_document_to_sync_state(tomlkit.document()))
+
+
+@parameterized_class([{"dependency_layer": True}, {"dependency_layer": False}])
+class TestSyncContext(TestCase):
+
+ dependency_layer: bool
+
+ def setUp(self) -> None:
+ self.build_dir = "build_dir"
+ self.cache_dir = "cache_dir"
+ self.sync_context = SyncContext(self.dependency_layer, self.build_dir, self.cache_dir)
+
+ @parameterized.expand([(True,), (False,)])
+ @patch("samcli.commands.sync.sync_context.rmtree_if_exists")
+ def test_sync_context_dependency_layer(self, previous_dependency_layer_value, patched_rmtree_if_exists):
+ previous_session_state = TOML_TEMPLATE.format(dependency_layer=str(previous_dependency_layer_value).lower())
+ with mock.patch("builtins.open", mock_open(read_data=previous_session_state)) as mock_file:
+ with self.sync_context:
+ pass
+
+ mock_file.assert_has_calls(
+ [call().write(tomlkit.dumps(_sync_state_to_toml_document(self.sync_context._current_state)))]
+ )
+
+ if previous_dependency_layer_value != self.dependency_layer:
+ patched_rmtree_if_exists.assert_has_calls(
+ [
+ call(self.sync_context._build_dir),
+ call(self.sync_context._cache_dir),
+ call(Path(DEFAULT_DEPENDENCIES_DIR)),
+ ]
+ )
+
+ @patch("samcli.commands.sync.sync_context.rmtree_if_exists")
+ def test_sync_context_has_no_previous_state_if_file_doesnt_exist(self, patched_rmtree_if_exists):
+ with mock.patch("builtins.open", mock_open()) as mock_file:
+ mock_file.side_effect = [OSError("File does not exist"), MagicMock()]
+ with self.sync_context:
+ pass
+ self.assertIsNone(self.sync_context._previous_state)
+ self.assertIsNotNone(self.sync_context._current_state)
+ patched_rmtree_if_exists.assert_not_called()