diff --git a/paraffin/abc.py b/paraffin/abc.py index b170253..e34f094 100644 --- a/paraffin/abc.py +++ b/paraffin/abc.py @@ -1,27 +1,5 @@ -import dataclasses import typing as t from dvc.stage import PipelineStage - -@dataclasses.dataclass(frozen=True) -class StageContainer: - stage: PipelineStage - branch: str - origin: t.Optional[str] - commit: bool - - @property - def name(self) -> str: - return self.stage.name - - def to_dict(self) -> dict[str, t.Any]: - return { - "name": self.name, - "branch": self.branch, - "origin": self.origin, - "commit": self.commit, - } - - -HirachicalStages = t.Dict[int, t.List[StageContainer]] +HirachicalStages = t.Dict[int, t.List[PipelineStage]] diff --git a/paraffin/cli.py b/paraffin/cli.py index 4070467..2087d9d 100644 --- a/paraffin/cli.py +++ b/paraffin/cli.py @@ -6,6 +6,7 @@ import git import networkx as nx +import tqdm import typer from paraffin.submit import submit_node_graph @@ -114,12 +115,19 @@ def submit( commit: bool = typer.Option( False, help="Automatically commit changes and push to remotes." ), - v: bool = typer.Option(False, help="Verbose output."), + verbose: bool = typer.Option(False, help="Verbose output."), + use_dvc: bool = typer.Option( + True, + help="Use DVC to manage pipeline stages. Do not change" + " this unless you know what you are doing.", + ), ): """Run DVC stages in parallel using Celery.""" if skip_unchanged: raise NotImplementedError("Skipping unchanged stages is not yet implemented.") - if v: + if not use_dvc and commit: + raise ValueError("Cannot commit changes without using DVC.") + if verbose: logging.basicConfig(level=logging.DEBUG) log.debug("Getting stage graph") @@ -141,9 +149,6 @@ def submit( disconnected_levels.append( dag_to_levels( graph=graph.subgraph(subgraph), - branch=str(repo.active_branch), - origin=origin, - commit=commit, ) ) # iterate disconnected subgraphs for better performance @@ -155,6 +160,10 @@ def submit( levels, custom_queues=custom_queues, changed_stages=changed_stages, + branch=str(repo.active_branch), + origin=origin, + commit=commit, + use_dvc=use_dvc, ) if show_mermaid: log.debug("Visualizing graph") @@ -167,3 +176,48 @@ def submit( "Start your celery worker using `paraffin worker`" " and specify concurrency with `--concurrency`." ) + + +@app.command() +def commit( + names: t.Optional[list[str]] = typer.Argument( + None, help="Stage names to run. If not specified, run all stages." + ), + check: bool = typer.Option(False), + verbose: bool = typer.Option(False, help="Verbose output."), +): + if verbose: + logging.basicConfig(level=logging.DEBUG) + log.debug("Getting stage graph") + graph = get_stage_graph(names=names, glob=True) + if check: + log.debug("Getting changed stages") + changed_stages = get_changed_stages(graph) + else: + changed_stages = [node.name for node in graph.nodes] + + disconnected_subgraphs = list(nx.connected_components(graph.to_undirected())) + disconnected_levels = [] + for subgraph in disconnected_subgraphs: + disconnected_levels.append( + dag_to_levels( + graph=graph.subgraph(subgraph), + ) + ) + + tbar = tqdm.tqdm( + disconnected_levels, desc="Committing stages", total=len(changed_stages) + ) + + for levels in disconnected_levels: + for nodes in levels.values(): + for node in nodes: + if node.name in changed_stages: + tbar.set_postfix(current=node.name) + cmd = ["dvc", "commit", node.name, "--force"] + res = subprocess.run(cmd, capture_output=True) + if res.returncode != 0: + log.error(f"Failed to commit {node.name}") + log.error(res.stderr.decode()) + raise RuntimeError(f"Failed to commit {node.name}") + tbar.update() diff --git a/paraffin/submit.py b/paraffin/submit.py index 8848bfd..e6e5f6f 100644 --- a/paraffin/submit.py +++ b/paraffin/submit.py @@ -1,5 +1,4 @@ import fnmatch -import typing as t from celery import chain, group @@ -9,8 +8,12 @@ def submit_node_graph( levels: HirachicalStages, - custom_queues: t.Optional[t.Dict[str, str]] = None, - changed_stages: list[str] | None = None, + custom_queues: dict[str, str], + changed_stages: list[str], + branch: str, + origin: str | None, + commit: bool, + use_dvc: bool, ): per_level_groups = [] for nodes in levels.values(): @@ -27,10 +30,26 @@ def submit_node_graph( None, ): group_tasks.append( - repro.s(**node.to_dict()).set(queue=custom_queues[matched_pattern]) + repro.s( + name=node.name, + cmd=node.cmd, + branch=branch, + commit=commit, + origin=origin, + use_dvc=use_dvc, + ).set(queue=custom_queues[matched_pattern]) ) else: - group_tasks.append(repro.s(**node.to_dict())) + group_tasks.append( + repro.s( + name=node.name, + cmd=node.cmd, + branch=branch, + commit=commit, + origin=origin, + use_dvc=use_dvc, + ) + ) per_level_groups.append(group(group_tasks)) workflow = chain(per_level_groups) diff --git a/paraffin/utils.py b/paraffin/utils.py index d1743a0..538a0e0 100644 --- a/paraffin/utils.py +++ b/paraffin/utils.py @@ -7,7 +7,7 @@ import networkx as nx import yaml -from paraffin.abc import HirachicalStages, StageContainer +from paraffin.abc import HirachicalStages def get_subgraph_with_predecessors(graph, nodes, reverse=False): @@ -109,9 +109,7 @@ def get_custom_queue(): return {} -def dag_to_levels( - graph, branch: str, origin: str | None, commit: bool -) -> HirachicalStages: +def dag_to_levels(graph) -> HirachicalStages: """Converts a directed acyclic graph (DAG) into hierarchical levels. This function takes a directed acyclic graph (DAG) and organizes its nodes @@ -150,23 +148,9 @@ def dag_to_levels( for path in nx.all_simple_paths(graph, start_node, node): level = max(level, len(path) - 1) try: - levels[level].append( - StageContainer( - stage=node, - branch=branch, - origin=origin, - commit=commit, - ) - ) + levels[level].append(node) except KeyError: - levels[level] = [ - StageContainer( - stage=node, - branch=branch, - origin=origin, - commit=commit, - ) - ] + levels[level] = [node] else: # this part has already been added break diff --git a/paraffin/worker/__init__.py b/paraffin/worker/__init__.py index c69b72e..fb96ed9 100644 --- a/paraffin/worker/__init__.py +++ b/paraffin/worker/__init__.py @@ -51,9 +51,8 @@ def make_celery() -> Celery: app = make_celery() -@app.task(bind=True, default_retry_delay=5) # retry in 5 seconds -def repro(self, *args, name: str, branch: str, origin: str | None, commit: bool): - """Celery task to reproduce a DVC pipeline stage. +def _run_dvc(self, name: str): + """Run DVC repro command for a given stage. This task attempts to reproduce a specified DVC pipeline stage using the `dvc repro` command. @@ -62,30 +61,7 @@ def repro(self, *args, name: str, branch: str, origin: str | None, commit: bool) If the error occurs after the stage has been executed, it attempts to commit the lock using the `dvc commit` command with a forced option to avoid loss of computational resources. - - Args: - self (Task): The bound Celery task instance. - *args: Additional arguments. - name (str): The name of the DVC pipeline stage to reproduce. - - Raises: - self.retry: If the "Unable to acquire lock" error occurs, - the task is retried up to 5 times. - RuntimeError: If unable to commit the lock after multiple attempts. - - Returns: - bool: True if the operation is successful. """ - working_dir = pathlib.Path(os.environ.get("PARAFFIN_WORKING_DIRECTORY", ".")) - cleanup = True if os.environ.get("PARAFFIN_CLEANUP", "True") == "True" else False - print(f"Working directory: {working_dir} with cleanup: {cleanup}") - - if not working_dir.exists(): - working_dir.mkdir(parents=True) - os.chdir(working_dir) - - clone_and_checkout(branch, origin) - popen = subprocess.Popen( ["dvc", "repro", "--single-item", name], stdout=subprocess.PIPE, @@ -123,6 +99,58 @@ def repro(self, *args, name: str, branch: str, origin: str | None, commit: bool) raise RuntimeError(f"Unable to commit lock for {name}") popen.stderr.close() + +def _run_vanilla(self, cmd: str): + """Run a vanilla command for a given stage. + + This task attempts to run a specified command + using the `subprocess.Popen` function. + """ + print(f"Running command: {cmd}") + subprocess.check_call(cmd, shell=True) + + +@app.task(bind=True, default_retry_delay=5) # retry in 5 seconds +def repro( + self, + *args, + name: str, + branch: str, + origin: str | None, + commit: bool, + cmd: str, + use_dvc: bool, +): + """Celery task to reproduce a DVC pipeline stage. + + Args: + self (Task): The bound Celery task instance. + *args: Additional arguments. + name (str): The name of the DVC pipeline stage to reproduce. + + Raises: + self.retry: If the "Unable to acquire lock" error occurs, + the task is retried up to 5 times. + RuntimeError: If unable to commit the lock after multiple attempts. + + Returns: + bool: True if the operation is successful. + """ + working_dir = pathlib.Path(os.environ.get("PARAFFIN_WORKING_DIRECTORY", ".")) + cleanup = True if os.environ.get("PARAFFIN_CLEANUP", "True") == "True" else False + # print(f"Working directory: {working_dir} with cleanup: {cleanup}") + + if not working_dir.exists(): + working_dir.mkdir(parents=True) + os.chdir(working_dir) + + clone_and_checkout(branch, origin) + + if use_dvc: + _run_dvc(self, name) + else: + _run_vanilla(self, cmd) + if commit: commit_and_push(name=name, origin=origin) diff --git a/tests/test_utils.py b/tests/test_utils.py index d54d822..64e5ba0 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -13,14 +13,9 @@ def test_dag_to_levels_1(): """ digraph = nx.DiGraph() digraph.add_edges_from([("A", "C"), ("B", "C")]) - levels = dag_to_levels(digraph, branch="main", origin=None, commit=False) - assert len(levels) == 2 + levels = dag_to_levels(digraph) - assert len(levels[0]) == 2 - assert len(levels[1]) == 1 - assert levels[0][0].stage == "A" - assert levels[0][1].stage == "B" - assert levels[1][0].stage == "C" + assert levels == {0: ["A", "B"], 1: ["C"]} def test_dag_to_levels_2(): @@ -32,15 +27,9 @@ def test_dag_to_levels_2(): """ digraph = nx.DiGraph() digraph.add_edges_from([("A", "B"), ("B", "C")]) - levels = dag_to_levels(digraph, branch="main", origin=None, commit=False) - assert len(levels) == 3 + levels = dag_to_levels(digraph) - assert len(levels[0]) == 1 - assert len(levels[1]) == 1 - assert len(levels[2]) == 1 - assert levels[0][0].stage == "A" - assert levels[1][0].stage == "B" - assert levels[2][0].stage == "C" + assert levels == {0: ["A"], 1: ["B"], 2: ["C"]} def test_dag_to_levels_3(): @@ -53,15 +42,9 @@ def test_dag_to_levels_3(): """ digraph = nx.DiGraph() digraph.add_edges_from([("A", "B"), ("B", "C"), ("A", "C")]) - levels = dag_to_levels(digraph, branch="main", origin=None, commit=False) - assert len(levels) == 3 + levels = dag_to_levels(digraph) - assert len(levels[0]) == 1 - assert len(levels[1]) == 1 - assert len(levels[2]) == 1 - assert levels[0][0].stage == "A" - assert levels[1][0].stage == "B" - assert levels[2][0].stage == "C" + assert levels == {0: ["A"], 1: ["B"], 2: ["C"]} def test_dag_to_levles_4(): @@ -76,13 +59,6 @@ def test_dag_to_levles_4(): """ digraph = nx.DiGraph() digraph.add_edges_from([("A", "D"), ("B", "D"), ("B", "E"), ("C", "E")]) - levels = dag_to_levels(digraph, branch="main", origin=None, commit=False) - assert len(levels) == 2 + levels = dag_to_levels(digraph) - assert len(levels[0]) == 3 - assert len(levels[1]) == 2 - assert levels[0][0].stage == "A" - assert levels[0][1].stage == "B" - assert levels[0][2].stage == "C" - assert levels[1][0].stage == "D" - assert levels[1][1].stage == "E" + assert levels == {0: ["A", "B", "C"], 1: ["D", "E"]}