diff --git a/samcli/lib/cookiecutter/interactive_flow.py b/samcli/lib/cookiecutter/interactive_flow.py index 486e8c4d30..996ac89ce3 100644 --- a/samcli/lib/cookiecutter/interactive_flow.py +++ b/samcli/lib/cookiecutter/interactive_flow.py @@ -1,5 +1,6 @@ """A flow of questions to be asked to the user in an interactive way.""" from typing import Any, Dict, Optional + from .question import Question @@ -40,7 +41,10 @@ def advance_to_next_question(self, current_answer: Optional[Any] = None) -> Opti self._current_question = self._questions.get(next_question_key) if next_question_key else None return self._current_question - def run(self, context: Dict) -> Dict: + def run( + self, + context: Dict, + ) -> Dict: """ starts the flow, collects user's answers to the question and return a new copy of the passed context with the answers appended to the copy @@ -49,14 +53,17 @@ def run(self, context: Dict) -> Dict: ---------- context: Dict The cookiecutter context before prompting this flow's questions + The context can be used to provide default values, and support both str keys and List[str] keys. - Returns: A new copy of the context with user's answers added to the copy such that each answer is - associated to the key of the corresponding question + Returns + ------- + A new copy of the context with user's answers added to the copy such that each answer is + associated to the key of the corresponding question """ context = context.copy() question = self.advance_to_next_question() while question: - answer = question.ask() + answer = question.ask(context=context) context[question.key] = answer question = self.advance_to_next_question(answer) return context diff --git a/samcli/lib/cookiecutter/interactive_flow_creator.py b/samcli/lib/cookiecutter/interactive_flow_creator.py index d1a227f1c8..d861174951 100644 --- a/samcli/lib/cookiecutter/interactive_flow_creator.py +++ b/samcli/lib/cookiecutter/interactive_flow_creator.py @@ -42,6 +42,19 @@ def create_flow(flow_definition_path: str, extra_context: Optional[Dict] = None) "True": "key of the question to jump to if the user answered 'Yes'", "False": "key of the question to jump to if the user answered 'Yes'", } + "default": "default_answer", + # the default value can also be loaded from cookiecutter context + # with a key path whose key path item can be loaded from cookiecutter as well. + "default": { + "keyPath": [ + { + "valueOf": "key-of-another-question" + }, + "pipeline_user" + ] + } + # assuming the answer of "key-of-another-question" is "ABC" + # the default value will be load from cookiecutter context with key "['ABC', 'pipeline_user]" }, ... ] @@ -63,15 +76,18 @@ def _load_questions( questions: Dict[str, Question] = {} questions_definition = InteractiveFlowCreator._parse_questions_definition(flow_definition_path, extra_context) - for question in questions_definition.get("questions"): - q = QuestionFactory.create_question_from_json(question) - if not first_question_key: - first_question_key = q.key - elif previous_question and not previous_question.default_next_question_key: - previous_question.set_default_next_question_key(q.key) - questions[q.key] = q - previous_question = q - return questions, first_question_key + try: + for question in questions_definition.get("questions"): + q = QuestionFactory.create_question_from_json(question) + if not first_question_key: + first_question_key = q.key + elif previous_question and not previous_question.default_next_question_key: + previous_question.set_default_next_question_key(q.key) + questions[q.key] = q + previous_question = q + return questions, first_question_key + except (KeyError, ValueError, AttributeError, TypeError) as ex: + raise QuestionsFailedParsingException(f"Failed to parse questions: {str(ex)}") from ex @staticmethod def _parse_questions_definition(file_path, extra_context: Optional[Dict] = None): diff --git a/samcli/lib/cookiecutter/question.py b/samcli/lib/cookiecutter/question.py index 71c30d98da..786836a400 100644 --- a/samcli/lib/cookiecutter/question.py +++ b/samcli/lib/cookiecutter/question.py @@ -1,6 +1,7 @@ """ This module represents the questions to ask to the user to fulfill the cookiecutter context. """ from enum import Enum -from typing import Any, Dict, List, Optional, Type +from typing import Any, Dict, List, Optional, Type, Union + import click @@ -26,8 +27,10 @@ class Question: The text to prompt to the user _required: bool Whether the user must provide an answer for this question or not. - _default_answer: Optional[str] - A default answer that is suggested to the user + _default_answer: Optional[Union[str, Dict]] + A default answer that is suggested to the user, + it can be directly provided (a string) + or resolved from cookiecutter context (a Dict, in the form of {"keyPath": [...,]}) _next_question_map: Optional[Dict[str, str]] A simple branching mechanism, it refers to what is the next question to ask the user if he answered a particular answer to this question. this map is in the form of {answer: next-question-key}. this @@ -48,7 +51,7 @@ def __init__( self, key: str, text: str, - default: Optional[str] = None, + default: Optional[Union[str, Dict]] = None, is_required: Optional[bool] = None, next_question_map: Optional[Dict[str, str]] = None, default_next_question_key: Optional[str] = None, @@ -87,8 +90,21 @@ def next_question_map(self): def default_next_question_key(self): return self._default_next_question_key - def ask(self) -> Any: - return click.prompt(text=self._text, default=self._default_answer) + def ask(self, context: Dict) -> Any: + """ + prompt the user this question + + Parameters + ---------- + context + The cookiecutter context dictionary containing previous questions' answers and default values + + Returns + ------- + The user provided answer. + """ + resolved_default_answer = self._resolve_default_answer(context) + return click.prompt(text=self._text, default=resolved_default_answer) def get_next_question_key(self, answer: Any) -> Optional[str]: # _next_question_map is a Dict[str(answer), str(next question key)] @@ -99,14 +115,83 @@ def get_next_question_key(self, answer: Any) -> Optional[str]: def set_default_next_question_key(self, next_question_key): self._default_next_question_key = next_question_key + def _resolve_key_path(self, key_path: List, context: Dict) -> List[str]: + """ + key_path element is a list of str and Dict. + When the element is a dict, in the form of { "valueOf": question_key }, + it means it refers to the answer to another questions. + _resolve_key_path() will replace such dict with the actual question answer + + Parameters + ---------- + key_path + The key_path list containing str and dict + context + The cookiecutter context containing answers to previous answered questions + Returns + ------- + The key_path list containing only str + """ + resolved_key_path: List[str] = [] + for unresolved_key in key_path: + if isinstance(unresolved_key, str): + resolved_key_path.append(unresolved_key) + elif isinstance(unresolved_key, dict): + if "valueOf" not in unresolved_key: + raise KeyError(f'Missing key "valueOf" in question default keyPath element "{unresolved_key}".') + query_question_key: str = unresolved_key.get("valueOf", "") + if query_question_key not in context: + raise KeyError( + f'Invalid question key "{query_question_key}" referenced ' + f"in default answer of question {self.key}" + ) + resolved_key_path.append(context[query_question_key]) + else: + raise ValueError(f'Invalid value "{unresolved_key}" in key path') + return resolved_key_path + + def _resolve_default_answer(self, context: Dict) -> Optional[Any]: + """ + a question may have a default answer provided directly through the "default_answer" value + or indirectly from cookiecutter context using a key path + + Parameters + ---------- + context + Cookiecutter context used to resolve default values and answered questions' answers. + + Raises + ------ + KeyError + When default value depends on the answer to a non-existent question + ValueError + The default value is malformed + + Returns + ------- + Optional default answer, it might be resolved from cookiecutter context using specified key path. + + """ + if isinstance(self._default_answer, dict): + # load value using key path from cookiecutter + if "keyPath" not in self._default_answer: + raise KeyError(f'Missing key "keyPath" in question default "{self._default_answer}".') + unresolved_key_path = self._default_answer.get("keyPath", []) + if not isinstance(unresolved_key_path, list): + raise ValueError(f'Invalid default answer "{self._default_answer}" for question {self.key}') + + return context.get(str(self._resolve_key_path(unresolved_key_path, context))) + + return self._default_answer + class Info(Question): - def ask(self) -> None: + def ask(self, context: Dict) -> None: return click.echo(message=self._text) class Confirm(Question): - def ask(self) -> bool: + def ask(self, context: Dict) -> bool: return click.confirm(text=self._text) @@ -126,7 +211,8 @@ def __init__( self._options = options super().__init__(key, text, default, is_required, next_question_map, default_next_question_key) - def ask(self) -> str: + def ask(self, context: Dict) -> str: + resolved_default_answer = self._resolve_default_answer(context) click.echo(self._text) for index, option in enumerate(self._options): click.echo(f"\t{index + 1} - {option}") @@ -134,7 +220,7 @@ def ask(self) -> str: choices = list(map(str, options_indexes)) choice = click.prompt( text="Choice", - default=self._default_answer, + default=resolved_default_answer, show_choices=False, type=click.Choice(choices), ) @@ -145,7 +231,6 @@ def _get_options_indexes(self, base: int = 0) -> List[int]: class QuestionFactory: - question_classes: Dict[QuestionKind, Type[Question]] = { QuestionKind.info: Info, QuestionKind.choice: Choice, diff --git a/tests/unit/lib/cookiecutter/test_interactive_flow.py b/tests/unit/lib/cookiecutter/test_interactive_flow.py index ed52626451..47ed0ec2b6 100644 --- a/tests/unit/lib/cookiecutter/test_interactive_flow.py +++ b/tests/unit/lib/cookiecutter/test_interactive_flow.py @@ -1,3 +1,4 @@ +from pathlib import Path from unittest import TestCase from unittest.mock import patch from samcli.lib.cookiecutter.interactive_flow import InteractiveFlow @@ -49,3 +50,26 @@ def test_run(self, mock_3rd_q, mock_2nd_q, mock_1st_q): mock_3rd_q.assert_called_once() self.assertEqual(expected_context, actual_context) self.assertIsNot(actual_context, initial_context) # shouldn't modify the input, it should copy and return new + + @patch.object(Question, "ask") + @patch.object(Confirm, "ask") + @patch.object(Choice, "ask") + def test_run_with_preloaded_default_values(self, mock_3rd_q, mock_2nd_q, mock_1st_q): + + mock_1st_q.return_value = "answer1" + mock_2nd_q.return_value = False + mock_3rd_q.return_value = "option1" + + initial_context = {"key": "value", "['beta', 'bootstrap', 'x']": "y"} + + actual_context = self.flow.run(initial_context) + + mock_1st_q.assert_called_once() + mock_2nd_q.assert_called_once() + mock_3rd_q.assert_called_once() + + self.assertEqual( + {"1st": "answer1", "2nd": False, "3rd": "option1", "['beta', 'bootstrap', 'x']": "y", "key": "value"}, + actual_context, + ) + self.assertIsNot(actual_context, initial_context) # shouldn't modify the input, it should copy and return new diff --git a/tests/unit/lib/cookiecutter/test_question.py b/tests/unit/lib/cookiecutter/test_question.py index e59a76b782..c46a37fa43 100644 --- a/tests/unit/lib/cookiecutter/test_question.py +++ b/tests/unit/lib/cookiecutter/test_question.py @@ -1,5 +1,9 @@ +from typing import List, Union, Dict from unittest import TestCase -from unittest.mock import ANY, patch +from unittest.mock import ANY, patch, Mock + +from parameterized import parameterized + from samcli.lib.cookiecutter.question import Question, QuestionKind, Choice, Confirm, Info, QuestionFactory @@ -27,6 +31,18 @@ def setUp(self): default_next_question_key=self._ANY_DEFAULT_NEXT_QUESTION_KEY, ) + def get_question_with_default_from_cookiecutter_context_using_keypath( + self, key_path: List[Union[str, Dict]] + ) -> Question: + return Question( + text=self._ANY_TEXT, + key=self._ANY_KEY, + default={"keyPath": key_path}, + is_required=True, + next_question_map=self._ANY_NEXT_QUESTION_MAP, + default_next_question_key=self._ANY_DEFAULT_NEXT_QUESTION_KEY, + ) + def test_creating_questions(self): q = Question(text=self._ANY_TEXT, key=self._ANY_KEY) self.assertEqual(q.text, self._ANY_TEXT) @@ -61,10 +77,80 @@ def test_get_next_question_key(self): @patch("samcli.lib.cookiecutter.question.click") def test_ask(self, mock_click): mock_click.prompt.return_value = self._ANY_ANSWER - answer = self.question.ask() + answer = self.question.ask({}) self.assertEqual(answer, self._ANY_ANSWER) mock_click.prompt.assert_called_once_with(text=self.question.text, default=self.question.default_answer) + @patch("samcli.lib.cookiecutter.question.click") + def test_ask_resolves_from_cookiecutter_context(self, mock_click): + # Setup + expected_default_value = Mock() + previous_question_key = "this is a question" + previous_question_answer = "this is an answer" + context = { + "['x', 'this is an answer']": expected_default_value, + previous_question_key: previous_question_answer, + } + question = self.get_question_with_default_from_cookiecutter_context_using_keypath( + ["x", {"valueOf": previous_question_key}] + ) + + # Trigger + question.ask(context=context) + + # Verify + mock_click.prompt.assert_called_once_with(text=self.question.text, default=expected_default_value) + + @patch("samcli.lib.cookiecutter.question.click") + def test_ask_resolves_from_cookiecutter_context_non_exist_key_path(self, mock_click): + # Setup + context = {} + question = self.get_question_with_default_from_cookiecutter_context_using_keypath(["y"]) + + # Trigger + question.ask(context=context) + + # Verify + mock_click.prompt.assert_called_once_with(text=self.question.text, default=None) + + def test_ask_resolves_from_cookiecutter_context_non_exist_question_key(self): + # Setup + expected_default_value = Mock() + previous_question_key = "this is a question" + previous_question_answer = "this is an answer" + context = { + "['x', 'this is an answer']": expected_default_value, + previous_question_key: previous_question_answer, + } + question = self.get_question_with_default_from_cookiecutter_context_using_keypath( + ["x", {"valueOf": "non_exist_question_key"}] + ) + + # Trigger + with self.assertRaises(KeyError): + question.ask(context=context) + + @parameterized.expand([("this should have been a list"), ([1],), ({},)]) + def test_ask_resolves_from_cookiecutter_context_with_key_path_not_a_list(self, key_path): + # Setup + context = {} + question = self.get_question_with_default_from_cookiecutter_context_using_keypath(key_path) + + # Trigger + with self.assertRaises(ValueError): + question.ask(context=context) + + @parameterized.expand([({"keyPath123": Mock()},), ({"keyPath": [{"valueOf123": Mock()}]},)]) + def test_ask_resolves_from_cookiecutter_context_with_default_object_missing_keys(self, default_object): + # Setup + context = {} + question = self.get_question_with_default_from_cookiecutter_context_using_keypath([]) + question._default_answer = default_object + + # Trigger + with self.assertRaises(KeyError): + question.ask(context=context) + class TestChoice(TestCase): def setUp(self): @@ -99,7 +185,7 @@ def test_get_options_indexes_with_different_bases(self): @patch("samcli.lib.cookiecutter.question.click") def test_ask(self, mock_click, mock_choice): mock_click.prompt.return_value = 2 - answer = self.question.ask() + answer = self.question.ask({}) self.assertEqual(answer, TestQuestion._ANY_OPTIONS[1]) # we deduct one from user's choice (base 1 vs base 0) mock_click.prompt.assert_called_once_with( text="Choice", default=self.question.default_answer, show_choices=False, type=ANY @@ -112,7 +198,7 @@ class TestInfo(TestCase): def test_ask(self, mock_click): q = Info(text=TestQuestion._ANY_TEXT, key=TestQuestion._ANY_KEY) mock_click.echo.return_value = None - answer = q.ask() + answer = q.ask({}) self.assertIsNone(answer) mock_click.echo.assert_called_once_with(message=q.text) @@ -122,7 +208,7 @@ class TestConfirm(TestCase): def test_ask(self, mock_click): q = Confirm(text=TestQuestion._ANY_TEXT, key=TestQuestion._ANY_KEY) mock_click.confirm.return_value = True - answer = q.ask() + answer = q.ask({}) self.assertTrue(answer) mock_click.confirm.assert_called_once_with(text=q.text)