Skip to content

Commit

Permalink
feat(context): expose a _context_phase context variable (fix #1883)
Browse files Browse the repository at this point in the history
  • Loading branch information
noirbizarre committed Jan 29, 2025
1 parent 71358ed commit 17efb40
Show file tree
Hide file tree
Showing 8 changed files with 114 additions and 5 deletions.
24 changes: 21 additions & 3 deletions copier/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,14 @@
scantree,
set_git_alternates,
)
from .types import MISSING, AnyByStrDict, JSONSerializable, RelativePath, StrOrPath
from .types import (
MISSING,
AnyByStrDict,
JSONSerializable,
Phase,
RelativePath,
StrOrPath,
)
from .user_data import DEFAULT_DATA, AnswersMap, Question
from .vcs import get_git

Expand Down Expand Up @@ -202,6 +209,7 @@ class Worker:
unsafe: bool = False
skip_answered: bool = False
skip_tasks: bool = False
phase: Phase = Phase.PROMPT

answers: AnswersMap = field(default_factory=AnswersMap, init=False)
_cleanup_hooks: list[Callable[[], None]] = field(default_factory=list, init=False)
Expand Down Expand Up @@ -350,6 +358,7 @@ def _render_context(self) -> Mapping[str, Any]:
_copier_conf=conf,
_folder_name=self.subproject.local_abspath.name,
_copier_python=sys.executable,
_copier_phase=self.phase.value,
)

def _path_matcher(self, patterns: Iterable[str]) -> Callable[[Path], bool]:
Expand Down Expand Up @@ -452,6 +461,7 @@ def _render_allowed(

def _ask(self) -> None: # noqa: C901
"""Ask the questions of the questionnaire and record their answers."""
self.phase = Phase.PROMPT
result = AnswersMap(
user_defaults=self.user_defaults,
init=self.data,
Expand Down Expand Up @@ -531,7 +541,12 @@ def answers_relpath(self) -> Path:
"""
path = self.answers_file or self.template.answers_relpath
template = self.jinja_env.from_string(str(path))
return Path(template.render(**self.answers.combined))
return Path(
template.render(
**self.answers.combined,
_copier_phase=self.phase.value,
)
)

@cached_property
def all_exclusions(self) -> Sequence[str]:
Expand Down Expand Up @@ -598,6 +613,7 @@ def match_skip(self) -> Callable[[Path], bool]:

def _render_template(self) -> None:
"""Render the template in the subproject root."""
self.phase = Phase.RENDER
follow_symlinks = not self.template.preserve_symlinks
for src in scantree(str(self.template_copy_root), follow_symlinks):
src_abspath = Path(src.path)
Expand Down Expand Up @@ -913,6 +929,7 @@ def run_copy(self) -> None:
# TODO Unify printing tools
print("") # padding space
if not self.skip_tasks:
self.phase = Phase.TASKS
self._execute_tasks(self.template.tasks)
except Exception:
if not was_existing and self.cleanup_on_error:
Expand Down Expand Up @@ -1015,6 +1032,7 @@ def _apply_update(self) -> None: # noqa: C901
) as old_worker:
old_worker.run_copy()
# Run pre-migration tasks
self.phase = Phase.MIGRATE
self._execute_tasks(
self.template.migration_tasks("before", self.subproject.template) # type: ignore[arg-type]
)
Expand Down Expand Up @@ -1090,7 +1108,7 @@ def _apply_update(self) -> None: # noqa: C901
self._git_initialize_repo()
new_copy_head = git("rev-parse", "HEAD").strip()
# Extract diff between temporary destination and real destination
# with some special handling of newly added files in both the poject
# with some special handling of newly added files in both the project
# and the template.
with local.cwd(old_copy):
# Configure borrowing Git objects from the real destination and
Expand Down
10 changes: 10 additions & 0 deletions copier/types.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Complex types, annotations, validators."""

from enum import Enum
from pathlib import Path
from typing import (
Annotated,
Expand Down Expand Up @@ -58,3 +59,12 @@ def path_is_relative(value: Path) -> Path:

AbsolutePath = Annotated[Path, AfterValidator(path_is_absolute)]
RelativePath = Annotated[Path, AfterValidator(path_is_relative)]


class Phase(str, Enum):
"""The known execution phases."""

PROMPT = "prompt"
TASKS = "tasks"
MIGRATE = "migrate"
RENDER = "render"
10 changes: 8 additions & 2 deletions copier/user_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@

from .errors import InvalidTypeError, UserMessageError
from .tools import cast_to_bool, cast_to_str, force_str_end
from .types import MISSING, AnyByStrDict, MissingType, OptStrOrPath, StrOrPath
from .types import MISSING, AnyByStrDict, MissingType, OptStrOrPath, Phase, StrOrPath


# TODO Remove these two functions as well as DEFAULT_DATA in a future release
Expand Down Expand Up @@ -446,7 +446,13 @@ def render_value(
else value
)
try:
return template.render({**self.answers.combined, **(extra_answers or {})})
return template.render(
{
**self.answers.combined,
**(extra_answers or {}),
"_copier_phase": Phase.PROMPT.value,
}
)
except UndefinedError as error:
raise UserMessageError(str(error)) from error

Expand Down
4 changes: 4 additions & 0 deletions docs/creating.md
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,10 @@ The absolute path of the Python interpreter running Copier.

The name of the project root directory.

### `_copier_phase`

The current phase, one of `"prompt"`,`"tasks"`, `"migrate"` or `"render"`.

## Variables (context-specific)

Some rendering contexts provide variables unique to them:
Expand Down
7 changes: 7 additions & 0 deletions tests/test_copy.py
Original file line number Diff line number Diff line change
Expand Up @@ -1042,3 +1042,10 @@ def test_templated_choices(tmp_path_factory: pytest.TempPathFactory, spec: str)
)
copier.run_copy(str(src), dst, data={"q": "two"})
assert yaml.safe_load((dst / "q.txt").read_text()) == "two"


def test_copier_phase_variable(tmp_path_factory: pytest.TempPathFactory) -> None:
src, dst = map(tmp_path_factory.mktemp, ("src", "dst"))
build_file_tree({src / "{{ _copier_phase }}": ""})
copier.run_copy(str(src), dst)
assert (dst / "render").exists()
29 changes: 29 additions & 0 deletions tests/test_migrations.py
Original file line number Diff line number Diff line change
Expand Up @@ -521,3 +521,32 @@ def test_migration_jinja_variables(
assert f"{variable}={value}" in vars
else:
assert f"{variable}=" in vars


def test_copier_phase_variable(tmp_path_factory: pytest.TempPathFactory) -> None:
"""Test that the Phase variable is properly set."""
src, dst = map(tmp_path_factory.mktemp, ("src", "dst"))

with local.cwd(src):
build_file_tree(
{
**COPIER_ANSWERS_FILE,
"copier.yml": (
"""\
_migrations:
- touch {{ _copier_phase }}
"""
),
}
)
git_save(tag="v1")
with local.cwd(dst):
run_copy(src_path=str(src))
git_save()

with local.cwd(src):
git("tag", "v2")
with local.cwd(dst):
run_update(defaults=True, overwrite=True, unsafe=True)

assert (dst / "migrate").is_file()
16 changes: 16 additions & 0 deletions tests/test_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,3 +164,19 @@ def test_os_specific_tasks(
monkeypatch.setattr("copier.main.OS", os)
copier.run_copy(str(src), dst, unsafe=True)
assert (dst / filename).exists()


def test_copier_phase_variable(tmp_path_factory: pytest.TempPathFactory) -> None:
src, dst = map(tmp_path_factory.mktemp, ("src", "dst"))
build_file_tree(
{
(src / "copier.yml"): (
"""
_tasks:
- touch {{ _copier_phase }}
"""
)
}
)
copier.run_copy(str(src), dst, unsafe=True)
assert (dst / "tasks").exists()
19 changes: 19 additions & 0 deletions tests/test_templated_prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -563,3 +563,22 @@ def test_multiselect_choices_with_templated_default_value(
"python_version": "3.11",
"github_runner_python_version": ["3.11"],
}


def test_copier_phase_variable(
tmp_path_factory: pytest.TempPathFactory,
spawn: Spawn,
) -> None:
src, dst = map(tmp_path_factory.mktemp, ("src", "dst"))
build_file_tree(
{
(src / "copier.yml"): """\
phase:
type: str
default: "{{ _copier_phase }}"
"""
}
)
tui = spawn(COPIER_PATH + ("copy", str(src), str(dst)), timeout=10)
expect_prompt(tui, "phase", "str")
tui.expect_exact("prompt")

0 comments on commit 17efb40

Please sign in to comment.