Skip to content

Commit

Permalink
add support for foreach target
Browse files Browse the repository at this point in the history
  • Loading branch information
skshetry committed Aug 31, 2022
1 parent e7daeb1 commit d1272ed
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 57 deletions.
2 changes: 1 addition & 1 deletion dvc/commands/stage.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def _get_stages(self) -> Iterable["Stage"]:
# removing duplicates while maintaining order
collected = chain.from_iterable(
self.repo.stage.collect(
target=target, recursive=self.args.recursive, accept_group=True
target=target, recursive=self.args.recursive
)
for target in self.args.targets
)
Expand Down
2 changes: 0 additions & 2 deletions dvc/repo/reproduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,6 @@ def reproduce(
from .graph import get_pipeline, get_pipelines

glob = kwargs.pop("glob", False)
accept_group = not glob

if isinstance(targets, str):
targets = [targets]
Expand Down Expand Up @@ -137,7 +136,6 @@ def reproduce(
self.stage.collect(
target,
recursive=recursive,
accept_group=accept_group,
glob=glob,
)
)
Expand Down
38 changes: 9 additions & 29 deletions dvc/repo/stage.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,6 @@ def _collect_specific_target(
target: str,
with_deps: bool,
recursive: bool,
accept_group: bool,
) -> Tuple[StageIter, "OptStr", "OptStr"]:
from dvc.dvcfile import is_valid_filename

Expand All @@ -98,13 +97,11 @@ def _collect_specific_target(
msg = "Checking if stage '%s' is in '%s'"
logger.debug(msg, target, PIPELINE_FILE)
if not (recursive and loader.fs.isdir(target)):
stages = _maybe_collect_from_dvc_yaml(
loader, target, with_deps, accept_group=accept_group
)
stages = _maybe_collect_from_dvc_yaml(loader, target, with_deps)
if stages:
return stages, file, name
elif not with_deps and is_valid_filename(file):
stages = loader.load_all(file, name, accept_group=accept_group)
stages = loader.load_all(file, name)
return stages, file, name
return [], file, name

Expand Down Expand Up @@ -208,17 +205,13 @@ def create(
restore_fields(stage)
return stage

def from_target(
self, target: str, accept_group: bool = False, glob: bool = False
) -> StageList:
def from_target(self, target: str, glob: bool = False) -> StageList:
"""
Returns a list of stage from the provided target.
(see load method below for further details)
"""
path, name = parse_target(target, isa_glob=glob)
return self.load_all(
path=path, name=name, accept_group=accept_group, glob=glob
)
return self.load_all(path=path, name=name, glob=glob)

def get_target(self, target: str) -> "Stage":
"""
Expand Down Expand Up @@ -250,16 +243,12 @@ def _get_keys(
self,
stages: "StageLoader",
name: str = None,
accept_group: bool = False,
glob: bool = False,
) -> Iterable[str]:

assert not (accept_group and glob)

if not name:
return stages.keys()

if accept_group and stages.is_foreach_generated(name):
if stages.is_foreach_generated(name):
return self._get_group_keys(stages, name)
if glob:
return fnmatch.filter(stages.keys(), name)
Expand All @@ -269,16 +258,13 @@ def load_all(
self,
path: str = None,
name: str = None,
accept_group: bool = False,
glob: bool = False,
) -> StageList:
"""Load a list of stages from a file.
Args:
path: if not provided, default `dvc.yaml` is assumed.
name: required for `dvc.yaml` files, ignored for `.dvc` files.
accept_group: if true, all of the the stages generated from `name`
foreach are returned.
glob: if true, `name` is considered as a glob, which is
used to filter list of stages from the given `path`.
"""
Expand All @@ -295,7 +281,7 @@ def load_all(
return [stage]

assert isinstance(stages, StageLoader)
keys = self._get_keys(stages, name, accept_group, glob)
keys = self._get_keys(stages, name, glob)
return [stages[key] for key in keys]

def load_one(self, path: str = None, name: str = None) -> "Stage":
Expand Down Expand Up @@ -338,7 +324,6 @@ def collect(
with_deps: bool = False,
recursive: bool = False,
graph: "DiGraph" = None,
accept_group: bool = False,
glob: bool = False,
) -> StageIter:
"""Collect list of stages from the provided target.
Expand All @@ -348,14 +333,13 @@ def collect(
returned.
Target can be:
- a stage name in the `dvc.yaml` file.
- a foreach group name in the `dvc.yaml` file.
- a path to `dvc.yaml` or `.dvc` file.
- in case of a stage to a dvc.yaml file in a different
directory than current working directory, it can be a path
to dvc.yaml file, followed by a colon `:`, followed by stage
name (eg: `../dvc.yaml:build`).
- in case of `recursive`, it can be a path to a directory.
- in case of `accept_group`, it can be a group name of
`foreach` generated stage.
- in case of `glob`, it can be a wildcard pattern to match
stages. Example: `build*` for stages in `dvc.yaml` file, or
`../dvc.yaml:build*` for stages in dvc.yaml in a different
Expand All @@ -367,8 +351,6 @@ def collect(
recursive: if true and if `target` is a directory, all of the
stages inside that directory is returned.
graph: graph to use. Defaults to `repo.graph`.
accept_group: if true, all of the `foreach` generated stages of
the specified target is returned.
glob: Use `target` as a pattern to match stages in a file.
"""
if not target:
Expand All @@ -380,7 +362,7 @@ def collect(
path = self.fs.path.abspath(target)
return collect_inside_path(path, graph or self.graph)

stages = self.from_target(target, accept_group=accept_group, glob=glob)
stages = self.from_target(target, glob=glob)
if not with_deps:
return stages

Expand All @@ -392,7 +374,6 @@ def collect_granular(
with_deps: bool = False,
recursive: bool = False,
graph: "DiGraph" = None,
accept_group: bool = False,
) -> List[StageInfo]:
"""Collects a list of (stage, filter_info) from the given target.
Expand All @@ -418,7 +399,7 @@ def collect_granular(
target = as_posix(target)

stages, file, _ = _collect_specific_target(
self, target, with_deps, recursive, accept_group
self, target, with_deps, recursive
)
if not stages:
if not (recursive and self.fs.isdir(target)):
Expand All @@ -440,7 +421,6 @@ def collect_granular(
with_deps,
recursive,
graph,
accept_group=accept_group,
)
except StageFileDoesNotExistError as exc:
# collect() might try to use `target` as a stage name
Expand Down
35 changes: 10 additions & 25 deletions tests/func/test_stage_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,23 +116,19 @@ def stages(tmp_dir, run_copy):


def test_collect_not_a_group_stage_with_group_flag(tmp_dir, dvc, stages):
assert set(dvc.stage.collect("copy-bar-foobar", accept_group=True)) == {
assert set(dvc.stage.collect("copy-bar-foobar")) == {
stages["copy-bar-foobar"]
}
assert set(
dvc.stage.collect("copy-bar-foobar", accept_group=True, with_deps=True)
) == {
assert set(dvc.stage.collect("copy-bar-foobar", with_deps=True)) == {
stages["copy-bar-foobar"],
stages["copy-foo-bar"],
stages["foo-generate"],
}
assert set(dvc.stage.collect_granular("copy-bar-foobar")) == {
(stages["copy-bar-foobar"], None)
}
assert set(
dvc.stage.collect_granular("copy-bar-foobar", accept_group=True)
) == {(stages["copy-bar-foobar"], None)}
assert set(
dvc.stage.collect_granular(
"copy-bar-foobar", accept_group=True, with_deps=True
)
dvc.stage.collect_granular("copy-bar-foobar", with_deps=True)
) == {
(stages["copy-bar-foobar"], None),
(stages["copy-foo-bar"], None),
Expand All @@ -153,29 +149,18 @@ def test_collect_generated(tmp_dir, dvc):
assert len(all_stages) == 5

assert set(dvc.stage.collect()) == all_stages
assert set(dvc.stage.collect("build", accept_group=True)) == all_stages
assert (
set(dvc.stage.collect("build", accept_group=True, with_deps=True))
== all_stages
)
assert set(dvc.stage.collect("build")) == all_stages
assert set(dvc.stage.collect("build", with_deps=True)) == all_stages
assert set(dvc.stage.collect("build*", glob=True)) == all_stages
assert (
set(dvc.stage.collect("build*", glob=True, with_deps=True))
== all_stages
)

stages_info = {(stage, None) for stage in all_stages}
assert set(dvc.stage.collect_granular("build")) == stages_info
assert (
set(dvc.stage.collect_granular("build", accept_group=True))
== stages_info
)
assert (
set(
dvc.stage.collect_granular(
"build", accept_group=True, with_deps=True
)
)
== stages_info
set(dvc.stage.collect_granular("build", with_deps=True)) == stages_info
)


Expand Down

0 comments on commit d1272ed

Please sign in to comment.