Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TensorBoard for Keras models #651

Merged
merged 16 commits into from
Feb 7, 2022
8 changes: 7 additions & 1 deletion src/huggingface_hub/keras_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import logging
import os
from pathlib import Path
from shutil import copytree
from typing import Any, Dict, Optional, Union

from huggingface_hub import ModelHubMixin
Expand Down Expand Up @@ -71,6 +72,7 @@ def push_to_hub_keras(
model,
repo_path_or_name: Optional[str] = None,
repo_url: Optional[str] = None,
log_dir: Optional[str] = None,
commit_message: Optional[str] = "Add model",
organization: Optional[str] = None,
private: Optional[bool] = None,
Expand All @@ -97,6 +99,9 @@ def push_to_hub_keras(
Specify this in case you want to push to an existing repository in the hub. If unspecified, a new
repository will be created in your namespace (unless you specify an :obj:`organization`) with
:obj:`repo_name`.
log_dir (:obj:`str`, `optional`):
TensorBoard logging directory to be pushed. The Hub automatically hosts
and displays a TensorBoard instance if log files are included in the repository.
commit_message (:obj:`str`, `optional`):
Message to commit while pushing. Will default to :obj:`"add model"`.
organization (:obj:`str`, `optional`):
Expand Down Expand Up @@ -173,7 +178,8 @@ def push_to_hub_keras(
include_optimizer=include_optimizer,
**model_save_kwargs,
)

if log_dir is not None:
copytree(log_dir, f"{repo_path_or_name}/logs")
# Commit and push!
repo.git_add(auto_lfs_track=True)
repo.git_commit(commit_message)
Expand Down
26 changes: 26 additions & 0 deletions tests/test_keras_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,32 @@ def test_push_to_hub(self):

self._api.delete_repo(name=f"{REPO_NAME}", token=self._token)

def test_push_to_hub_tensorboard(self):
os.makedirs(f"{WORKING_REPO_DIR}/tb_log_dir")
with open(f"{WORKING_REPO_DIR}/tb_log_dir/tensorboard.txt", "w") as fp:
fp.write("Keras FTW")
REPO_NAME = repo_name("PUSH_TO_HUB")
model = self.model_init()
model.build((None, 2))
push_to_hub_keras(
model,
repo_path_or_name=f"{WORKING_REPO_DIR}/{REPO_NAME}",
log_dir=f"{WORKING_REPO_DIR}/tb_log_dir",
api_endpoint=ENDPOINT_STAGING,
use_auth_token=self._token,
git_user="ci",
git_email="ci@dummy.com",
)
model_info = HfApi(endpoint=ENDPOINT_STAGING).model_info(
f"{USER}/{REPO_NAME}",
)

self.assertTrue(
osanseviero marked this conversation as resolved.
Show resolved Hide resolved
"logs/tensorboard.txt" in [f.rfilename for f in model_info.siblings]
)

self._api.delete_repo(name=f"{REPO_NAME}", token=self._token)

def test_push_to_hub_model_kwargs(self):
REPO_NAME = repo_name("PUSH_TO_HUB")
model = self.model_init()
Expand Down