From a8c7d53ceef336ee287b90a5c0e539d335b40d4d Mon Sep 17 00:00:00 2001 From: "Axel H." Date: Sun, 19 Jan 2025 15:54:16 +0100 Subject: [PATCH] feat(context): expose a `_context_phase` context variable (fix #1883) --- copier/main.py | 16 ++++++++++++++-- copier/types.py | 10 ++++++++++ copier/user_data.py | 10 ++++++++-- docs/creating.md | 4 ++++ tests/test_copy.py | 7 +++++++ tests/test_migrations.py | 29 +++++++++++++++++++++++++++++ tests/test_tasks.py | 16 ++++++++++++++++ tests/test_templated_prompt.py | 19 +++++++++++++++++++ 8 files changed, 107 insertions(+), 4 deletions(-) diff --git a/copier/main.py b/copier/main.py index f2b1db195..f89b58ed5 100644 --- a/copier/main.py +++ b/copier/main.py @@ -62,6 +62,7 @@ MISSING, AnyByStrDict, JSONSerializable, + Phase, RelativePath, StrOrPath, ) @@ -206,6 +207,7 @@ class Worker: unsafe: bool = False skip_answered: bool = False skip_tasks: bool = False + phase: Phase = Phase.PROMPT answers: AnswersMap = field(default_factory=AnswersMap, init=False) _cleanup_hooks: list[Callable[[], None]] = field(default_factory=list, init=False) @@ -354,6 +356,7 @@ def _render_context(self) -> Mapping[str, Any]: _copier_conf=conf, _folder_name=self.subproject.local_abspath.name, _copier_python=sys.executable, + _copier_phase=self.phase.value, ) def _path_matcher(self, patterns: Iterable[str]) -> Callable[[Path], bool]: @@ -456,6 +459,7 @@ def _render_allowed( def _ask(self) -> None: # noqa: C901 """Ask the questions of the questionnaire and record their answers.""" + self.phase = Phase.PROMPT result = AnswersMap( user_defaults=self.user_defaults, init=self.data, @@ -534,7 +538,12 @@ def answers_relpath(self) -> Path: """ path = self.answers_file or self.template.answers_relpath template = self.jinja_env.from_string(str(path)) - return Path(template.render(**self.answers.combined)) + return Path( + template.render( + **self.answers.combined, + _copier_phase=self.phase.value, + ) + ) @cached_property def all_exclusions(self) -> Sequence[str]: @@ -601,6 +610,7 @@ def match_skip(self) -> Callable[[Path], bool]: def _render_template(self) -> None: """Render the template in the subproject root.""" + self.phase = Phase.RENDER follow_symlinks = not self.template.preserve_symlinks for src in scantree(str(self.template_copy_root), follow_symlinks): src_abspath = Path(src.path) @@ -916,6 +926,7 @@ def run_copy(self) -> None: # TODO Unify printing tools print("") # padding space if not self.skip_tasks: + self.phase = Phase.TASKS self._execute_tasks(self.template.tasks) except Exception: if not was_existing and self.cleanup_on_error: @@ -1015,6 +1026,7 @@ def _apply_update(self) -> None: # noqa: C901 ) as old_worker: old_worker.run_copy() # Run pre-migration tasks + self.phase = Phase.MIGRATE self._execute_tasks( self.template.migration_tasks("before", self.subproject.template) # type: ignore[arg-type] ) @@ -1090,7 +1102,7 @@ def _apply_update(self) -> None: # noqa: C901 self._git_initialize_repo() new_copy_head = git("rev-parse", "HEAD").strip() # Extract diff between temporary destination and real destination - # with some special handling of newly added files in both the poject + # with some special handling of newly added files in both the project # and the template. with local.cwd(old_copy): # Configure borrowing Git objects from the real destination and diff --git a/copier/types.py b/copier/types.py index be9a5de91..e1752eec2 100644 --- a/copier/types.py +++ b/copier/types.py @@ -1,5 +1,6 @@ """Complex types, annotations, validators.""" +from enum import Enum from pathlib import Path from typing import ( Annotated, @@ -58,3 +59,12 @@ def path_is_relative(value: Path) -> Path: AbsolutePath = Annotated[Path, AfterValidator(path_is_absolute)] RelativePath = Annotated[Path, AfterValidator(path_is_relative)] + + +class Phase(str, Enum): + """The known execution phases.""" + + PROMPT = "prompt" + TASKS = "tasks" + MIGRATE = "migrate" + RENDER = "render" diff --git a/copier/user_data.py b/copier/user_data.py index 6672128e0..2270b7a86 100644 --- a/copier/user_data.py +++ b/copier/user_data.py @@ -25,7 +25,7 @@ from .errors import InvalidTypeError, UserMessageError from .tools import cast_to_bool, cast_to_str, force_str_end -from .types import MISSING, AnyByStrDict, MissingType, OptStrOrPath, StrOrPath +from .types import MISSING, AnyByStrDict, MissingType, OptStrOrPath, Phase, StrOrPath # TODO Remove these two functions as well as DEFAULT_DATA in a future release @@ -441,7 +441,13 @@ def render_value( else value ) try: - return template.render({**self.answers.combined, **(extra_answers or {})}) + return template.render( + { + **self.answers.combined, + **(extra_answers or {}), + "_copier_phase": Phase.PROMPT.value, + } + ) except UndefinedError as error: raise UserMessageError(str(error)) from error diff --git a/docs/creating.md b/docs/creating.md index 75c07a863..fef15a851 100644 --- a/docs/creating.md +++ b/docs/creating.md @@ -125,6 +125,10 @@ The absolute path of the Python interpreter running Copier. The name of the project root directory. +### `_copier_phase` + +The current phase, one of `"prompt"`,`"tasks"`, `"migrate"` or `"render"`. + ## Variables (context-specific) Some rendering contexts provide variables unique to them: diff --git a/tests/test_copy.py b/tests/test_copy.py index c98060a1a..a41574a3c 100644 --- a/tests/test_copy.py +++ b/tests/test_copy.py @@ -1042,3 +1042,10 @@ def test_templated_choices(tmp_path_factory: pytest.TempPathFactory, spec: str) ) copier.run_copy(str(src), dst, data={"q": "two"}) assert yaml.safe_load((dst / "q.txt").read_text()) == "two" + + +def test_copier_phase_variable(tmp_path_factory: pytest.TempPathFactory) -> None: + src, dst = map(tmp_path_factory.mktemp, ("src", "dst")) + build_file_tree({src / "{{ _copier_phase }}": ""}) + copier.run_copy(str(src), dst) + assert (dst / "render").exists() diff --git a/tests/test_migrations.py b/tests/test_migrations.py index b77f50d90..61aa4f13c 100644 --- a/tests/test_migrations.py +++ b/tests/test_migrations.py @@ -521,3 +521,32 @@ def test_migration_jinja_variables( assert f"{variable}={value}" in vars else: assert f"{variable}=" in vars + + +def test_copier_phase_variable(tmp_path_factory: pytest.TempPathFactory) -> None: + """Test that the Phase variable is properly set.""" + src, dst = map(tmp_path_factory.mktemp, ("src", "dst")) + + with local.cwd(src): + build_file_tree( + { + **COPIER_ANSWERS_FILE, + "copier.yml": ( + """\ + _migrations: + - touch {{ _copier_phase }} + """ + ), + } + ) + git_save(tag="v1") + with local.cwd(dst): + run_copy(src_path=str(src)) + git_save() + + with local.cwd(src): + git("tag", "v2") + with local.cwd(dst): + run_update(defaults=True, overwrite=True, unsafe=True) + + assert (dst / "migrate").is_file() diff --git a/tests/test_tasks.py b/tests/test_tasks.py index 38be36b50..d4286e1f3 100644 --- a/tests/test_tasks.py +++ b/tests/test_tasks.py @@ -164,3 +164,19 @@ def test_os_specific_tasks( monkeypatch.setattr("copier.main.OS", os) copier.run_copy(str(src), dst, unsafe=True) assert (dst / filename).exists() + + +def test_copier_phase_variable(tmp_path_factory: pytest.TempPathFactory) -> None: + src, dst = map(tmp_path_factory.mktemp, ("src", "dst")) + build_file_tree( + { + (src / "copier.yml"): ( + """ + _tasks: + - touch {{ _copier_phase }} + """ + ) + } + ) + copier.run_copy(str(src), dst, unsafe=True) + assert (dst / "tasks").exists() diff --git a/tests/test_templated_prompt.py b/tests/test_templated_prompt.py index d7e745dc0..53dedc14c 100644 --- a/tests/test_templated_prompt.py +++ b/tests/test_templated_prompt.py @@ -563,3 +563,22 @@ def test_multiselect_choices_with_templated_default_value( "python_version": "3.11", "github_runner_python_version": ["3.11"], } + + +def test_copier_phase_variable( + tmp_path_factory: pytest.TempPathFactory, + spawn: Spawn, +) -> None: + src, dst = map(tmp_path_factory.mktemp, ("src", "dst")) + build_file_tree( + { + (src / "copier.yml"): """\ + phase: + type: str + default: "{{ _copier_phase }}" + """ + } + ) + tui = spawn(COPIER_PATH + ("copy", str(src), str(dst)), timeout=10) + expect_prompt(tui, "phase", "str") + tui.expect_exact("prompt")