From 47c715c26e49ed6a99486c2893f01bfe2bfe48ed Mon Sep 17 00:00:00 2001 From: Dave Berenbaum Date: Fri, 5 Jan 2024 12:22:22 -0500 Subject: [PATCH] add back dvclive to tests (#2280) * add back dvclive * dvclive tracker: handle and test step increments * fix python<3.9 compatibility --- setup.py | 2 +- src/accelerate/tracking.py | 1 + tests/test_tracking.py | 15 +++++++++++---- 3 files changed, 13 insertions(+), 5 deletions(-) diff --git a/setup.py b/setup.py index f6eefda0dea..b3a8fda47bf 100644 --- a/setup.py +++ b/setup.py @@ -25,7 +25,7 @@ extras["testing"] = extras["test_prod"] + extras["test_dev"] extras["rich"] = ["rich"] -extras["test_trackers"] = ["wandb", "comet-ml", "tensorboard"] +extras["test_trackers"] = ["wandb", "comet-ml", "tensorboard", "dvclive"] extras["dev"] = extras["quality"] + extras["testing"] + extras["rich"] extras["sagemaker"] = [ diff --git a/src/accelerate/tracking.py b/src/accelerate/tracking.py index 4966c5e2532..6ec95224d9d 100644 --- a/src/accelerate/tracking.py +++ b/src/accelerate/tracking.py @@ -947,6 +947,7 @@ def log(self, values: dict, step: Optional[int] = None, **kwargs): "This invocation of DVCLive's Live.log_metric() " "is incorrect so we dropped this attribute." ) + self.live.next_step() @on_main_process def finish(self): diff --git a/tests/test_tracking.py b/tests/test_tracking.py index 58709546ea5..3264ca13160 100644 --- a/tests/test_tracking.py +++ b/tests/test_tracking.py @@ -513,12 +513,19 @@ def test_log(self, mock_repo): init_kwargs = {"dvclive": {"dir": dirpath, "save_dvc_exp": False, "dvcyaml": None}} accelerator.init_trackers(project_name, init_kwargs=init_kwargs) values = {"total_loss": 0.1, "iteration": 1, "my_text": "some_value"} - accelerator.log(values, step=0) + # Log step 0 + accelerator.log(values) + # Log step 1 + accelerator.log(values) + # Log step 3 (skip step 2) + accelerator.log(values, step=3) accelerator.end_training() live = accelerator.trackers[0].live logs, latest = parse_metrics(live) + assert latest.pop("step") == 3 assert latest == values scalars = os.path.join(live.plots_dir, Metric.subfolder) - assert os.path.join(scalars, "total_loss.tsv") in logs - assert os.path.join(scalars, "iteration.tsv") in logs - assert os.path.join(scalars, "my_text.tsv") in logs + for val in values.keys(): + val_path = os.path.join(scalars, f"{val}.tsv") + steps = [int(row["step"]) for row in logs[val_path]] + assert steps == [0, 1, 3]