Skip to content

Commit

Permalink
Adding mixin class for ease saving, uploading, downloading (as discus…
Browse files Browse the repository at this point in the history
…sed in issue #9). (#11)

* work initiated

* start upload_to_hub

* add changes

* final-push

* i feel this is better.

* updated for Repositary class

* small updates

* fix mutiple calling

* small fix

* make style

* add everything

* minor fix

* minor fix

* done evrything

* small fix

* [doc] remove mention of TF support

* Fix typings (i think)

* We do NOT want to have a hard requirement on torch

* Fix flake8

* Fix CI

Co-authored-by: Julien Chaumond <julien@huggingface.co>
  • Loading branch information
thevasudevgupta and julien-c authored Mar 18, 2021
1 parent be902d8 commit 9e1d3d5
Show file tree
Hide file tree
Showing 4 changed files with 343 additions and 6 deletions.
1 change: 1 addition & 0 deletions src/huggingface_hub/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,4 +30,5 @@
)
from .file_download import cached_download, hf_hub_url
from .hf_api import HfApi, HfFolder
from .hub_mixin import ModelHubMixin
from .repository import Repository
246 changes: 246 additions & 0 deletions src/huggingface_hub/hub_mixin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,246 @@
import json
import logging
import os
from typing import Dict, Optional

import requests

from .constants import CONFIG_NAME, PYTORCH_WEIGHTS_NAME
from .file_download import cached_download, hf_hub_url, is_torch_available
from .hf_api import HfApi, HfFolder
from .repository import Repository


if is_torch_available():
import torch


logger = logging.getLogger(__name__)


class ModelHubMixin(object):
def __init__(self, *args, **kwargs):
"""
Mix this class with your torch-model class for ease process of saving & loading from huggingface-hub
Example::
>>> from huggingface_hub import ModelHubMixin
>>> class MyModel(nn.Module, ModelHubMixin):
... def __init__(self, **kwargs):
... super().__init__()
... self.config = kwargs.pop("config", None)
... self.layer = ...
... def forward(self, ...)
... return ...
>>> model = MyModel()
>>> model.save_pretrained("mymodel", push_to_hub=False) # Saving model weights in the directory
>>> model.push_to_hub("mymodel", "model-1") # Pushing model-weights to hf-hub
>>> # Downloading weights from hf-hub & model will be initialized from those weights
>>> model = MyModel.from_pretrained("username/mymodel@main")
"""

def save_pretrained(
self,
save_directory: str,
config: Optional[dict] = None,
push_to_hub: bool = False,
**kwargs,
):
"""
Saving weights in local directory.
Parameters:
save_directory (:obj:`str`):
Specify directory in which you want to save weights.
config (:obj:`dict`, `optional`):
specify config (must be dict) incase you want to save it.
push_to_hub (:obj:`bool`, `optional`, defaults to :obj:`False`):
Set it to `True` in case you want to push your weights to huggingface_hub
model_id (:obj:`str`, `optional`, defaults to :obj:`save_directory`):
Repo name in huggingface_hub. If not specified, repo name will be same as `save_directory`
kwargs (:obj:`Dict`, `optional`):
kwargs will be passed to `push_to_hub`
"""

os.makedirs(save_directory, exist_ok=True)

# saving config
if isinstance(config, dict):
path = os.path.join(save_directory, CONFIG_NAME)
with open(path, "w") as f:
json.dump(config, f)

# saving model weights
path = os.path.join(save_directory, PYTORCH_WEIGHTS_NAME)
self._save_pretrained(path)

if push_to_hub:
return self.push_to_hub(save_directory, **kwargs)

def _save_pretrained(self, path):
"""
Overwrite this method in case you don't want to save complete model, rather some specific layers
"""
model_to_save = self.module if hasattr(self, "module") else self
torch.save(model_to_save.state_dict(), path)

@classmethod
def from_pretrained(
cls,
pretrained_model_name_or_path: Optional[str],
strict: bool = True,
map_location: Optional[str] = "cpu",
force_download: bool = False,
resume_download: bool = False,
proxies: Dict = None,
use_auth_token: Optional[str] = None,
cache_dir: Optional[str] = None,
local_files_only: bool = False,
**model_kwargs,
):
r"""
Instantiate a pretrained pytorch model from a pre-trained model configuration from huggingface-hub.
The model is set in evaluation mode by default using ``model.eval()`` (Dropout modules are deactivated). To
train the model, you should first set it back in training mode with ``model.train()``.
Parameters:
pretrained_model_name_or_path (:obj:`str` or :obj:`os.PathLike`, `optional`):
Can be either:
- A string, the `model id` of a pretrained model hosted inside a model repo on huggingface.co.
Valid model ids can be located at the root-level, like ``bert-base-uncased``, or namespaced under
a user or organization name, like ``dbmdz/bert-base-german-cased``.
- You can add `revision` by appending `@` at the end of model_id simply like this: ``dbmdz/bert-base-german-cased@main``
Revision is the specific model version to use. It can be a branch name, a tag name, or a commit id,
since we use a git-based system for storing models and other artifacts on huggingface.co, so ``revision`` can be any identifier allowed by git.
- A path to a `directory` containing model weights saved using
:func:`~transformers.PreTrainedModel.save_pretrained`, e.g., ``./my_model_directory/``.
- :obj:`None` if you are both providing the configuration and state dictionary (resp. with keyword
arguments ``config`` and ``state_dict``).
cache_dir (:obj:`Union[str, os.PathLike]`, `optional`):
Path to a directory in which a downloaded pretrained model configuration should be cached if the
standard cache should not be used.
force_download (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
cached versions if they exist.
resume_download (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not to delete incompletely received files. Will attempt to resume the download if such a
file exists.
proxies (:obj:`Dict[str, str], `optional`):
A dictionary of proxy servers to use by protocol or endpoint, e.g., :obj:`{'http': 'foo.bar:3128',
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
local_files_only(:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not to only look at local files (i.e., do not try to download the model).
use_auth_token (:obj:`str` or `bool`, `optional`):
The token to use as HTTP bearer authorization for remote files. If :obj:`True`, will use the token
generated when running :obj:`transformers-cli login` (stored in :obj:`~/.huggingface`).
model_kwargs (:obj:`Dict`, `optional`)::
model_kwargs will be passed to the model during initialization
.. note::
Passing :obj:`use_auth_token=True` is required when you want to use a private model.
"""

model_id = pretrained_model_name_or_path
map_location = torch.device(map_location)

revision = None
if len(model_id.split("@")) == 2:
model_id, revision = model_id.split("@")

if model_id in os.listdir() and CONFIG_NAME in os.listdir(model_id):
config_file = os.path.join(model_id, CONFIG_NAME)
else:
try:
config_url = hf_hub_url(
model_id, filename=CONFIG_NAME, revision=revision
)
config_file = cached_download(
config_url,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
resume_download=resume_download,
local_files_only=local_files_only,
use_auth_token=use_auth_token,
)
except requests.exceptions.RequestException:
logger.warning("config.json NOT FOUND in HuggingFace Hub")
config_file = None

if model_id in os.listdir():
print("LOADING weights from local directory")
model_file = os.path.join(model_id, PYTORCH_WEIGHTS_NAME)
else:
model_url = hf_hub_url(
model_id, filename=PYTORCH_WEIGHTS_NAME, revision=revision
)
model_file = cached_download(
model_url,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
resume_download=resume_download,
local_files_only=local_files_only,
use_auth_token=use_auth_token,
)

if config_file is not None:
with open(config_file, "r", encoding="utf-8") as f:
config = json.load(f)
model_kwargs.update({"config": config})

model = cls(**model_kwargs)

state_dict = torch.load(model_file, map_location=map_location)
model.load_state_dict(state_dict, strict=strict)
model.eval()

return model

@staticmethod
def push_to_hub(
save_directory: Optional[str],
model_id: Optional[str] = None,
repo_url: Optional[str] = None,
commit_message: Optional[str] = "add model",
organization: Optional[str] = None,
private: bool = None,
) -> str:
"""
Parameters:
save_directory (:obj:`Union[str, os.PathLike]`):
Directory having model weights & config.
model_id (:obj:`str`, `optional`, defaults to :obj:`save_directory`):
Repo name in huggingface_hub. If not specified, repo name will be same as `save_directory`
repo_url (:obj:`str`, `optional`):
Specify this in case you want to push to existing repo in hub.
organization (:obj:`str`, `optional`):
Organization in which you want to push your model.
private (:obj:`bool`, `optional`):
private: Whether the model repo should be private (requires a paid huggingface.co account)
commit_message (:obj:`str`, `optional`, defaults to :obj:`add model`):
Message to commit while pushing
Returns:
url to commit on remote repo.
"""
if model_id is None:
model_id = save_directory

token = HfFolder.get_token()
if repo_url is None:
repo_url = HfApi().create_repo(
token,
model_id,
organization=organization,
private=private,
repo_type=None,
exist_ok=True,
)

repo = Repository(save_directory, clone_from=repo_url, use_auth_token=token)

return repo.push_to_hub(commit_message=commit_message)
36 changes: 30 additions & 6 deletions src/huggingface_hub/repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,14 +159,26 @@ def clone_from(self, repo_url: str, use_auth_token: Union[bool, str, None] = Non
encoding="utf-8",
cwd=self.local_dir,
)
subprocess.run(
["git", "remote", "add", "origin", repo_url],

output = subprocess.run(
"git remote -v".split(),
stderr=subprocess.PIPE,
stdout=subprocess.PIPE,
check=True,
encoding="utf-8",
cwd=self.local_dir,
)

if "origin" not in output.stdout.split():
subprocess.run(
["git", "remote", "add", "origin", repo_url],
stderr=subprocess.PIPE,
stdout=subprocess.PIPE,
check=True,
encoding="utf-8",
cwd=self.local_dir,
)

subprocess.run(
"git fetch".split(),
stderr=subprocess.PIPE,
Expand All @@ -183,15 +195,27 @@ def clone_from(self, repo_url: str, use_auth_token: Union[bool, str, None] = Non
check=True,
cwd=self.local_dir,
)
# TODO(check if we really want the --force flag)
subprocess.run(
"git checkout origin/main -ft".split(),

output = subprocess.run(
"git branch".split(),
stderr=subprocess.PIPE,
stdout=subprocess.PIPE,
encoding="utf-8",
check=True,
encoding="utf-8",
cwd=self.local_dir,
)

if "main" not in output.stdout.split():
# TODO(check if we really want the --force flag)
subprocess.run(
"git checkout origin/main -ft".split(),
stderr=subprocess.PIPE,
stdout=subprocess.PIPE,
encoding="utf-8",
check=True,
cwd=self.local_dir,
)

except subprocess.CalledProcessError as exc:
raise EnvironmentError(exc.stderr)

Expand Down
66 changes: 66 additions & 0 deletions tests/test_hubmixin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
import unittest

from huggingface_hub.file_download import is_torch_available
from huggingface_hub.hub_mixin import ModelHubMixin


if is_torch_available():
import torch.nn as nn


HUGGINGFACE_ID = "vasudevgupta"
DUMMY_REPO_NAME = "dummy"


def require_torch(test_case):
"""
Decorator marking a test that requires PyTorch.
These tests are skipped when PyTorch isn't installed.
"""
if not is_torch_available():
return unittest.skip("test requires PyTorch")(test_case)
else:
return test_case


@require_torch
class DummyModel(ModelHubMixin):
def __init__(self, **kwargs):
super().__init__()
self.config = kwargs.pop("config", None)
self.l1 = nn.Linear(2, 2)

def forward(self, x):
return self.l1(x)


@require_torch
class DummyModelTest(unittest.TestCase):
def test_save_pretrained(self):
model = DummyModel()
model.save_pretrained(DUMMY_REPO_NAME)
model.save_pretrained(
DUMMY_REPO_NAME, config={"num": 12, "act": "gelu"}, push_to_hub=True
)
model.save_pretrained(
DUMMY_REPO_NAME, config={"num": 24, "act": "relu"}, push_to_hub=True
)
model.save_pretrained(
"dummy-wts", config=None, push_to_hub=True, model_id=DUMMY_REPO_NAME
)

def test_from_pretrained(self):
model = DummyModel()
model.save_pretrained(
DUMMY_REPO_NAME, config={"num": 7, "act": "gelu_fast"}, push_to_hub=True
)

model = DummyModel.from_pretrained(f"{HUGGINGFACE_ID}/{DUMMY_REPO_NAME}@main")
self.assertTrue(model.config == {"num": 7, "act": "gelu_fast"})

def test_push_to_hub(self):
model = DummyModel()
model.save_pretrained("dummy-wts", push_to_hub=False)
model.push_to_hub("dummy-wts", model_id=DUMMY_REPO_NAME)

0 comments on commit 9e1d3d5

Please sign in to comment.