Skip to content

Commit

Permalink
Skip-unchanged-nodes (#41)
Browse files Browse the repository at this point in the history
* add `get_changed_stages`

* raname `G` to `digraph`

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* add logging

* bugfix unchanged stages

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* typo

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
PythonFZ and pre-commit-ci[bot] authored Dec 20, 2024
1 parent 325e1c4 commit 690f754
Show file tree
Hide file tree
Showing 6 changed files with 73 additions and 18 deletions.
19 changes: 18 additions & 1 deletion paraffin/cli.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import logging
import os
import subprocess
import time
Expand All @@ -10,11 +11,14 @@
from paraffin.submit import submit_node_graph
from paraffin.utils import (
dag_to_levels,
get_changed_stages,
get_custom_queue,
get_stage_graph,
levels_to_mermaid,
)

log = logging.getLogger(__name__)

app = typer.Typer()


Expand Down Expand Up @@ -110,12 +114,18 @@ def submit(
commit: bool = typer.Option(
False, help="Automatically commit changes and push to remotes."
),
v: bool = typer.Option(False, help="Verbose output."),
):
"""Run DVC stages in parallel using Celery."""
if skip_unchanged:
raise NotImplementedError("Skipping unchanged stages is not yet implemented.")
if v:
logging.basicConfig(level=logging.DEBUG)

log.debug("Getting stage graph")
graph = get_stage_graph(names=names, glob=glob)
log.debug("Getting changed stages")
changed_stages = get_changed_stages(graph)
custom_queues = get_custom_queue()

repo = git.Repo() # TODO: consider allow submitting remote repos
Expand All @@ -124,6 +134,7 @@ def submit(
except AttributeError:
origin = None

log.debug("Converting graph to levels")
disconnected_subgraphs = list(nx.connected_components(graph.to_undirected()))
disconnected_levels = []
for subgraph in disconnected_subgraphs:
Expand All @@ -137,13 +148,19 @@ def submit(
)
# iterate disconnected subgraphs for better performance
if not dry:
log.debug("Submitting node graph")
for levels in disconnected_levels:
# TODO: why not have the commit, repo, branch and origin as arguments here!
submit_node_graph(
levels,
custom_queues=custom_queues,
changed_stages=changed_stages,
)
if show_mermaid:
typer.echo(levels_to_mermaid(disconnected_levels))
log.debug("Visualizing graph")
typer.echo(
levels_to_mermaid(disconnected_levels, changed_stages=changed_stages)
)

typer.echo(f"Submitted all (n = {len(graph)}) tasks.")
typer.echo(
Expand Down
7 changes: 5 additions & 2 deletions paraffin/submit.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,21 @@
from celery import chain, group

from paraffin.abc import HirachicalStages
from paraffin.worker import repro
from paraffin.worker import repro, skipped_repro


def submit_node_graph(
levels: HirachicalStages,
custom_queues: t.Optional[t.Dict[str, str]] = None,
changed_stages: list[str] | None = None,
):
per_level_groups = []
for nodes in levels.values():
group_tasks = []
for node in nodes:
if matched_pattern := next(
if changed_stages and node.name not in changed_stages:
group_tasks.append(skipped_repro.s())
elif matched_pattern := next(
(
pattern
for pattern in custom_queues
Expand Down
29 changes: 27 additions & 2 deletions paraffin/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,26 @@ def get_stage_graph(names, glob=False):
return subgraph


def get_changed_stages(subgraph) -> list:
fs = dvc.api.DVCFileSystem(url=None, rev=None)
repo = fs.repo
names = [x.name for x in subgraph.nodes]
changed = list(repo.status(targets=names))
graph = fs.repo.index.graph.reverse(copy=True)
# find all downstream stages and add them to the changed list
# Issue with changed stages is, if any upstream stage was changed
# then we need to run ALL downstream stages, because
# dvc status does not know / tell us because the immediate
# upstream stage was unchanged at the point of checking.

for name in changed:
stage = next(x for x in graph.nodes if hasattr(x, "name") and x.name == name)
for node in nx.descendants(graph, stage):
changed.append(node.name)
# TODO: split into definitely changed and maybe changed stages
return changed


def get_custom_queue():
try:
with pathlib.Path("paraffin.yaml").open() as f:
Expand Down Expand Up @@ -153,7 +173,9 @@ def dag_to_levels(
return levels


def levels_to_mermaid(all_levels: list[HirachicalStages]) -> str:
def levels_to_mermaid(
all_levels: list[HirachicalStages], changed_stages: list[str]
) -> str:
# Initialize Mermaid syntax
mermaid_syntax = "flowchart TD\n"

Expand All @@ -162,7 +184,10 @@ def levels_to_mermaid(all_levels: list[HirachicalStages]) -> str:
for level, nodes in levels.items():
mermaid_syntax += f"\tsubgraph Level{idx}:{level + 1}\n"
for node in nodes:
mermaid_syntax += f"\t\t{node.name}\n"
if node.name in changed_stages:
mermaid_syntax += f"\t\t{node.name}\n"
else:
mermaid_syntax += f"\t\t{node.name}(✓)\n"
mermaid_syntax += "\tend\n"

# Add connections between levels
Expand Down
6 changes: 6 additions & 0 deletions paraffin/worker/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,3 +130,9 @@ def repro(self, *args, name: str, branch: str, origin: str | None, commit: bool)
# remove the working directory
shutil.rmtree(working_dir)
return True


@app.task(bind=True)
def skipped_repro(*args, **kwargs):
"""Dummy Celery task for testing purposes."""
pass
6 changes: 5 additions & 1 deletion tests/test_run_all.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,10 @@ def check_finished(names: list[str] | None = None, exclusive: bool = False) -> b
for name in names or []:
cmd.append(name)
result = subprocess.run(cmd, capture_output=True, check=True)
return result.stdout.decode().strip() == "Data and pipelines are up to date."
finished = result.stdout.decode().strip() == "Data and pipelines are up to date."
if not finished:
print(result.stdout.decode())
return finished


def test_check_finished(proj01):
Expand Down Expand Up @@ -149,6 +152,7 @@ def test_run_datafile(proj02, caplog):
data_file.write_text("4,5,6")

result = runner.invoke(app, ["submit", "--glob", "a*"])
print(result.stdout)
# assert "Running 2 stages" in caplog.text
# caplog.clear()
assert result.exit_code == 0
Expand Down
24 changes: 12 additions & 12 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@ def test_dag_to_levels_1():
B --> C
```
"""
G = nx.DiGraph()
G.add_edges_from([("A", "C"), ("B", "C")])
levels = dag_to_levels(G, branch="main", origin=None, commit=False)
digraph = nx.DiGraph()
digraph.add_edges_from([("A", "C"), ("B", "C")])
levels = dag_to_levels(digraph, branch="main", origin=None, commit=False)
assert len(levels) == 2

assert len(levels[0]) == 2
Expand All @@ -30,9 +30,9 @@ def test_dag_to_levels_2():
A --> B --> C
```
"""
G = nx.DiGraph()
G.add_edges_from([("A", "B"), ("B", "C")])
levels = dag_to_levels(G, branch="main", origin=None, commit=False)
digraph = nx.DiGraph()
digraph.add_edges_from([("A", "B"), ("B", "C")])
levels = dag_to_levels(digraph, branch="main", origin=None, commit=False)
assert len(levels) == 3

assert len(levels[0]) == 1
Expand All @@ -51,9 +51,9 @@ def test_dag_to_levels_3():
A --> C
```
"""
G = nx.DiGraph()
G.add_edges_from([("A", "B"), ("B", "C"), ("A", "C")])
levels = dag_to_levels(G, branch="main", origin=None, commit=False)
digraph = nx.DiGraph()
digraph.add_edges_from([("A", "B"), ("B", "C"), ("A", "C")])
levels = dag_to_levels(digraph, branch="main", origin=None, commit=False)
assert len(levels) == 3

assert len(levels[0]) == 1
Expand All @@ -74,9 +74,9 @@ def test_dag_to_levles_4():
C --> E
```
"""
G = nx.DiGraph()
G.add_edges_from([("A", "D"), ("B", "D"), ("B", "E"), ("C", "E")])
levels = dag_to_levels(G, branch="main", origin=None, commit=False)
digraph = nx.DiGraph()
digraph.add_edges_from([("A", "D"), ("B", "D"), ("B", "E"), ("C", "E")])
levels = dag_to_levels(digraph, branch="main", origin=None, commit=False)
assert len(levels) == 2

assert len(levels[0]) == 3
Expand Down

0 comments on commit 690f754

Please sign in to comment.