Skip to content

Commit

Permalink
add back dvclive to tests (huggingface#2280)
Browse files Browse the repository at this point in the history
* add back dvclive

* dvclive tracker: handle and test step increments

* fix python<3.9 compatibility
  • Loading branch information
Dave Berenbaum authored and unit_test committed Jan 22, 2024
1 parent c5eee26 commit 47c715c
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 5 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
extras["testing"] = extras["test_prod"] + extras["test_dev"]
extras["rich"] = ["rich"]

extras["test_trackers"] = ["wandb", "comet-ml", "tensorboard"]
extras["test_trackers"] = ["wandb", "comet-ml", "tensorboard", "dvclive"]
extras["dev"] = extras["quality"] + extras["testing"] + extras["rich"]

extras["sagemaker"] = [
Expand Down
1 change: 1 addition & 0 deletions src/accelerate/tracking.py
Original file line number Diff line number Diff line change
Expand Up @@ -947,6 +947,7 @@ def log(self, values: dict, step: Optional[int] = None, **kwargs):
"This invocation of DVCLive's Live.log_metric() "
"is incorrect so we dropped this attribute."
)
self.live.next_step()

@on_main_process
def finish(self):
Expand Down
15 changes: 11 additions & 4 deletions tests/test_tracking.py
Original file line number Diff line number Diff line change
Expand Up @@ -513,12 +513,19 @@ def test_log(self, mock_repo):
init_kwargs = {"dvclive": {"dir": dirpath, "save_dvc_exp": False, "dvcyaml": None}}
accelerator.init_trackers(project_name, init_kwargs=init_kwargs)
values = {"total_loss": 0.1, "iteration": 1, "my_text": "some_value"}
accelerator.log(values, step=0)
# Log step 0
accelerator.log(values)
# Log step 1
accelerator.log(values)
# Log step 3 (skip step 2)
accelerator.log(values, step=3)
accelerator.end_training()
live = accelerator.trackers[0].live
logs, latest = parse_metrics(live)
assert latest.pop("step") == 3
assert latest == values
scalars = os.path.join(live.plots_dir, Metric.subfolder)
assert os.path.join(scalars, "total_loss.tsv") in logs
assert os.path.join(scalars, "iteration.tsv") in logs
assert os.path.join(scalars, "my_text.tsv") in logs
for val in values.keys():
val_path = os.path.join(scalars, f"{val}.tsv")
steps = [int(row["step"]) for row in logs[val_path]]
assert steps == [0, 1, 3]

0 comments on commit 47c715c

Please sign in to comment.