Skip to content

Commit

Permalink
add support for foreach target (#8210)
Browse files Browse the repository at this point in the history
add general support for foreach target
  • Loading branch information
skshetry authored Sep 6, 2022
1 parent 0b2b59a commit 071b2b8
Show file tree
Hide file tree
Showing 6 changed files with 28 additions and 53 deletions.
2 changes: 1 addition & 1 deletion dvc/commands/stage.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def _get_stages(self) -> Iterable["Stage"]:
# removing duplicates while maintaining order
collected = chain.from_iterable(
self.repo.stage.collect(
target=target, recursive=self.args.recursive, accept_group=True
target=target, recursive=self.args.recursive
)
for target in self.args.targets
)
Expand Down
7 changes: 5 additions & 2 deletions dvc/dvcfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,9 +243,11 @@ def _dump_lockfile(self, stage):
self._lockfile.dump(stage)

@staticmethod
def _check_if_parametrized(stage):
def _check_if_parametrized(stage, action: str = "dump") -> None:
if stage.raw_data.parametrized:
raise ParametrizedDumpError(f"cannot dump a parametrized {stage}")
raise ParametrizedDumpError(
f"cannot {action} a parametrized {stage}"
)

def _dump_pipeline_file(self, stage):
self._check_if_parametrized(stage)
Expand Down Expand Up @@ -291,6 +293,7 @@ def remove(self, force=False):
self._lockfile.remove()

def remove_stage(self, stage):
self._check_if_parametrized(stage, "remove")
self._lockfile.remove_stage(stage)
if not self.exists():
return
Expand Down
2 changes: 1 addition & 1 deletion dvc/repo/remove.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
@locked
def remove(self: "Repo", target: str, outs: bool = False):
try:
stages = self.stage.from_target(target)
stages = self.stage.from_target(target, accept_group=False)
except (StageNotFound, StageFileDoesNotExistError) as e:
# If the user specified a tracked file as a target instead of a stage,
# e.g. `data.csv` instead of `data.csv.dvc`,
Expand Down
2 changes: 0 additions & 2 deletions dvc/repo/reproduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,6 @@ def reproduce(
from .graph import get_pipeline, get_pipelines

glob = kwargs.pop("glob", False)
accept_group = not glob

if isinstance(targets, str):
targets = [targets]
Expand Down Expand Up @@ -137,7 +136,6 @@ def reproduce(
self.stage.collect(
target,
recursive=recursive,
accept_group=accept_group,
glob=glob,
)
)
Expand Down
33 changes: 11 additions & 22 deletions dvc/repo/stage.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,6 @@ def _collect_specific_target(
target: str,
with_deps: bool,
recursive: bool,
accept_group: bool,
) -> Tuple[StageIter, "OptStr", "OptStr"]:
from dvc.dvcfile import is_valid_filename

Expand All @@ -98,13 +97,11 @@ def _collect_specific_target(
msg = "Checking if stage '%s' is in '%s'"
logger.debug(msg, target, PIPELINE_FILE)
if not (recursive and loader.fs.isdir(target)):
stages = _maybe_collect_from_dvc_yaml(
loader, target, with_deps, accept_group=accept_group
)
stages = _maybe_collect_from_dvc_yaml(loader, target, with_deps)
if stages:
return stages, file, name
elif not with_deps and is_valid_filename(file):
stages = loader.load_all(file, name, accept_group=accept_group)
stages = loader.load_all(file, name)
return stages, file, name
return [], file, name

Expand Down Expand Up @@ -209,7 +206,7 @@ def create(
return stage

def from_target(
self, target: str, accept_group: bool = False, glob: bool = False
self, target: str, accept_group: bool = True, glob: bool = False
) -> StageList:
"""
Returns a list of stage from the provided target.
Expand Down Expand Up @@ -250,15 +247,12 @@ def _get_keys(
self,
stages: "StageLoader",
name: str = None,
accept_group: bool = False,
accept_group: bool = True,
glob: bool = False,
) -> Iterable[str]:

assert not (accept_group and glob)

if not name:
return stages.keys()

if accept_group and stages.is_foreach_generated(name):
return self._get_group_keys(stages, name)
if glob:
Expand All @@ -269,7 +263,7 @@ def load_all(
self,
path: str = None,
name: str = None,
accept_group: bool = False,
accept_group: bool = True,
glob: bool = False,
) -> StageList:
"""Load a list of stages from a file.
Expand Down Expand Up @@ -338,7 +332,6 @@ def collect(
with_deps: bool = False,
recursive: bool = False,
graph: "DiGraph" = None,
accept_group: bool = False,
glob: bool = False,
) -> StageIter:
"""Collect list of stages from the provided target.
Expand All @@ -347,15 +340,14 @@ def collect(
target: if not provided, all of the stages in the graph are
returned.
Target can be:
- a stage name in the `dvc.yaml` file.
- a foreach group name or a stage name in the `dvc.yaml` file.
- a generated stage name from a foreach group.
- a path to `dvc.yaml` or `.dvc` file.
- in case of a stage to a dvc.yaml file in a different
directory than current working directory, it can be a path
to dvc.yaml file, followed by a colon `:`, followed by stage
name (eg: `../dvc.yaml:build`).
- in case of `recursive`, it can be a path to a directory.
- in case of `accept_group`, it can be a group name of
`foreach` generated stage.
- in case of `glob`, it can be a wildcard pattern to match
stages. Example: `build*` for stages in `dvc.yaml` file, or
`../dvc.yaml:build*` for stages in dvc.yaml in a different
Expand All @@ -367,8 +359,6 @@ def collect(
recursive: if true and if `target` is a directory, all of the
stages inside that directory is returned.
graph: graph to use. Defaults to `repo.graph`.
accept_group: if true, all of the `foreach` generated stages of
the specified target is returned.
glob: Use `target` as a pattern to match stages in a file.
"""
if not target:
Expand All @@ -380,7 +370,7 @@ def collect(
path = self.fs.path.abspath(target)
return collect_inside_path(path, graph or self.graph)

stages = self.from_target(target, accept_group=accept_group, glob=glob)
stages = self.from_target(target, glob=glob)
if not with_deps:
return stages

Expand All @@ -392,14 +382,14 @@ def collect_granular(
with_deps: bool = False,
recursive: bool = False,
graph: "DiGraph" = None,
accept_group: bool = False,
) -> List[StageInfo]:
"""Collects a list of (stage, filter_info) from the given target.
Priority is in the order of following in case of ambiguity:
- .dvc file or .yaml file
- dir if recursive and directory exists
- stage_name
- foreach_group_name or stage_name
- generated stage name from a foreach group
- output file
Args:
Expand All @@ -418,7 +408,7 @@ def collect_granular(
target = as_posix(target)

stages, file, _ = _collect_specific_target(
self, target, with_deps, recursive, accept_group
self, target, with_deps, recursive
)
if not stages:
if not (recursive and self.fs.isdir(target)):
Expand All @@ -440,7 +430,6 @@ def collect_granular(
with_deps,
recursive,
graph,
accept_group=accept_group,
)
except StageFileDoesNotExistError as exc:
# collect() might try to use `target` as a stage name
Expand Down
35 changes: 10 additions & 25 deletions tests/func/test_stage_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,23 +116,19 @@ def stages(tmp_dir, run_copy):


def test_collect_not_a_group_stage_with_group_flag(tmp_dir, dvc, stages):
assert set(dvc.stage.collect("copy-bar-foobar", accept_group=True)) == {
assert set(dvc.stage.collect("copy-bar-foobar")) == {
stages["copy-bar-foobar"]
}
assert set(
dvc.stage.collect("copy-bar-foobar", accept_group=True, with_deps=True)
) == {
assert set(dvc.stage.collect("copy-bar-foobar", with_deps=True)) == {
stages["copy-bar-foobar"],
stages["copy-foo-bar"],
stages["foo-generate"],
}
assert set(dvc.stage.collect_granular("copy-bar-foobar")) == {
(stages["copy-bar-foobar"], None)
}
assert set(
dvc.stage.collect_granular("copy-bar-foobar", accept_group=True)
) == {(stages["copy-bar-foobar"], None)}
assert set(
dvc.stage.collect_granular(
"copy-bar-foobar", accept_group=True, with_deps=True
)
dvc.stage.collect_granular("copy-bar-foobar", with_deps=True)
) == {
(stages["copy-bar-foobar"], None),
(stages["copy-foo-bar"], None),
Expand All @@ -153,29 +149,18 @@ def test_collect_generated(tmp_dir, dvc):
assert len(all_stages) == 5

assert set(dvc.stage.collect()) == all_stages
assert set(dvc.stage.collect("build", accept_group=True)) == all_stages
assert (
set(dvc.stage.collect("build", accept_group=True, with_deps=True))
== all_stages
)
assert set(dvc.stage.collect("build")) == all_stages
assert set(dvc.stage.collect("build", with_deps=True)) == all_stages
assert set(dvc.stage.collect("build*", glob=True)) == all_stages
assert (
set(dvc.stage.collect("build*", glob=True, with_deps=True))
== all_stages
)

stages_info = {(stage, None) for stage in all_stages}
assert set(dvc.stage.collect_granular("build")) == stages_info
assert (
set(dvc.stage.collect_granular("build", accept_group=True))
== stages_info
)
assert (
set(
dvc.stage.collect_granular(
"build", accept_group=True, with_deps=True
)
)
== stages_info
set(dvc.stage.collect_granular("build", with_deps=True)) == stages_info
)


Expand Down

0 comments on commit 071b2b8

Please sign in to comment.