Skip to content

Commit

Permalink
feat(context): expose a _phase context variable (fix #1883)
Browse files Browse the repository at this point in the history
  • Loading branch information
noirbizarre committed Jan 26, 2025
1 parent 57439e5 commit 166fc44
Show file tree
Hide file tree
Showing 8 changed files with 101 additions and 3 deletions.
9 changes: 8 additions & 1 deletion copier/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@
MISSING,
AnyByStrDict,
JSONSerializable,
Phase,
RelativePath,
StrOrPath,
)
Expand Down Expand Up @@ -206,6 +207,7 @@ class Worker:
unsafe: bool = False
skip_answered: bool = False
skip_tasks: bool = False
phase: Phase = Phase.prompt

answers: AnswersMap = field(default_factory=AnswersMap, init=False)
_cleanup_hooks: list[Callable[[], None]] = field(default_factory=list, init=False)
Expand Down Expand Up @@ -354,6 +356,7 @@ def _render_context(self) -> Mapping[str, Any]:
_copier_conf=conf,
_folder_name=self.subproject.local_abspath.name,
_copier_python=sys.executable,
_phase=self.phase.value,
)

def _path_matcher(self, patterns: Iterable[str]) -> Callable[[Path], bool]:
Expand Down Expand Up @@ -456,6 +459,7 @@ def _render_allowed(

def _ask(self) -> None: # noqa: C901
"""Ask the questions of the questionnaire and record their answers."""
self.phase = Phase.prompt
result = AnswersMap(
user_defaults=self.user_defaults,
init=self.data,
Expand Down Expand Up @@ -601,6 +605,7 @@ def match_skip(self) -> Callable[[Path], bool]:

def _render_template(self) -> None:
"""Render the template in the subproject root."""
self.phase = Phase.render
follow_symlinks = not self.template.preserve_symlinks
for src in scantree(str(self.template_copy_root), follow_symlinks):
src_abspath = Path(src.path)
Expand Down Expand Up @@ -916,6 +921,7 @@ def run_copy(self) -> None:
# TODO Unify printing tools
print("") # padding space
if not self.skip_tasks:
self.phase = Phase.tasks
self._execute_tasks(self.template.tasks)
except Exception:
if not was_existing and self.cleanup_on_error:
Expand Down Expand Up @@ -1015,6 +1021,7 @@ def _apply_update(self) -> None: # noqa: C901
) as old_worker:
old_worker.run_copy()
# Run pre-migration tasks
self.phase = Phase.migrate
self._execute_tasks(
self.template.migration_tasks("before", self.subproject.template) # type: ignore[arg-type]
)
Expand Down Expand Up @@ -1090,7 +1097,7 @@ def _apply_update(self) -> None: # noqa: C901
self._git_initialize_repo()
new_copy_head = git("rev-parse", "HEAD").strip()
# Extract diff between temporary destination and real destination
# with some special handling of newly added files in both the poject
# with some special handling of newly added files in both the project
# and the template.
with local.cwd(old_copy):
# Configure borrowing Git objects from the real destination and
Expand Down
10 changes: 10 additions & 0 deletions copier/types.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Complex types, annotations, validators."""

from enum import Enum
from pathlib import Path
from typing import (
Annotated,
Expand Down Expand Up @@ -58,3 +59,12 @@ def path_is_relative(value: Path) -> Path:

AbsolutePath = Annotated[Path, AfterValidator(path_is_absolute)]
RelativePath = Annotated[Path, AfterValidator(path_is_relative)]


class Phase(str, Enum):
"""The known execution phases."""

prompt = "prompt"
tasks = "tasks"
migrate = "migrate"
render = "render"
10 changes: 8 additions & 2 deletions copier/user_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@

from .errors import InvalidTypeError, UserMessageError
from .tools import cast_to_bool, cast_to_str, force_str_end
from .types import MISSING, AnyByStrDict, MissingType, OptStrOrPath, StrOrPath
from .types import MISSING, AnyByStrDict, MissingType, OptStrOrPath, Phase, StrOrPath


# TODO Remove these two functions as well as DEFAULT_DATA in a future release
Expand Down Expand Up @@ -441,7 +441,13 @@ def render_value(
else value
)
try:
return template.render({**self.answers.combined, **(extra_answers or {})})
return template.render(
{
**self.answers.combined,
**(extra_answers or {}),
"_phase": Phase.prompt,
}
)
except UndefinedError as error:
raise UserMessageError(str(error)) from error

Expand Down
4 changes: 4 additions & 0 deletions docs/creating.md
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,10 @@ The absolute path of the Python interpreter running Copier.

The name of the project root directory.

### `_phase`

The current phase, one of `"prompt"`,`"tasks"`, `"migrate"` or `"render"`.

## Variables (context-specific)

Some rendering contexts provide variables unique to them:
Expand Down
7 changes: 7 additions & 0 deletions tests/test_copy.py
Original file line number Diff line number Diff line change
Expand Up @@ -1042,3 +1042,10 @@ def test_templated_choices(tmp_path_factory: pytest.TempPathFactory, spec: str)
)
copier.run_copy(str(src), dst, data={"q": "two"})
assert yaml.safe_load((dst / "q.txt").read_text()) == "two"


def test_phase_variable(tmp_path_factory: pytest.TempPathFactory) -> None:
src, dst = map(tmp_path_factory.mktemp, ("src", "dst"))
build_file_tree({src / "{{ _phase }}": ""})
copier.run_copy(str(src), dst)
assert (dst / "render").exists()
29 changes: 29 additions & 0 deletions tests/test_migrations.py
Original file line number Diff line number Diff line change
Expand Up @@ -521,3 +521,32 @@ def test_migration_jinja_variables(
assert f"{variable}={value}" in vars
else:
assert f"{variable}=" in vars


def test_phase_variable(tmp_path_factory: pytest.TempPathFactory) -> None:
"""Test that the j-phase variable is properly set"""
src, dst = map(tmp_path_factory.mktemp, ("src", "dst"))

with local.cwd(src):
build_file_tree(
{
**COPIER_ANSWERS_FILE,
"copier.yml": (
"""\
_migrations:
- touch {{ _phase }}
"""
),
}
)
git_save(tag="v1")
with local.cwd(dst):
run_copy(src_path=str(src))
git_save()

with local.cwd(src):
git("tag", "v2")
with local.cwd(dst):
run_update(defaults=True, overwrite=True, unsafe=True)

assert (dst / "migrate").is_file()
16 changes: 16 additions & 0 deletions tests/test_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,3 +164,19 @@ def test_os_specific_tasks(
monkeypatch.setattr("copier.main.OS", os)
copier.run_copy(str(src), dst, unsafe=True)
assert (dst / filename).exists()


def test_phase_variable(tmp_path_factory: pytest.TempPathFactory) -> None:
src, dst = map(tmp_path_factory.mktemp, ("src", "dst"))
build_file_tree(
{
(src / "copier.yml"): (
"""
_tasks:
- touch {{ _phase }}
"""
)
}
)
copier.run_copy(str(src), dst, unsafe=True)
assert (dst / "tasks").exists()
19 changes: 19 additions & 0 deletions tests/test_templated_prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -563,3 +563,22 @@ def test_multiselect_choices_with_templated_default_value(
"python_version": "3.11",
"github_runner_python_version": ["3.11"],
}


def test_phase_variable(
tmp_path_factory: pytest.TempPathFactory,
spawn: Spawn,
) -> None:
src, dst = map(tmp_path_factory.mktemp, ("src", "dst"))
build_file_tree(
{
(src / "copier.yml"): """\
phase:
type: str
default: "{{ _phase }}"
"""
}
)
tui = spawn(COPIER_PATH + ("copy", str(src), str(dst)), timeout=10)
expect_prompt(tui, "phase", "str")
tui.expect_exact("prompt")

0 comments on commit 166fc44

Please sign in to comment.