Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

huggingface-cli upload - Validate README.md before file hashing #2452

Merged
merged 1 commit into from
Aug 16, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
159 changes: 102 additions & 57 deletions src/huggingface_hub/hf_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -3797,26 +3797,10 @@ def create_commit(
for addition in additions:
if addition.path_in_repo == "README.md":
with addition.as_file() as file:
response = get_session().post(
f"{ENDPOINT}/api/validate-yaml",
json={"content": file.read().decode(), "repoType": repo_type},
headers=headers,
)
# Handle warnings (example: empty metadata)
response_content = response.json()
message = "\n".join(
[f"- {warning.get('message')}" for warning in response_content.get("warnings", [])]
)
if message:
warnings.warn(f"Warnings while validating metadata in README.md:\n{message}")

# Raise on errors
try:
hf_raise_for_status(response)
except BadRequestError as e:
errors = response_content.get("errors", [])
message = "\n".join([f"- {error.get('message')}" for error in errors])
raise ValueError(f"Invalid metadata in README.md.\n{message}") from e
content = file.read().decode()
self._validate_yaml(content, repo_type=repo_type, token=token)
# Skip other additions after `README.md` has been processed
break

# If updating twice the same file or update then delete a file in a single commit
_warn_on_overwriting_operations(operations)
Expand Down Expand Up @@ -4875,11 +4859,13 @@ def upload_folder(
path_in_repo=path_in_repo,
delete_patterns=delete_patterns,
)
add_operations = _prepare_upload_folder_additions(
add_operations = self._prepare_upload_folder_additions(
folder_path,
path_in_repo,
allow_patterns=allow_patterns,
ignore_patterns=ignore_patterns,
token=token,
repo_type=repo_type,
)

# Optimize operations: if some files will be overwritten, we don't need to delete them first
Expand Down Expand Up @@ -9182,6 +9168,101 @@ def _prepare_folder_deletions(
if relpath_to_abspath[relpath] != ".gitattributes"
]

def _prepare_upload_folder_additions(
self,
folder_path: Union[str, Path],
path_in_repo: str,
allow_patterns: Optional[Union[List[str], str]] = None,
ignore_patterns: Optional[Union[List[str], str]] = None,
repo_type: Optional[str] = None,
token: Union[bool, str, None] = None,
) -> List[CommitOperationAdd]:
"""Generate the list of Add operations for a commit to upload a folder.

Files not matching the `allow_patterns` (allowlist) and `ignore_patterns` (denylist)
constraints are discarded.
"""

folder_path = Path(folder_path).expanduser().resolve()
if not folder_path.is_dir():
raise ValueError(f"Provided path: '{folder_path}' is not a directory")

# List files from folder
relpath_to_abspath = {
path.relative_to(folder_path).as_posix(): path
for path in sorted(folder_path.glob("**/*")) # sorted to be deterministic
if path.is_file()
}

# Filter files
# Patterns are applied on the path relative to `folder_path`. `path_in_repo` is prefixed after the filtering.
filtered_repo_objects = list(
filter_repo_objects(
relpath_to_abspath.keys(), allow_patterns=allow_patterns, ignore_patterns=ignore_patterns
)
)

prefix = f"{path_in_repo.strip('/')}/" if path_in_repo else ""

# If updating a README.md file, make sure the metadata format is valid
# It's better to fail early than to fail after all the files have been hashed.
Wauplin marked this conversation as resolved.
Show resolved Hide resolved
if "README.md" in filtered_repo_objects:
self._validate_yaml(
content=relpath_to_abspath["README.md"].read_text(),
repo_type=repo_type,
token=token,
)

return [
CommitOperationAdd(
path_or_fileobj=relpath_to_abspath[relpath], # absolute path on disk
path_in_repo=prefix + relpath, # "absolute" path in repo
)
for relpath in filtered_repo_objects
]

def _validate_yaml(self, content: str, *, repo_type: Optional[str] = None, token: Union[bool, str, None] = None):
"""
Validate YAML from `README.md`, used before file hashing and upload.

Args:
content (`str`):
Content of `README.md` to validate.
repo_type (`str`, *optional*):
The type of the repo to grant access to. Must be one of `model`, `dataset` or `space`.
Defaults to `model`.
token (Union[bool, str, None], optional):
A valid user access token (string). Defaults to the locally saved
token, which is the recommended method for authentication (see
https://huggingface.co/docs/huggingface_hub/quick-start#authentication).
To disable authentication, pass `False`.

Raises:
- [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError)
if YAML is invalid
"""
repo_type = repo_type if repo_type is not None else REPO_TYPE_MODEL
headers = self._build_hf_headers(token=token)

response = get_session().post(
f"{self.endpoint}/api/validate-yaml",
json={"content": content, "repoType": repo_type},
headers=headers,
)
# Handle warnings (example: empty metadata)
response_content = response.json()
message = "\n".join([f"- {warning.get('message')}" for warning in response_content.get("warnings", [])])
if message:
warnings.warn(f"Warnings while validating metadata in README.md:\n{message}")

# Raise on errors
try:
hf_raise_for_status(response)
except BadRequestError as e:
errors = response_content.get("errors", [])
message = "\n".join([f"- {error.get('message')}" for error in errors])
raise ValueError(f"Invalid metadata in README.md.\n{message}") from e

def get_user_overview(self, username: str) -> User:
"""
Get an overview of a user on the Hub.
Expand Down Expand Up @@ -9275,42 +9356,6 @@ def list_user_following(self, username: str) -> Iterable[User]:
yield User(**followed_user)


def _prepare_upload_folder_additions(
folder_path: Union[str, Path],
path_in_repo: str,
allow_patterns: Optional[Union[List[str], str]] = None,
ignore_patterns: Optional[Union[List[str], str]] = None,
) -> List[CommitOperationAdd]:
"""Generate the list of Add operations for a commit to upload a folder.

Files not matching the `allow_patterns` (allowlist) and `ignore_patterns` (denylist)
constraints are discarded.
"""
folder_path = Path(folder_path).expanduser().resolve()
if not folder_path.is_dir():
raise ValueError(f"Provided path: '{folder_path}' is not a directory")

# List files from folder
relpath_to_abspath = {
path.relative_to(folder_path).as_posix(): path
for path in sorted(folder_path.glob("**/*")) # sorted to be deterministic
if path.is_file()
}

# Filter files and return
# Patterns are applied on the path relative to `folder_path`. `path_in_repo` is prefixed after the filtering.
prefix = f"{path_in_repo.strip('/')}/" if path_in_repo else ""
return [
CommitOperationAdd(
path_or_fileobj=relpath_to_abspath[relpath], # absolute path on disk
path_in_repo=prefix + relpath, # "absolute" path in repo
)
for relpath in filter_repo_objects(
relpath_to_abspath.keys(), allow_patterns=allow_patterns, ignore_patterns=ignore_patterns
)
]


def _parse_revision_from_pr_url(pr_url: str) -> str:
"""Safely parse revision number from a PR url.

Expand Down
Loading