Skip to content

Commit

Permalink
add tqdm bar for projects with more than 5 nodes (#738)
Browse files Browse the repository at this point in the history
  • Loading branch information
PythonFZ authored Nov 7, 2023
1 parent adcb123 commit a730fd3
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 4 deletions.
14 changes: 12 additions & 2 deletions zntrack/project/zntrack_project.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

import dvc.api
import git
import tqdm
import yaml
import znflow
from znflow.handler import UpdateConnectors
Expand Down Expand Up @@ -217,7 +218,13 @@ def run(
else:
raise ValueError(f"Unknown node type {type(node)}")

for node_uuid in self.graph.get_sorted_nodes():
sorted_nodes = self.graph.get_sorted_nodes()

_tqdm_disabled = True if eager or len(sorted_nodes) <= 5 else False

tbar = tqdm.tqdm(self.graph.get_sorted_nodes(), ncols=140, disable=_tqdm_disabled)

for node_uuid in tbar:
node: Node = self.graph.nodes[node_uuid]["value"]
if node_names is not None and node.name not in node_names:
continue
Expand All @@ -238,7 +245,10 @@ def run(
node, git_only_repo=self.git_only_repo, **optional.get(node.name, {})
)
for x in cmd:
run_dvc_cmd(x)
stdout = None
if not _tqdm_disabled:
stdout = tbar.set_description
run_dvc_cmd(x, stdout=stdout)
node.save(results=False)
if not eager and repro:
self.repro()
Expand Down
9 changes: 7 additions & 2 deletions zntrack/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,13 +95,16 @@ class DVCProcessError(Exception):
"""DVC specific message for CalledProcessError."""


def run_dvc_cmd(script):
def run_dvc_cmd(script, stdout=None):
"""Run the DVC script via subprocess calls.
Parameters
----------
script: tuple[str]|list[str]
A list of strings to pass the subprocess command
stdout: callable, optional
A callable to which the stdout is passed. If None, the stdout is
passed to log.warning.
Raises
------
Expand All @@ -111,7 +114,9 @@ def run_dvc_cmd(script):
dvc_short_string = " ".join(script[:5])
if len(script) > 5:
dvc_short_string += " ..."
log.warning(f"Running DVC command: '{dvc_short_string}'")
if stdout is None:
stdout = log.warning
stdout(f"Running DVC command: '{dvc_short_string}'")
# do not display the output if log.log_level > logging.INFO
show_log = config.log_level < logging.INFO
if not show_log:
Expand Down

0 comments on commit a730fd3

Please sign in to comment.