Skip to content

Commit

Permalink
TensorBoard Override for same Repository (#709)
Browse files Browse the repository at this point in the history
* tensorboard override enabled for the next time user pushes to same repository

* make style

* removed redundant tests

* changed adding separate log from separate folder
  • Loading branch information
merveenoyan authored Feb 23, 2022
1 parent 2e36429 commit 0773518
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 1 deletion.
5 changes: 4 additions & 1 deletion src/huggingface_hub/keras_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import logging
import os
from pathlib import Path
from shutil import copytree
from shutil import copytree, rmtree
from typing import Any, Dict, Optional, Union

from huggingface_hub import ModelHubMixin
Expand Down Expand Up @@ -179,7 +179,10 @@ def push_to_hub_keras(
**model_save_kwargs,
)
if log_dir is not None:
if os.path.exists(f"{repo_path_or_name}/logs"):
rmtree(f"{repo_path_or_name}/logs")
copytree(log_dir, f"{repo_path_or_name}/logs")

# Commit and push!
repo.git_add(auto_lfs_track=True)
repo.git_commit(commit_message)
Expand Down
42 changes: 42 additions & 0 deletions tests/test_keras_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,6 +324,48 @@ def test_push_to_hub_tensorboard(self):

self._api.delete_repo(name=f"{REPO_NAME}", token=self._token)

@retry_endpoint
def test_override_tensorboard(self):
os.makedirs(f"{WORKING_REPO_DIR}/tb_log_dir")
with open(f"{WORKING_REPO_DIR}/tb_log_dir/tensorboard.txt", "w") as fp:
fp.write("Keras FTW")
REPO_NAME = repo_name("PUSH_TO_HUB")
model = self.model_init()
model.build((None, 2))
push_to_hub_keras(
model,
repo_path_or_name=f"{WORKING_REPO_DIR}/{REPO_NAME}",
log_dir=f"{WORKING_REPO_DIR}/tb_log_dir",
api_endpoint=ENDPOINT_STAGING,
use_auth_token=self._token,
git_user="ci",
git_email="ci@dummy.com",
)
os.makedirs(f"{WORKING_REPO_DIR}/tb_log_dir2")
with open(f"{WORKING_REPO_DIR}/tb_log_dir2/override.txt", "w") as fp:
fp.write("Keras FTW")
push_to_hub_keras(
model,
repo_path_or_name=f"{WORKING_REPO_DIR}/{REPO_NAME}",
log_dir=f"{WORKING_REPO_DIR}/tb_log_dir2",
api_endpoint=ENDPOINT_STAGING,
use_auth_token=self._token,
git_user="ci",
git_email="ci@dummy.com",
)

model_info = HfApi(endpoint=ENDPOINT_STAGING).model_info(
f"{USER}/{REPO_NAME}",
)
self.assertTrue(
"logs/override.txt" in [f.rfilename for f in model_info.siblings]
)
self.assertFalse(
"logs/tensorboard.txt" in [f.rfilename for f in model_info.siblings]
)

self._api.delete_repo(name=f"{REPO_NAME}", token=self._token)

@retry_endpoint
def test_push_to_hub_model_kwargs(self):
REPO_NAME = repo_name("PUSH_TO_HUB")
Expand Down

0 comments on commit 0773518

Please sign in to comment.