Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 11 additions & 4 deletions samcli/lib/cookiecutter/interactive_flow.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""A flow of questions to be asked to the user in an interactive way."""
from typing import Any, Dict, Optional

from .question import Question


Expand Down Expand Up @@ -40,7 +41,10 @@ def advance_to_next_question(self, current_answer: Optional[Any] = None) -> Opti
self._current_question = self._questions.get(next_question_key) if next_question_key else None
return self._current_question

def run(self, context: Dict) -> Dict:
def run(
self,
context: Dict,
) -> Dict:
"""
starts the flow, collects user's answers to the question and return a new copy of the passed context
with the answers appended to the copy
Expand All @@ -49,14 +53,17 @@ def run(self, context: Dict) -> Dict:
----------
context: Dict
The cookiecutter context before prompting this flow's questions
The context can be used to provide default values, and support both str keys and List[str] keys.

Returns: A new copy of the context with user's answers added to the copy such that each answer is
associated to the key of the corresponding question
Returns
-------
A new copy of the context with user's answers added to the copy such that each answer is
associated to the key of the corresponding question
"""
context = context.copy()
question = self.advance_to_next_question()
while question:
answer = question.ask()
answer = question.ask(context=context)
context[question.key] = answer
question = self.advance_to_next_question(answer)
return context
34 changes: 25 additions & 9 deletions samcli/lib/cookiecutter/interactive_flow_creator.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,19 @@ def create_flow(flow_definition_path: str, extra_context: Optional[Dict] = None)
"True": "key of the question to jump to if the user answered 'Yes'",
"False": "key of the question to jump to if the user answered 'Yes'",
}
"default": "default_answer",
# the default value can also be loaded from cookiecutter context
# with a key path whose key path item can be loaded from cookiecutter as well.
"default": {
"keyPath": [
{
"valueOf": "key-of-another-question"
},
"pipeline_user"
]
}
# assuming the answer of "key-of-another-question" is "ABC"
# the default value will be load from cookiecutter context with key "['ABC', 'pipeline_user]"
},
...
]
Expand All @@ -63,15 +76,18 @@ def _load_questions(
questions: Dict[str, Question] = {}
questions_definition = InteractiveFlowCreator._parse_questions_definition(flow_definition_path, extra_context)

for question in questions_definition.get("questions"):
q = QuestionFactory.create_question_from_json(question)
if not first_question_key:
first_question_key = q.key
elif previous_question and not previous_question.default_next_question_key:
previous_question.set_default_next_question_key(q.key)
questions[q.key] = q
previous_question = q
return questions, first_question_key
try:
for question in questions_definition.get("questions"):
q = QuestionFactory.create_question_from_json(question)
if not first_question_key:
first_question_key = q.key
elif previous_question and not previous_question.default_next_question_key:
previous_question.set_default_next_question_key(q.key)
questions[q.key] = q
previous_question = q
return questions, first_question_key
except (KeyError, ValueError, AttributeError, TypeError) as ex:
raise QuestionsFailedParsingException(f"Failed to parse questions: {str(ex)}") from ex

@staticmethod
def _parse_questions_definition(file_path, extra_context: Optional[Dict] = None):
Expand Down
107 changes: 96 additions & 11 deletions samcli/lib/cookiecutter/question.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
""" This module represents the questions to ask to the user to fulfill the cookiecutter context. """
from enum import Enum
from typing import Any, Dict, List, Optional, Type
from typing import Any, Dict, List, Optional, Type, Union

import click


Expand All @@ -26,8 +27,10 @@ class Question:
The text to prompt to the user
_required: bool
Whether the user must provide an answer for this question or not.
_default_answer: Optional[str]
A default answer that is suggested to the user
_default_answer: Optional[Union[str, Dict]]
A default answer that is suggested to the user,
it can be directly provided (a string)
or resolved from cookiecutter context (a Dict, in the form of {"keyPath": [...,]})
_next_question_map: Optional[Dict[str, str]]
A simple branching mechanism, it refers to what is the next question to ask the user if he answered
a particular answer to this question. this map is in the form of {answer: next-question-key}. this
Expand All @@ -48,7 +51,7 @@ def __init__(
self,
key: str,
text: str,
default: Optional[str] = None,
default: Optional[Union[str, Dict]] = None,
is_required: Optional[bool] = None,
next_question_map: Optional[Dict[str, str]] = None,
default_next_question_key: Optional[str] = None,
Expand Down Expand Up @@ -87,8 +90,21 @@ def next_question_map(self):
def default_next_question_key(self):
return self._default_next_question_key

def ask(self) -> Any:
return click.prompt(text=self._text, default=self._default_answer)
def ask(self, context: Dict) -> Any:
"""
prompt the user this question

Parameters
----------
context
The cookiecutter context dictionary containing previous questions' answers and default values

Returns
-------
The user provided answer.
"""
resolved_default_answer = self._resolve_default_answer(context)
return click.prompt(text=self._text, default=resolved_default_answer)

def get_next_question_key(self, answer: Any) -> Optional[str]:
# _next_question_map is a Dict[str(answer), str(next question key)]
Expand All @@ -99,14 +115,83 @@ def get_next_question_key(self, answer: Any) -> Optional[str]:
def set_default_next_question_key(self, next_question_key):
self._default_next_question_key = next_question_key

def _resolve_key_path(self, key_path: List, context: Dict) -> List[str]:
"""
key_path element is a list of str and Dict.
When the element is a dict, in the form of { "valueOf": question_key },
it means it refers to the answer to another questions.
_resolve_key_path() will replace such dict with the actual question answer

Parameters
----------
key_path
The key_path list containing str and dict
context
The cookiecutter context containing answers to previous answered questions
Returns
-------
The key_path list containing only str
"""
resolved_key_path: List[str] = []
for unresolved_key in key_path:
if isinstance(unresolved_key, str):
resolved_key_path.append(unresolved_key)
elif isinstance(unresolved_key, dict):
if "valueOf" not in unresolved_key:
raise KeyError(f'Missing key "valueOf" in question default keyPath element "{unresolved_key}".')
query_question_key: str = unresolved_key.get("valueOf", "")
if query_question_key not in context:
raise KeyError(
f'Invalid question key "{query_question_key}" referenced '
f"in default answer of question {self.key}"
)
resolved_key_path.append(context[query_question_key])
else:
raise ValueError(f'Invalid value "{unresolved_key}" in key path')
return resolved_key_path

def _resolve_default_answer(self, context: Dict) -> Optional[Any]:
"""
a question may have a default answer provided directly through the "default_answer" value
or indirectly from cookiecutter context using a key path

Parameters
----------
context
Cookiecutter context used to resolve default values and answered questions' answers.

Raises
------
KeyError
When default value depends on the answer to a non-existent question
ValueError
The default value is malformed

Returns
-------
Optional default answer, it might be resolved from cookiecutter context using specified key path.

"""
if isinstance(self._default_answer, dict):
# load value using key path from cookiecutter
if "keyPath" not in self._default_answer:
raise KeyError(f'Missing key "keyPath" in question default "{self._default_answer}".')
unresolved_key_path = self._default_answer.get("keyPath", [])
if not isinstance(unresolved_key_path, list):
raise ValueError(f'Invalid default answer "{self._default_answer}" for question {self.key}')

return context.get(str(self._resolve_key_path(unresolved_key_path, context)))

return self._default_answer


class Info(Question):
def ask(self) -> None:
def ask(self, context: Dict) -> None:
return click.echo(message=self._text)


class Confirm(Question):
def ask(self) -> bool:
def ask(self, context: Dict) -> bool:
return click.confirm(text=self._text)


Expand All @@ -126,15 +211,16 @@ def __init__(
self._options = options
super().__init__(key, text, default, is_required, next_question_map, default_next_question_key)

def ask(self) -> str:
def ask(self, context: Dict) -> str:
resolved_default_answer = self._resolve_default_answer(context)
click.echo(self._text)
for index, option in enumerate(self._options):
click.echo(f"\t{index + 1} - {option}")
options_indexes = self._get_options_indexes(base=1)
choices = list(map(str, options_indexes))
choice = click.prompt(
text="Choice",
default=self._default_answer,
default=resolved_default_answer,
show_choices=False,
type=click.Choice(choices),
)
Expand All @@ -145,7 +231,6 @@ def _get_options_indexes(self, base: int = 0) -> List[int]:


class QuestionFactory:

question_classes: Dict[QuestionKind, Type[Question]] = {
QuestionKind.info: Info,
QuestionKind.choice: Choice,
Expand Down
24 changes: 24 additions & 0 deletions tests/unit/lib/cookiecutter/test_interactive_flow.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from pathlib import Path
from unittest import TestCase
from unittest.mock import patch
from samcli.lib.cookiecutter.interactive_flow import InteractiveFlow
Expand Down Expand Up @@ -49,3 +50,26 @@ def test_run(self, mock_3rd_q, mock_2nd_q, mock_1st_q):
mock_3rd_q.assert_called_once()
self.assertEqual(expected_context, actual_context)
self.assertIsNot(actual_context, initial_context) # shouldn't modify the input, it should copy and return new

@patch.object(Question, "ask")
@patch.object(Confirm, "ask")
@patch.object(Choice, "ask")
def test_run_with_preloaded_default_values(self, mock_3rd_q, mock_2nd_q, mock_1st_q):

mock_1st_q.return_value = "answer1"
mock_2nd_q.return_value = False
mock_3rd_q.return_value = "option1"

initial_context = {"key": "value", "['beta', 'bootstrap', 'x']": "y"}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is value? Where is this sort of format outputted?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is a {"AnyKey": "AnyValue"} pair that we don't care about

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But why is it in that format? It's not correct JSON, so I'm wondering why the test uses a custom format.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

list cannot be a key so converting it to a string (using str()) can solve it.


actual_context = self.flow.run(initial_context)

mock_1st_q.assert_called_once()
mock_2nd_q.assert_called_once()
mock_3rd_q.assert_called_once()

self.assertEqual(
{"1st": "answer1", "2nd": False, "3rd": "option1", "['beta', 'bootstrap', 'x']": "y", "key": "value"},
actual_context,
)
self.assertIsNot(actual_context, initial_context) # shouldn't modify the input, it should copy and return new
Loading