Skip to content

Commit

Permalink
feat(context): expose a _context_phase context variable (fix copier…
Browse files Browse the repository at this point in the history
  • Loading branch information
noirbizarre committed Mar 1, 2025
1 parent 3daa5bc commit ca9986d
Show file tree
Hide file tree
Showing 9 changed files with 168 additions and 16 deletions.
38 changes: 24 additions & 14 deletions copier/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@
AnyByStrMutableMapping,
JSONSerializable,
LazyDict,
Phase,
RelativePath,
StrOrPath,
)
Expand Down Expand Up @@ -375,6 +376,7 @@ def _render_context(self) -> AnyByStrMutableMapping:
_copier_conf=conf,
_folder_name=self.subproject.local_abspath.name,
_copier_python=sys.executable,
_copier_phase=Phase.current(),
)

def _path_matcher(self, patterns: Iterable[str]) -> Callable[[Path], bool]:
Expand Down Expand Up @@ -560,7 +562,9 @@ def answers_relpath(self) -> Path:
"""
path = self.answers_file or self.template.answers_relpath
template = self.jinja_env.from_string(str(path))
return Path(template.render(**self.answers.combined))
return Path(
template.render(_copier_phase=Phase.current(), **self.answers.combined)
)

@cached_property
def all_exclusions(self) -> Sequence[str]:
Expand Down Expand Up @@ -927,8 +931,9 @@ def run_copy(self) -> None:
See [generating a project][generating-a-project].
"""
self._check_unsafe("copy")
self._print_message(self.template.message_before_copy)
self._ask()
with Phase.use(Phase.PROMPT):
self._print_message(self.template.message_before_copy)
self._ask()
was_existing = self.subproject.local_abspath.exists()
try:
if not self.quiet:
Expand All @@ -937,12 +942,14 @@ def run_copy(self) -> None:
f"\nCopying from template version {self.template.version}",
file=sys.stderr,
)
self._render_template()
with Phase.use(Phase.RENDER):
self._render_template()
if not self.quiet:
# TODO Unify printing tools
print("") # padding space
if not self.skip_tasks:
self._execute_tasks(self.template.tasks)
with Phase.use(Phase.TASKS):
self._execute_tasks(self.template.tasks)
except Exception:
if not was_existing and self.cleanup_on_error:
rmtree(self.subproject.local_abspath)
Expand Down Expand Up @@ -1009,8 +1016,9 @@ def run_update(self) -> None:
print(
f"Updating to template version {self.template.version}", file=sys.stderr
)
self._apply_update()
self._print_message(self.template.message_after_update)
with Phase.use(Phase.UPDATE):
self._apply_update()
self._print_message(self.template.message_after_update)

def _apply_update(self) -> None: # noqa: C901
git = get_git()
Expand Down Expand Up @@ -1044,9 +1052,10 @@ def _apply_update(self) -> None: # noqa: C901
) as old_worker:
old_worker.run_copy()
# Run pre-migration tasks
self._execute_tasks(
self.template.migration_tasks("before", self.subproject.template) # type: ignore[arg-type]
)
with Phase.use(Phase.MIGRATE):
self._execute_tasks(
self.template.migration_tasks("before", self.subproject.template) # type: ignore[arg-type]
)
# Create a Git tree object from the current (possibly dirty) index
# and keep the object reference.
with local.cwd(subproject_top):
Expand Down Expand Up @@ -1120,7 +1129,7 @@ def _apply_update(self) -> None: # noqa: C901
self._git_initialize_repo()
new_copy_head = git("rev-parse", "HEAD").strip()
# Extract diff between temporary destination and real destination
# with some special handling of newly added files in both the poject
# with some special handling of newly added files in both the project
# and the template.
with local.cwd(old_copy):
# Configure borrowing Git objects from the real destination and
Expand Down Expand Up @@ -1265,9 +1274,10 @@ def _apply_update(self) -> None: # noqa: C901
_remove_old_files(subproject_top, compared)

# Run post-migration tasks
self._execute_tasks(
self.template.migration_tasks("after", self.subproject.template) # type: ignore[arg-type]
)
with Phase.use(Phase.MIGRATE):
self._execute_tasks(
self.template.migration_tasks("after", self.subproject.template) # type: ignore[arg-type]
)

def _git_initialize_repo(self) -> None:
"""Initialize a git repository in the current directory."""
Expand Down
38 changes: 38 additions & 0 deletions copier/types.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,17 @@
"""Complex types, annotations, validators."""

from __future__ import annotations

from contextlib import contextmanager
from contextvars import ContextVar
from enum import Enum
from pathlib import Path
from typing import (
Annotated,
Any,
Callable,
Dict,
Iterator,
Literal,
Mapping,
MutableMapping,
Expand Down Expand Up @@ -75,3 +81,35 @@ def __getitem__(self, key: str) -> Any:
if key not in self.done:
self.done[key] = self.pending[key]()
return self.done[key]


class Phase(str, Enum):
"""The known execution phases."""

PROMPT = "prompt"
TASKS = "tasks"
MIGRATE = "migrate"
RENDER = "render"
UPDATE = "update"
UNDEFINED = "undefined"

def __str__(self) -> str:
return str(self.value)

@classmethod
@contextmanager
def use(cls, phase: Phase) -> Iterator[None]:
"""Set the current phase for the duration of a context."""
token = _phase.set(phase)
try:
yield
finally:
_phase.reset(token)

@classmethod
def current(cls) -> Phase:
"""Get the current phase."""
return _phase.get()


_phase: ContextVar[Phase] = ContextVar("phase", default=Phase.UNDEFINED)
12 changes: 10 additions & 2 deletions copier/user_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,14 @@
import json
import warnings
from collections import ChainMap
from collections.abc import Mapping, Sequence
from dataclasses import field
from datetime import datetime
from functools import cached_property
from hashlib import sha512
from os import urandom
from pathlib import Path
from typing import Any, Callable, Literal, Mapping, Sequence
from typing import Any, Callable, Literal

import yaml
from jinja2 import UndefinedError
Expand All @@ -33,6 +34,7 @@
AnyByStrMutableMapping,
LazyDict,
MissingType,
Phase,
StrOrPath,
)

Expand Down Expand Up @@ -464,7 +466,13 @@ def render_value(
else value
)
try:
return template.render({**self.answers.combined, **(extra_answers or {})})
return template.render(
{
**self.answers.combined,
**(extra_answers or {}),
"_copier_phase": Phase.current(),
}
)
except UndefinedError as error:
raise UserMessageError(str(error)) from error

Expand Down
7 changes: 7 additions & 0 deletions docs/creating.md
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,13 @@ variable:

The name of the project root directory.

### `_copier_phase`

The current phase, one of `"prompt"`,`"tasks"`, `"migrate"`, `"update"` or `"render"`.

!!! note There is also an additional `"undefined"` phase used when not in any phase,
mostly for testing purpose.

## Variables (context-specific)

Some rendering contexts provide variables unique to them:
Expand Down
18 changes: 18 additions & 0 deletions tests/test_answersfile_templating.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,3 +118,21 @@ def test_answersfile_templating_with_message_before_copy(
assert answers["module_name"] == "mymodule"
assert (dst / "result.txt").exists()
assert (dst / "result.txt").read_text() == "mymodule"


def test_answersfile_templating_phase(tmp_path_factory: pytest.TempPathFactory) -> None:
"""
Ensure `_copier_phase` is available while render `answers_relpath`.
Not because it is directly useful, but because some extensions might need it.
"""
src, dst = map(tmp_path_factory.mktemp, ("src", "dst"))
build_file_tree(
{
src / "copier.yml": """\
_answers_file: ".copier-answers-{{ _copier_phase }}.yml"
""",
src / "{{ _copier_conf.answers_file }}.jinja": "",
}
)
copier.run_copy(str(src), dst, overwrite=True, unsafe=True)
assert (dst / ".copier-answers-render.yml").exists()
7 changes: 7 additions & 0 deletions tests/test_copy.py
Original file line number Diff line number Diff line change
Expand Up @@ -1041,3 +1041,10 @@ def test_templated_choices(tmp_path_factory: pytest.TempPathFactory, spec: str)
)
copier.run_copy(str(src), dst, data={"q": "two"})
assert yaml.safe_load((dst / "q.txt").read_text()) == "two"


def test_copier_phase_variable(tmp_path_factory: pytest.TempPathFactory) -> None:
src, dst = map(tmp_path_factory.mktemp, ("src", "dst"))
build_file_tree({src / "{{ _copier_phase }}": ""})
copier.run_copy(str(src), dst)
assert (dst / "render").exists()
29 changes: 29 additions & 0 deletions tests/test_migrations.py
Original file line number Diff line number Diff line change
Expand Up @@ -521,3 +521,32 @@ def test_migration_jinja_variables(
assert f"{variable}={value}" in vars
else:
assert f"{variable}=" in vars


def test_copier_phase_variable(tmp_path_factory: pytest.TempPathFactory) -> None:
"""Test that the Phase variable is properly set."""
src, dst = map(tmp_path_factory.mktemp, ("src", "dst"))

with local.cwd(src):
build_file_tree(
{
**COPIER_ANSWERS_FILE,
"copier.yml": (
"""\
_migrations:
- touch {{ _copier_phase }}
"""
),
}
)
git_save(tag="v1")
with local.cwd(dst):
run_copy(src_path=str(src))
git_save()

with local.cwd(src):
git("tag", "v2")
with local.cwd(dst):
run_update(defaults=True, overwrite=True, unsafe=True)

assert (dst / "migrate").is_file()
16 changes: 16 additions & 0 deletions tests/test_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,3 +164,19 @@ def test_os_specific_tasks(
monkeypatch.setattr("copier.main.OS", os)
copier.run_copy(str(src), dst, unsafe=True)
assert (dst / filename).exists()


def test_copier_phase_variable(tmp_path_factory: pytest.TempPathFactory) -> None:
src, dst = map(tmp_path_factory.mktemp, ("src", "dst"))
build_file_tree(
{
(src / "copier.yml"): (
"""
_tasks:
- touch {{ _copier_phase }}
"""
)
}
)
copier.run_copy(str(src), dst, unsafe=True)
assert (dst / "tasks").exists()
19 changes: 19 additions & 0 deletions tests/test_templated_prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -563,3 +563,22 @@ def test_multiselect_choices_with_templated_default_value(
"python_version": "3.11",
"github_runner_python_version": ["3.11"],
}


def test_copier_phase_variable(
tmp_path_factory: pytest.TempPathFactory,
spawn: Spawn,
) -> None:
src, dst = map(tmp_path_factory.mktemp, ("src", "dst"))
build_file_tree(
{
(src / "copier.yml"): """\
phase:
type: str
default: "{{ _copier_phase }}"
"""
}
)
tui = spawn(COPIER_PATH + ("copy", str(src), str(dst)), timeout=10)
expect_prompt(tui, "phase", "str")
tui.expect_exact("prompt")

0 comments on commit ca9986d

Please sign in to comment.