Skip to content

Commit

Permalink
FIX metadata_update: work on a copy of the upstream file, to not me…
Browse files Browse the repository at this point in the history
…ss up the cache (#891)

* Tiny consistency nit for symmetry with `snapshot_download`

@LysandreJik

* `metadata_update`: work on a copy of the upstream file, to not mess up the cache

* Add tests

* Additional test

Co-authored-by: Lysandre <lysandre.debut@reseau.eseo.fr>
  • Loading branch information
julien-c and LysandreJik authored Jun 10, 2022
1 parent 1d3637c commit 8994fd3
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 36 deletions.
5 changes: 3 additions & 2 deletions src/huggingface_hub/file_download.py
Original file line number Diff line number Diff line change
Expand Up @@ -1051,8 +1051,9 @@ def hf_hub_download(
"We have no connection or you passed local_files_only, so"
" force_download is not an accepted option."
)
commit_hash = revision
if not REGEX_COMMIT_HASH.match(revision):
if REGEX_COMMIT_HASH.match(revision):
commit_hash = revision
else:
ref_path = os.path.join(storage_folder, "refs", revision)
with open(ref_path) as f:
commit_hash = f.read()
Expand Down
73 changes: 39 additions & 34 deletions src/huggingface_hub/repocard.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import dataclasses
import os
import re
import shutil
import tempfile
from pathlib import Path
from typing import Any, Dict, Optional, Union

Expand Down Expand Up @@ -149,50 +151,53 @@ def metadata_update(
`str`: URL of the commit which updated the card metadata.
"""

filepath = hf_hub_download(
upstream_filepath = hf_hub_download(
repo_id,
filename=REPOCARD_NAME,
repo_type=repo_type,
use_auth_token=token,
force_download=True,
)
existing_metadata = metadata_load(filepath)
# work on a copy of the upstream file, to not mess up the cache
with tempfile.TemporaryDirectory() as tmpdirname:
filepath = shutil.copy(upstream_filepath, tmpdirname)

for key in metadata:
# update model index containing the evaluation results
if key == "model-index":
if "model-index" not in existing_metadata:
existing_metadata["model-index"] = metadata["model-index"]
else:
# the model-index contains a list of results as used by PwC but only has one element thus we take the first one
existing_metadata["model-index"][0][
"results"
] = _update_metadata_model_index(
existing_metadata["model-index"][0]["results"],
metadata["model-index"][0]["results"],
overwrite=overwrite,
)
# update all fields except model index
else:
if key in existing_metadata and not overwrite:
if existing_metadata[key] != metadata[key]:
raise ValueError(
f"""You passed a new value for the existing meta data field '{key}'. Set `overwrite=True` to overwrite existing metadata."""
existing_metadata = metadata_load(filepath)

for key in metadata:
# update model index containing the evaluation results
if key == "model-index":
if "model-index" not in existing_metadata:
existing_metadata["model-index"] = metadata["model-index"]
else:
# the model-index contains a list of results as used by PwC but only has one element thus we take the first one
existing_metadata["model-index"][0][
"results"
] = _update_metadata_model_index(
existing_metadata["model-index"][0]["results"],
metadata["model-index"][0]["results"],
overwrite=overwrite,
)
# update all fields except model index
else:
existing_metadata[key] = metadata[key]
if key in existing_metadata and not overwrite:
if existing_metadata[key] != metadata[key]:
raise ValueError(
f"""You passed a new value for the existing meta data field '{key}'. Set `overwrite=True` to overwrite existing metadata."""
)
else:
existing_metadata[key] = metadata[key]

# save and push to hub
metadata_save(filepath, existing_metadata)
# save and push to hub
metadata_save(filepath, existing_metadata)

return HfApi().upload_file(
path_or_fileobj=filepath,
path_in_repo=REPOCARD_NAME,
repo_id=repo_id,
repo_type=repo_type,
identical_ok=False,
token=token,
)
return HfApi().upload_file(
path_or_fileobj=filepath,
path_in_repo=REPOCARD_NAME,
repo_id=repo_id,
repo_type=repo_type,
identical_ok=False,
token=token,
)


def _update_metadata_model_index(existing_results, new_results, overwrite=False):
Expand Down
18 changes: 18 additions & 0 deletions tests/test_repocard.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
metadata_update,
)
from huggingface_hub.constants import REPOCARD_NAME
from huggingface_hub.file_download import hf_hub_download
from huggingface_hub.hf_api import HfApi
from huggingface_hub.repository import Repository
from huggingface_hub.utils import logging
Expand Down Expand Up @@ -240,6 +241,23 @@ def test_update_existing_result_with_overwrite(self):
updated_metadata = metadata_load(self.repo_path / self.REPO_NAME / "README.md")
self.assertDictEqual(updated_metadata, new_metadata)

def test_metadata_update_upstream(self):
new_metadata = copy.deepcopy(self.existing_metadata)
new_metadata["model-index"][0]["results"][0]["metrics"][0]["value"] = 0.1

path = hf_hub_download(
f"{USER}/{self.REPO_NAME}",
filename=REPOCARD_NAME,
use_auth_token=self._token,
)

metadata_update(
f"{USER}/{self.REPO_NAME}", new_metadata, token=self._token, overwrite=True
)

self.assertNotEqual(metadata_load(path), new_metadata)
self.assertEqual(metadata_load(path), self.existing_metadata)

def test_update_existing_result_without_overwrite(self):
new_metadata = copy.deepcopy(self.existing_metadata)
new_metadata["model-index"][0]["results"][0]["metrics"][0][
Expand Down

0 comments on commit 8994fd3

Please sign in to comment.