Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add support for foreach target #8210

Merged
merged 1 commit into from
Sep 6, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion dvc/commands/stage.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def _get_stages(self) -> Iterable["Stage"]:
# removing duplicates while maintaining order
collected = chain.from_iterable(
self.repo.stage.collect(
target=target, recursive=self.args.recursive, accept_group=True
target=target, recursive=self.args.recursive
)
for target in self.args.targets
)
Expand Down
7 changes: 5 additions & 2 deletions dvc/dvcfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,9 +243,11 @@ def _dump_lockfile(self, stage):
self._lockfile.dump(stage)

@staticmethod
def _check_if_parametrized(stage):
def _check_if_parametrized(stage, action: str = "dump") -> None:
if stage.raw_data.parametrized:
raise ParametrizedDumpError(f"cannot dump a parametrized {stage}")
raise ParametrizedDumpError(
f"cannot {action} a parametrized {stage}"
)

def _dump_pipeline_file(self, stage):
self._check_if_parametrized(stage)
Expand Down Expand Up @@ -291,6 +293,7 @@ def remove(self, force=False):
self._lockfile.remove()

def remove_stage(self, stage):
self._check_if_parametrized(stage, "remove")
self._lockfile.remove_stage(stage)
if not self.exists():
return
Expand Down
2 changes: 1 addition & 1 deletion dvc/repo/remove.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
@locked
def remove(self: "Repo", target: str, outs: bool = False):
try:
stages = self.stage.from_target(target)
stages = self.stage.from_target(target, accept_group=False)
except (StageNotFound, StageFileDoesNotExistError) as e:
# If the user specified a tracked file as a target instead of a stage,
# e.g. `data.csv` instead of `data.csv.dvc`,
Expand Down
2 changes: 0 additions & 2 deletions dvc/repo/reproduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,6 @@ def reproduce(
from .graph import get_pipeline, get_pipelines

glob = kwargs.pop("glob", False)
accept_group = not glob

if isinstance(targets, str):
targets = [targets]
Expand Down Expand Up @@ -137,7 +136,6 @@ def reproduce(
self.stage.collect(
target,
recursive=recursive,
accept_group=accept_group,
glob=glob,
)
)
Expand Down
33 changes: 11 additions & 22 deletions dvc/repo/stage.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,6 @@ def _collect_specific_target(
target: str,
with_deps: bool,
recursive: bool,
accept_group: bool,
) -> Tuple[StageIter, "OptStr", "OptStr"]:
from dvc.dvcfile import is_valid_filename

Expand All @@ -98,13 +97,11 @@ def _collect_specific_target(
msg = "Checking if stage '%s' is in '%s'"
logger.debug(msg, target, PIPELINE_FILE)
if not (recursive and loader.fs.isdir(target)):
stages = _maybe_collect_from_dvc_yaml(
loader, target, with_deps, accept_group=accept_group
)
stages = _maybe_collect_from_dvc_yaml(loader, target, with_deps)
if stages:
return stages, file, name
elif not with_deps and is_valid_filename(file):
stages = loader.load_all(file, name, accept_group=accept_group)
stages = loader.load_all(file, name)
return stages, file, name
return [], file, name

Expand Down Expand Up @@ -209,7 +206,7 @@ def create(
return stage

def from_target(
self, target: str, accept_group: bool = False, glob: bool = False
self, target: str, accept_group: bool = True, glob: bool = False
) -> StageList:
"""
Returns a list of stage from the provided target.
Expand Down Expand Up @@ -250,15 +247,12 @@ def _get_keys(
self,
stages: "StageLoader",
name: str = None,
accept_group: bool = False,
accept_group: bool = True,
glob: bool = False,
) -> Iterable[str]:

assert not (accept_group and glob)

if not name:
return stages.keys()

if accept_group and stages.is_foreach_generated(name):
return self._get_group_keys(stages, name)
if glob:
Expand All @@ -269,7 +263,7 @@ def load_all(
self,
path: str = None,
name: str = None,
accept_group: bool = False,
accept_group: bool = True,
glob: bool = False,
) -> StageList:
"""Load a list of stages from a file.
Expand Down Expand Up @@ -338,7 +332,6 @@ def collect(
with_deps: bool = False,
recursive: bool = False,
graph: "DiGraph" = None,
accept_group: bool = False,
glob: bool = False,
) -> StageIter:
"""Collect list of stages from the provided target.
Expand All @@ -347,15 +340,14 @@ def collect(
target: if not provided, all of the stages in the graph are
returned.
Target can be:
- a stage name in the `dvc.yaml` file.
- a foreach group name or a stage name in the `dvc.yaml` file.
- a generated stage name from a foreach group.
- a path to `dvc.yaml` or `.dvc` file.
- in case of a stage to a dvc.yaml file in a different
directory than current working directory, it can be a path
to dvc.yaml file, followed by a colon `:`, followed by stage
name (eg: `../dvc.yaml:build`).
- in case of `recursive`, it can be a path to a directory.
- in case of `accept_group`, it can be a group name of
`foreach` generated stage.
- in case of `glob`, it can be a wildcard pattern to match
stages. Example: `build*` for stages in `dvc.yaml` file, or
`../dvc.yaml:build*` for stages in dvc.yaml in a different
Expand All @@ -367,8 +359,6 @@ def collect(
recursive: if true and if `target` is a directory, all of the
stages inside that directory is returned.
graph: graph to use. Defaults to `repo.graph`.
accept_group: if true, all of the `foreach` generated stages of
the specified target is returned.
glob: Use `target` as a pattern to match stages in a file.
"""
if not target:
Expand All @@ -380,7 +370,7 @@ def collect(
path = self.fs.path.abspath(target)
return collect_inside_path(path, graph or self.graph)

stages = self.from_target(target, accept_group=accept_group, glob=glob)
stages = self.from_target(target, glob=glob)
if not with_deps:
return stages

Expand All @@ -392,14 +382,14 @@ def collect_granular(
with_deps: bool = False,
recursive: bool = False,
graph: "DiGraph" = None,
accept_group: bool = False,
) -> List[StageInfo]:
"""Collects a list of (stage, filter_info) from the given target.

Priority is in the order of following in case of ambiguity:
- .dvc file or .yaml file
- dir if recursive and directory exists
- stage_name
- foreach_group_name or stage_name
- generated stage name from a foreach group
- output file

Args:
Expand All @@ -418,7 +408,7 @@ def collect_granular(
target = as_posix(target)

stages, file, _ = _collect_specific_target(
self, target, with_deps, recursive, accept_group
self, target, with_deps, recursive
)
if not stages:
if not (recursive and self.fs.isdir(target)):
Expand All @@ -440,7 +430,6 @@ def collect_granular(
with_deps,
recursive,
graph,
accept_group=accept_group,
)
except StageFileDoesNotExistError as exc:
# collect() might try to use `target` as a stage name
Expand Down
35 changes: 10 additions & 25 deletions tests/func/test_stage_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,23 +116,19 @@ def stages(tmp_dir, run_copy):


def test_collect_not_a_group_stage_with_group_flag(tmp_dir, dvc, stages):
assert set(dvc.stage.collect("copy-bar-foobar", accept_group=True)) == {
assert set(dvc.stage.collect("copy-bar-foobar")) == {
stages["copy-bar-foobar"]
}
assert set(
dvc.stage.collect("copy-bar-foobar", accept_group=True, with_deps=True)
) == {
assert set(dvc.stage.collect("copy-bar-foobar", with_deps=True)) == {
stages["copy-bar-foobar"],
stages["copy-foo-bar"],
stages["foo-generate"],
}
assert set(dvc.stage.collect_granular("copy-bar-foobar")) == {
(stages["copy-bar-foobar"], None)
}
assert set(
dvc.stage.collect_granular("copy-bar-foobar", accept_group=True)
) == {(stages["copy-bar-foobar"], None)}
assert set(
dvc.stage.collect_granular(
"copy-bar-foobar", accept_group=True, with_deps=True
)
dvc.stage.collect_granular("copy-bar-foobar", with_deps=True)
) == {
(stages["copy-bar-foobar"], None),
(stages["copy-foo-bar"], None),
Expand All @@ -153,29 +149,18 @@ def test_collect_generated(tmp_dir, dvc):
assert len(all_stages) == 5

assert set(dvc.stage.collect()) == all_stages
assert set(dvc.stage.collect("build", accept_group=True)) == all_stages
assert (
set(dvc.stage.collect("build", accept_group=True, with_deps=True))
== all_stages
)
assert set(dvc.stage.collect("build")) == all_stages
assert set(dvc.stage.collect("build", with_deps=True)) == all_stages
assert set(dvc.stage.collect("build*", glob=True)) == all_stages
assert (
set(dvc.stage.collect("build*", glob=True, with_deps=True))
== all_stages
)

stages_info = {(stage, None) for stage in all_stages}
assert set(dvc.stage.collect_granular("build")) == stages_info
assert (
set(dvc.stage.collect_granular("build", accept_group=True))
== stages_info
)
assert (
set(
dvc.stage.collect_granular(
"build", accept_group=True, with_deps=True
)
)
== stages_info
set(dvc.stage.collect_granular("build", with_deps=True)) == stages_info
)


Expand Down