Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Non-git mixin #8

Merged
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
cf0a577
FIX typos in contributing.md
Wauplin Aug 3, 2022
396ed56
Remove redefined logger in HfApi.py
Wauplin Aug 3, 2022
d34b020
Use upload_folder in both mixins + some docstring
Wauplin Aug 3, 2022
bab52c2
moved back logger to top of hf_api ><
Wauplin Aug 3, 2022
198a425
space in documentation can be ambiguous
Wauplin Aug 3, 2022
4d3f9e4
WIP started deprecation
Wauplin Aug 3, 2022
0a55017
deprecate skip lfs file and use Path
Wauplin Aug 4, 2022
ada6607
added decorator to deprecate specific arguments + unittests for it
Wauplin Aug 4, 2022
bc0425a
simplified tests
Wauplin Aug 4, 2022
0b3eb5d
proper decorators
Wauplin Aug 4, 2022
4de8ea9
fix docstring
Wauplin Aug 4, 2022
f71b482
hubmixin: fixed existing tests + add http one
Wauplin Aug 4, 2022
b3b5bac
unique repo names across tests
Wauplin Aug 4, 2022
f1580ba
make push_to_hub_keras work + tests
Wauplin Aug 4, 2022
9b12475
logs are not overwritten in push_to_hub_keras
Wauplin Aug 4, 2022
41cc62a
flake8
Wauplin Aug 5, 2022
d039ad5
refacto push_to_hub from mixin.save_pretrained
Wauplin Aug 8, 2022
50069b1
deprecate positional argument in version 0.12
Wauplin Aug 8, 2022
ed06809
remove docstring for deprecated skip_lfs_files
Wauplin Aug 8, 2022
bea0b77
delete old logs when uploading keras model to hub
Wauplin Aug 8, 2022
bf87876
remove TODO in tests
Wauplin Aug 8, 2022
d593f75
remove useless todo
Wauplin Aug 8, 2022
8c09885
Update src/huggingface_hub/hub_mixin.py
Wauplin Aug 8, 2022
76a03c2
flake8
Wauplin Aug 8, 2022
12be988
Merge branch 'wauplin-non-git-mixin' of github.com:huggingface/huggin…
Wauplin Aug 8, 2022
9965f0e
remove un-explicit _generate_url helper
Wauplin Aug 8, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ venv/
ENV/
env.bak/
venv.bak/
.venv*

# Spyder project settings
.spyderproject
Expand Down
5 changes: 3 additions & 2 deletions CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -258,13 +258,14 @@ repository they can be run with the following:

```bash
$ HUGGINGFACE_CO_STAGING=1 python -m pytest -sv ./tests
```

In fact, that's how `make test` is implemented (sans the `pip install` line)!
In fact, that's how `make test` is implemented (without the `pip install` line)!

You can specify a smaller set of tests in order to test only the feature
you're working on.

For example, the following will only run the tests hel in the `test_repository.py` file:
For example, the following will only run the tests in the `test_repository.py` file:

```bash
$ HUGGINGFACE_CO_STAGING=1 python -m pytest -sv ./tests/test_repository.py
Expand Down
2 changes: 1 addition & 1 deletion docs/source/how-to-manage.mdx
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Create and manage a repository

A repository is a space for you to store your model or dataset files. This guide will show you how to:
A repository is a place where you can store your model or dataset files. This guide will show you how to:

* Create and delete a repository.
* Adjust repository visibility.
Expand Down
6 changes: 1 addition & 5 deletions src/huggingface_hub/hf_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,6 @@

logger = logging.get_logger(__name__)


# TODO: remove after deprecation period is over (v0.10)
def _validate_repo_id_deprecation(repo_id, name, organization):
"""Returns (name, organization) from the input."""
Expand Down Expand Up @@ -100,9 +99,6 @@ def _validate_repo_id_deprecation(repo_id, name, organization):
return name, organization


logger = logging.get_logger(__name__)


def repo_type_and_id_from_hf_id(
hf_id: str, hub_url: Optional[str] = None
) -> Tuple[Optional[str], Optional[str], str]:
Expand Down Expand Up @@ -1431,7 +1427,7 @@ def list_repo_files(
)
return [f.rfilename for f in repo_info.siblings]

@_deprecate_positional_args
@_deprecate_positional_args(version="0.8")
Wauplin marked this conversation as resolved.
Show resolved Hide resolved
Wauplin marked this conversation as resolved.
Show resolved Hide resolved
def create_repo(
self,
repo_id: str = None,
Expand Down
196 changes: 141 additions & 55 deletions src/huggingface_hub/hub_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,17 @@
import os
import tempfile
from pathlib import Path
from typing import Dict, List, Optional
from typing import Dict, List, Optional, Union

import requests
from huggingface_hub import hf_api

from .constants import CONFIG_NAME, PYTORCH_WEIGHTS_NAME
from .file_download import hf_hub_download, is_torch_available
from .hf_api import HfApi
from .hf_api import HfApi, HfFolder
from .repository import Repository
from .utils import logging
from .utils._deprecation import _deprecate_positional_args
from .utils._deprecation import _deprecate_arguments, _deprecate_positional_args


if is_torch_available():
Expand All @@ -33,44 +34,45 @@ def save_pretrained(
self,
save_directory: str,
config: Optional[dict] = None,
# NOTE: without new arguments this behavior (with push_to_hub=True), will be
# deprecated as well. Should we already plan to extend the "save_pretrained"
# function or should we let this feature die and then force users to do
# `model.save_pretrained(...)` and `model.push_to_hub(...)` separately ?
# Main argument that is missing I think is "repo_id". All other are optional but
# nice to have.
# TODO: To be discussed and to remove before merging PR
Wauplin marked this conversation as resolved.
Show resolved Hide resolved
push_to_hub: bool = False,
**kwargs,
):
"""
Save weights in local directory.

Parameters:
save_directory (`str`):
Specify directory in which you want to save weights.
config (`dict`, *optional*):
specify config (must be dict) in case you want to save
it.
push_to_hub (`bool`, *optional*, defaults to `False`):
Set it to `True` in case you want to push your weights
to huggingface_hub
kwargs (`Dict`, *optional*):
kwargs will be passed to `push_to_hub`
Parameters:
save_directory (`str`):
Specify directory in which you want to save weights.
config (`dict`, *optional*):
specify config (must be dict) in case you want to save
it.
push_to_hub (`bool`, *optional*, defaults to `False`):
Set it to `True` in case you want to push your weights to huggingface_hub.
kwargs (`Dict`, *optional*):
kwargs will be passed to `push_to_hub`
"""

os.makedirs(save_directory, exist_ok=True)

# saving model weights/files
files = self._save_pretrained(save_directory)
self._save_pretrained(save_directory)

# saving config
if isinstance(config, dict):
path = os.path.join(save_directory, CONFIG_NAME)
with open(path, "w") as f:
json.dump(config, f)

files.append(path)

if push_to_hub:
return self.push_to_hub(save_directory, **kwargs)

return files

def _save_pretrained(self, save_directory: str) -> List[str]:
def _save_pretrained(self, save_directory: str):
"""
Overwrite this method in subclass to define how to save your model.
"""
Expand Down Expand Up @@ -214,18 +216,48 @@ def _from_pretrained(
pretrained"""
raise NotImplementedError

@_deprecate_positional_args(version=0.8)
@_deprecate_positional_args(version="0.12")
@_deprecate_arguments(
version="0.12",
deprecated_args={
"repo_url",
"repo_path_or_name",
"organization",
"use_auth_token",
"git_user",
"git_email",
"skip_lfs_files",
},
)
Wauplin marked this conversation as resolved.
Show resolved Hide resolved
def push_to_hub(
self,
repo_id: str,
*,
# NOTE: deprecated signature that will change in 0.12
repo_path_or_name: Optional[str] = None,
repo_url: Optional[str] = None,
commit_message: Optional[str] = "Add model",
organization: Optional[str] = None,
private: Optional[bool] = None,
api_endpoint: Optional[str] = None,
token: Optional[str] = None,
branch: Optional[str] = None,
use_auth_token: Optional[Union[bool, str]] = None,
git_user: Optional[str] = None,
git_email: Optional[str] = None,
config: Optional[dict] = None,
skip_lfs_files: bool = False,
# NOTE: New arguments since 0.9
repo_id: Optional[str] = None, # optional only until 0.12
token: Optional[str] = None,
branch: Optional[str] = None,
create_pr: Optional[bool] = None,
# TODO (release 0.12): signature must be the following
# repo_id: str,
# *,
# commit_message: Optional[str] = "Add model",
# private: Optional[bool] = None,
# api_endpoint: Optional[str] = None,
# token: Optional[str] = None,
# branch: Optional[str] = None,
# create_pr: Optional[bool] = None,
# config: Optional[dict] = None,
) -> str:
"""
Upload model checkpoint to the Hub.
Expand All @@ -243,50 +275,104 @@ def push_to_hub(
The token to use as HTTP bearer authorization for remote files.
If not set, will use the token set when logging in with
`transformers-cli login` (stored in `~/.huggingface`).
branch (Optional :obj:`str`):
The git branch on which to push the dataset. This defaults to
branch (`str`, *optional*):
The git branch on which to push the model. This defaults to
the default branch as specified in your repository, which
defaults to `"main"`.
create_pr (`boolean`, *optional*):
Whether or not to create a Pull Request from `branch` with that commit.
Defaults to `False`.
config (`dict`, *optional*):
Configuration object to be saved alongside the model weights.
skip_lfs_files (`bool`, *optional*, defaults to `False`):
Whether to skip git-LFS files or not.
Wauplin marked this conversation as resolved.
Show resolved Hide resolved


Returns:
The url of the commit of your model in the given repository.
"""
# Repo id is set means we use the new version using HTTP endpoint
# (introduced v0.9)
if repo_id is not None:
token, _ = hf_api._validate_or_retrieve_token(token)
api = HfApi(endpoint=api_endpoint)

api.create_repo(
repo_id=repo_id,
repo_type="model",
token=token,
private=private,
exist_ok=True,
)

# Push the files to the repo in a single commit
with tempfile.TemporaryDirectory() as tmp:
saved_path = Path(tmp) / repo_id
self.save_pretrained(saved_path, config=config)
return api.upload_folder(
repo_id=repo_id,
repo_type="model",
path_in_repo="",
token=token,
folder_path=saved_path,
commit_message=commit_message,
revision=branch,
create_pr=create_pr,
)

Wauplin marked this conversation as resolved.
Show resolved Hide resolved
# Repo id is None means we use the deprecated version using Git
# TODO: remove code between here and `return repo.git_push()` in release 0.12
if repo_path_or_name is None and repo_url is None:
raise ValueError(
"You need to specify a `repo_path_or_name` or a `repo_url`."
)

token, _ = hf_api._validate_or_retrieve_token(token)
api = HfApi(endpoint=api_endpoint)
if use_auth_token is None and repo_url is None:
token = HfFolder.get_token()
if token is None:
raise ValueError(
"You must login to the Hugging Face hub on this computer by typing"
" `huggingface-cli login` and entering your credentials to use"
" `use_auth_token=True`. Alternatively, you can pass your own token"
" as the `use_auth_token` argument."
)
elif isinstance(use_auth_token, str):
token = use_auth_token
else:
token = None

if repo_path_or_name is None:
repo_path_or_name = repo_url.split("/")[-1]

# If no URL is passed and there's no path to a directory containing files, create a repo
if repo_url is None and not os.path.exists(repo_path_or_name):
repo_id = Path(repo_path_or_name).name
if organization:
repo_id = f"{organization}/{repo_id}"
repo_url = HfApi(endpoint=api_endpoint).create_repo(
repo_id=repo_id,
token=token,
private=private,
repo_type=None,
exist_ok=True,
)

api.create_repo(
repo_id=repo_id,
token=token,
private=private,
repo_type=None,
exist_ok=True,
repo = Repository(
repo_path_or_name,
clone_from=repo_url,
use_auth_token=use_auth_token,
git_user=git_user,
git_email=git_email,
skip_lfs_files=skip_lfs_files,
)
repo.git_pull(rebase=True)

# Save the files in the cloned repo
with tempfile.TemporaryDirectory() as tmp:
saved_path = Path(tmp) / repo_id
self.save_pretrained(saved_path, config=config)

for path, currentDirectory, files in os.walk(saved_path):
for filename in files:
file = os.path.join(path, filename)
common_prefix = os.path.commonprefix([saved_path, file])
relative_path = os.path.relpath(file, common_prefix)

api.upload_file(
path_or_fileobj=file,
path_in_repo=relative_path,
token=token,
repo_id=repo_id,
revision=branch,
commit_message=commit_message,
)
self.save_pretrained(repo_path_or_name, config=config)

# Commit and push!
repo.git_add(auto_lfs_track=True)
repo.git_commit(commit_message)
return repo.git_push()


class PyTorchModelHubMixin(ModelHubMixin):
Expand Down
Loading