diff --git a/.circleci/config.yml b/.circleci/config.yml index 92b5b2ae058bea..c4e9b79c703cd5 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -317,24 +317,33 @@ jobs: - store_artifacts: path: ~/transformers/reports - run_tests_git_lfs: + run_tests_hub: working_directory: ~/transformers docker: - image: circleci/python:3.7 environment: + HUGGINGFACE_CO_STAGING: yes RUN_GIT_LFS_TESTS: yes TRANSFORMERS_IS_CI: yes resource_class: xlarge parallelism: 1 steps: - checkout + - restore_cache: + keys: + - v0.4-hub-{{ checksum "setup.py" }} + - v0.4-{{ checksum "setup.py" }} - run: sudo apt-get install git-lfs - run: | git config --global user.email "ci@dummy.com" git config --global user.name "ci" - run: pip install --upgrade pip - - run: pip install .[testing] - - run: python -m pytest -sv ./tests/test_hf_api.py -k "HfLargefilesTest" + - run: pip install .[torch,sentencepiece,testing] + - save_cache: + key: v0.4-hub-{{ checksum "setup.py" }} + paths: + - '~/.cache/pip' + - run: python -m pytest -sv ./tests/ -m is_staging_test build_doc: working_directory: ~/transformers @@ -469,7 +478,7 @@ workflows: - run_tests_flax - run_tests_pipelines_torch - run_tests_pipelines_tf - - run_tests_git_lfs + - run_tests_hub - build_doc - deploy_doc: *workflow_filters # tpu_testing_jobs: diff --git a/docs/source/main_classes/model.rst b/docs/source/main_classes/model.rst index fef1426fa9a9f2..0f93bec8cef47c 100644 --- a/docs/source/main_classes/model.rst +++ b/docs/source/main_classes/model.rst @@ -73,3 +73,10 @@ Generation .. autoclass:: transformers.generation_tf_utils.TFGenerationMixin :members: + + +Pushing to the Hub +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.file_utils.PushToHubMixin + :members: diff --git a/docs/source/model_sharing.rst b/docs/source/model_sharing.rst index 3c6a9fc8353531..5c545695b38339 100644 --- a/docs/source/model_sharing.rst +++ b/docs/source/model_sharing.rst @@ -22,8 +22,6 @@ the `model hub `__. Optionally, you can join an existing organization or create a new one. -Prepare your model for uploading -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ We have seen in the :doc:`training tutorial `: how to fine-tune a model on a given task. You have probably done something similar on your task, either using the model directly in your own training loop or using the @@ -31,7 +29,7 @@ done something similar on your task, either using the model directly in your own `model hub `__. Model versioning -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Since version v3.5.0, the model hub has built-in model versioning based on git and git-lfs. It is based on the paradigm that one model *is* one repo. @@ -54,6 +52,106 @@ For instance: >>> revision="v2.0.1" # tag name, or branch name, or commit hash >>> ) + +Push your model from Python +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Preparation +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +The first step is to make sure your credentials to the hub are stored somewhere. This can be done in two ways. If you +have access to a terminal, you cam just run the following command in the virtual environment where you installed 🤗 +Transformers: + +.. code-block:: bash + + transformers-cli login + +It will store your access token in the Hugging Face cache folder (by default :obj:`~/.cache/`). + +If you don't have an easy access to a terminal (for instance in a Colab session), you can find a token linked to your +acount by going on `huggingface.co `, click on your avatar on the top left corner, then on +`Edit profile` on the left, just beneath your profile picture. In the submenu `API Tokens`, you will find your API +token that you can just copy. + +Directly push your model to the hub +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Once you have an API token (either stored in the cache or copied and pasted in your notebook), you can directly push a +finetuned model you saved in :obj:`save_drectory` by calling: + +.. code-block:: python + + finetuned_model.push_to_hub("my-awesome-model") + +If you have your API token not stored in the cache, you will need to pass it with :obj:`use_auth_token=your_token`. +This is also be the case for all the examples below, so we won't mention it again. + +This will create a repository in your namespace name :obj:`my-awesome-model`, so anyone can now run: + +.. code-block:: python + + from transformers import AutoModel + + model = AutoModel.from_pretrained("your_username/my-awesome-model") + +Even better, you can combine this push to the hub with the call to :obj:`save_pretrained`: + +.. code-block:: python + + finetuned_model.save_pretrained(save_directory, push_to_hub=True, repo_name="my-awesome-model") + +If you are a premium user and want your model to be private, just add :obj:`private=True` to this call. + +If you are a member of an organization and want to push it inside the namespace of the organization instead of yours, +just add :obj:`organization=my_amazing_org`. + +Add new files to your model repo +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Once you have pushed your model to the hub, you might want to add the tokenizer, or a version of your model for another +framework (TensorFlow, PyTorch, Flax). This is super easy to do! Let's begin with the tokenizer. You can add it to the +repo you created before like this + +.. code-block:: python + + tokenizer.push_to_hub("my-awesome-model") + +If you know its URL (it should be :obj:`https://huggingface.co/username/repo_name`), you can also do: + +.. code-block:: python + + tokenizer.push_to_hub(repo_url=my_repo_url) + +And that's all there is to it! It's also a very easy way to fix a mistake if one of the files online had a bug. + +To add a model for another backend, it's also super easy. Let's say you have fine-tuned a TensorFlow model and want to +add the pytorch model files to your model repo, so that anyone in the community can use it. The following allows you to +directly create a PyTorch version of your TensorFlow model: + +.. code-block:: python + + from transfomers import AutoModel + + model = AutoModel.from_pretrained(save_directory, from_tf=True) + +You can also replace :obj:`save_directory` by the identifier of your model (:obj:`username/repo_name`) if you don't +have a local save of it anymore. Then, just do the same as before: + +.. code-block:: python + + model.push_to_hub("my-awesome-model") + +or + +.. code-block:: python + + model.push_to_hub(repo_url=my_repo_url) + + +Use your terminal and git +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + Basic steps ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ diff --git a/examples/legacy/text-classification/run_tf_text_classification.py b/examples/legacy/text-classification/run_tf_text_classification.py old mode 100644 new mode 100755 diff --git a/examples/pytorch/language-modeling/run_clm.py b/examples/pytorch/language-modeling/run_clm.py index 3bc3fe40ff5da7..ea4b4f0934b09c 100755 --- a/examples/pytorch/language-modeling/run_clm.py +++ b/examples/pytorch/language-modeling/run_clm.py @@ -447,6 +447,9 @@ def group_texts(examples): trainer.log_metrics("eval", metrics) trainer.save_metrics("eval", metrics) + if training_args.push_to_hub: + trainer.push_to_hub() + def _mp_fn(index): # For xla_spawn (TPUs) diff --git a/examples/pytorch/language-modeling/run_mlm.py b/examples/pytorch/language-modeling/run_mlm.py index 1cb9fcc6cdd482..d0932e5f93fde2 100755 --- a/examples/pytorch/language-modeling/run_mlm.py +++ b/examples/pytorch/language-modeling/run_mlm.py @@ -476,6 +476,9 @@ def group_texts(examples): trainer.log_metrics("eval", metrics) trainer.save_metrics("eval", metrics) + if training_args.push_to_hub: + trainer.push_to_hub() + def _mp_fn(index): # For xla_spawn (TPUs) diff --git a/examples/pytorch/language-modeling/run_plm.py b/examples/pytorch/language-modeling/run_plm.py index b5734815781579..d7136144729e1c 100755 --- a/examples/pytorch/language-modeling/run_plm.py +++ b/examples/pytorch/language-modeling/run_plm.py @@ -452,6 +452,9 @@ def group_texts(examples): trainer.log_metrics("eval", metrics) trainer.save_metrics("eval", metrics) + if training_args.push_to_hub: + trainer.push_to_hub() + def _mp_fn(index): # For xla_spawn (TPUs) diff --git a/examples/pytorch/multiple-choice/run_swag.py b/examples/pytorch/multiple-choice/run_swag.py index de012171938c68..e73e13492f5b6b 100755 --- a/examples/pytorch/multiple-choice/run_swag.py +++ b/examples/pytorch/multiple-choice/run_swag.py @@ -428,6 +428,9 @@ def compute_metrics(eval_predictions): trainer.log_metrics("eval", metrics) trainer.save_metrics("eval", metrics) + if training_args.push_to_hub: + trainer.push_to_hub() + def _mp_fn(index): # For xla_spawn (TPUs) diff --git a/examples/pytorch/question-answering/run_qa.py b/examples/pytorch/question-answering/run_qa.py index d275a2992212c9..949456950500c3 100755 --- a/examples/pytorch/question-answering/run_qa.py +++ b/examples/pytorch/question-answering/run_qa.py @@ -599,6 +599,9 @@ def compute_metrics(p: EvalPrediction): trainer.log_metrics("test", metrics) trainer.save_metrics("test", metrics) + if training_args.push_to_hub: + trainer.push_to_hub() + def _mp_fn(index): # For xla_spawn (TPUs) diff --git a/examples/pytorch/question-answering/run_qa_beam_search.py b/examples/pytorch/question-answering/run_qa_beam_search.py index 490ecae86eb79b..ca8d620fe0e014 100755 --- a/examples/pytorch/question-answering/run_qa_beam_search.py +++ b/examples/pytorch/question-answering/run_qa_beam_search.py @@ -638,6 +638,9 @@ def compute_metrics(p: EvalPrediction): trainer.log_metrics("test", metrics) trainer.save_metrics("test", metrics) + if training_args.push_to_hub: + trainer.push_to_hub() + def _mp_fn(index): # For xla_spawn (TPUs) diff --git a/examples/pytorch/summarization/run_summarization.py b/examples/pytorch/summarization/run_summarization.py index 0856d6dce5cb22..25a117010e3e53 100755 --- a/examples/pytorch/summarization/run_summarization.py +++ b/examples/pytorch/summarization/run_summarization.py @@ -579,6 +579,9 @@ def compute_metrics(eval_preds): with open(output_test_preds_file, "w") as writer: writer.write("\n".join(test_preds)) + if training_args.push_to_hub: + trainer.push_to_hub() + return results diff --git a/examples/pytorch/text-classification/run_glue.py b/examples/pytorch/text-classification/run_glue.py index 0af73cf5ddb59f..eff048a65a0b14 100755 --- a/examples/pytorch/text-classification/run_glue.py +++ b/examples/pytorch/text-classification/run_glue.py @@ -517,6 +517,9 @@ def compute_metrics(p: EvalPrediction): item = label_list[item] writer.write(f"{index}\t{item}\n") + if training_args.push_to_hub: + trainer.push_to_hub() + def _mp_fn(index): # For xla_spawn (TPUs) diff --git a/examples/pytorch/token-classification/run_ner.py b/examples/pytorch/token-classification/run_ner.py index 2f31b1f64da55f..2907a78033b1dc 100755 --- a/examples/pytorch/token-classification/run_ner.py +++ b/examples/pytorch/token-classification/run_ner.py @@ -491,6 +491,9 @@ def compute_metrics(p): for prediction in true_predictions: writer.write(" ".join(prediction) + "\n") + if training_args.push_to_hub: + trainer.push_to_hub() + def _mp_fn(index): # For xla_spawn (TPUs) diff --git a/examples/pytorch/translation/run_translation.py b/examples/pytorch/translation/run_translation.py index e42dce9cf626de..24ab358cc30da4 100755 --- a/examples/pytorch/translation/run_translation.py +++ b/examples/pytorch/translation/run_translation.py @@ -571,6 +571,9 @@ def compute_metrics(eval_preds): with open(output_test_preds_file, "w") as writer: writer.write("\n".join(test_preds)) + if training_args.push_to_hub: + trainer.push_to_hub() + return results diff --git a/hubconf.py b/hubconf.py index c23d5ed8ed2f90..6c60cd4213d5c4 100644 --- a/hubconf.py +++ b/hubconf.py @@ -31,7 +31,7 @@ ) -dependencies = ["torch", "numpy", "tokenizers", "filelock", "requests", "tqdm", "regex", "sentencepiece", "sacremoses", "importlib_metadata"] +dependencies = ["torch", "numpy", "tokenizers", "filelock", "requests", "tqdm", "regex", "sentencepiece", "sacremoses", "importlib_metadata", "huggingface_hub"] @add_start_docstrings(AutoConfig.__doc__) diff --git a/src/transformers/configuration_utils.py b/src/transformers/configuration_utils.py index 3aa671251c89dc..2a1ad215158b4a 100755 --- a/src/transformers/configuration_utils.py +++ b/src/transformers/configuration_utils.py @@ -22,14 +22,14 @@ from typing import Any, Dict, Tuple, Union from . import __version__ -from .file_utils import CONFIG_NAME, cached_path, hf_bucket_url, is_offline_mode, is_remote_url +from .file_utils import CONFIG_NAME, PushToHubMixin, cached_path, hf_bucket_url, is_offline_mode, is_remote_url from .utils import logging logger = logging.get_logger(__name__) -class PretrainedConfig(object): +class PretrainedConfig(PushToHubMixin): r""" Base class for all configuration classes. Handles a few parameters common to all models' configurations as well as methods for loading/downloading/saving configurations. @@ -310,7 +310,7 @@ def num_labels(self, num_labels: int): self.id2label = {i: f"LABEL_{i}" for i in range(num_labels)} self.label2id = dict(zip(self.id2label.values(), self.id2label.keys())) - def save_pretrained(self, save_directory: Union[str, os.PathLike]): + def save_pretrained(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, **kwargs): """ Save a configuration object to the directory ``save_directory``, so that it can be re-loaded using the :func:`~transformers.PretrainedConfig.from_pretrained` class method. @@ -318,6 +318,11 @@ def save_pretrained(self, save_directory: Union[str, os.PathLike]): Args: save_directory (:obj:`str` or :obj:`os.PathLike`): Directory where the configuration JSON file will be saved (will be created if it does not exist). + push_to_hub (:obj:`bool`, `optional`, defaults to :obj:`False`): + Whether or not to push your model to the Hugging Face model hub after saving it. + kwargs: + Additional key word arguments passed along to the + :meth:`~transformers.file_utils.PushToHubMixin.push_to_hub` method. """ if os.path.isfile(save_directory): raise AssertionError(f"Provided path ({save_directory}) should be a directory, not a file") @@ -328,6 +333,10 @@ def save_pretrained(self, save_directory: Union[str, os.PathLike]): self.to_json_file(output_config_file, use_diff=True) logger.info(f"Configuration saved in {output_config_file}") + if push_to_hub: + url = self._push_to_hub(save_files=[output_config_file], **kwargs) + logger.info(f"Configuration pushed to the hub in this commit: {url}") + @classmethod def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig": r""" diff --git a/src/transformers/file_utils.py b/src/transformers/file_utils.py index 7522d38785de5f..392728fdf0e994 100644 --- a/src/transformers/file_utils.py +++ b/src/transformers/file_utils.py @@ -31,6 +31,7 @@ from collections import OrderedDict, UserDict from contextlib import contextmanager from dataclasses import fields +from distutils.dir_util import copy_tree from enum import Enum from functools import partial, wraps from hashlib import sha256 @@ -47,10 +48,10 @@ import requests from filelock import FileLock +from huggingface_hub import HfApi, HfFolder, Repository from transformers.utils.versions import importlib_metadata from . import __version__ -from .hf_api import HfFolder from .utils import logging @@ -229,7 +230,12 @@ S3_BUCKET_PREFIX = "https://s3.amazonaws.com/models.huggingface.co/bert" CLOUDFRONT_DISTRIB_PREFIX = "https://cdn.huggingface.co" -HUGGINGFACE_CO_PREFIX = "https://huggingface.co/{model_id}/resolve/{revision}/{filename}" + +_staging_mode = os.environ.get("HUGGINGFACE_CO_STAGING", "NO").upper() in ENV_VARS_TRUE_VALUES +_default_endpoint = "https://moon-staging.huggingface.co" if _staging_mode else "https://huggingface.co" + +HUGGINGFACE_CO_RESOLVE_ENDPOINT = os.environ.get("HUGGINGFACE_CO_RESOLVE_ENDPOINT", _default_endpoint) +HUGGINGFACE_CO_PREFIX = HUGGINGFACE_CO_RESOLVE_ENDPOINT + "/{model_id}/resolve/{revision}/{filename}" PRESET_MIRROR_DICT = { "tuna": "https://mirrors.tuna.tsinghua.edu.cn/hugging-face-models", @@ -1684,3 +1690,125 @@ def copy_func(f): g = functools.update_wrapper(g, f) g.__kwdefaults__ = f.__kwdefaults__ return g + + +class PushToHubMixin: + """ + A Mixin containing the functionality to push a model or tokenizer to the hub. + """ + + def push_to_hub( + self, + repo_name: Optional[str] = None, + repo_url: Optional[str] = None, + commit_message: Optional[str] = None, + organization: Optional[str] = None, + private: bool = None, + use_auth_token: Optional[Union[bool, str]] = None, + ) -> str: + """ + Upload model checkpoint or tokenizer files to the 🤗 model hub. + + Parameters: + repo_name (:obj:`str`, `optional`): + Repository name for your model or tokenizer in the hub. If not specified, the repository name will be + the stem of :obj:`save_directory`. + repo_url (:obj:`str`, `optional`): + Specify this in case you want to push to an existing repository in the hub. If unspecified, a new + repository will be created in your namespace (unless you specify an :obj:`organization`) with + :obj:`repo_name`. + commit_message (:obj:`str`, `optional`): + Message to commit while pushing. Will default to :obj:`"add config"`, :obj:`"add tokenizer"` or + :obj:`"add model"` depending on the type of the class. + organization (:obj:`str`, `optional`): + Organization in which you want to push your model or tokenizer (you must be a member of this + organization). + private (:obj:`bool`, `optional`): + Whether or not the repository created should be private (requires a paying subscription). + use_auth_token (:obj:`bool` or :obj:`str`, `optional`): + The token to use as HTTP bearer authorization for remote files. If :obj:`True`, will use the token + generated when running :obj:`transformers-cli login` (stored in :obj:`~/.huggingface`). Will default to + :obj:`True` if :obj:`repo_url` is not specified. + + + Returns: + The url of the commit of your model in the given repository. + """ + with tempfile.TemporaryDirectory() as tmp_dir: + self.save_pretrained(tmp_dir) + self._push_to_hub( + save_directory=tmp_dir, + repo_name=repo_name, + repo_url=repo_url, + commit_message=commit_message, + organization=organization, + private=private, + use_auth_token=use_auth_token, + ) + + @classmethod + def _push_to_hub( + cls, + save_directory: Optional[str] = None, + save_files: Optional[List[str]] = None, + repo_name: Optional[str] = None, + repo_url: Optional[str] = None, + commit_message: Optional[str] = None, + organization: Optional[str] = None, + private: bool = None, + use_auth_token: Optional[Union[bool, str]] = None, + ) -> str: + # Private version of push_to_hub, that either accepts a folder to push or a list of files. + if save_directory is None and save_files is None: + raise ValueError("_push_to_hub requires either a `save_directory` or a list of `save_files`.") + if repo_name is None and repo_url is None and save_directory is None: + raise ValueError("Need either a `repo_name` or `repo_url` to know where to push!") + + if repo_name is None and repo_url is None and save_files is None: + repo_name = Path(save_directory).name + if use_auth_token is None and repo_url is None: + use_auth_token = True + + if isinstance(use_auth_token, str): + token = use_auth_token + elif use_auth_token: + token = HfFolder.get_token() + if token is None: + raise ValueError( + "You must login to the Hugging Face hub on this computer by typing `transformers-cli login` and " + "entering your credentials to use `use_auth_token=True`. Alternatively, you can pass your own " + "token as the `use_auth_token` argument." + ) + else: + token = None + + if repo_url is None: + # Special provision for the test endpoint (CI) + repo_url = HfApi(endpoint=HUGGINGFACE_CO_RESOLVE_ENDPOINT).create_repo( + token, + repo_name, + organization=organization, + private=private, + repo_type=None, + exist_ok=True, + ) + + if commit_message is None: + if "Tokenizer" in cls.__name__: + commit_message = "add tokenizer" + if "Config" in cls.__name__: + commit_message = "add config" + else: + commit_message = "add model" + + with tempfile.TemporaryDirectory() as tmp_dir: + # First create the repo (and clone its content if it's nonempty), then add the files (otherwise there is + # no diff so nothing is pushed). + repo = Repository(tmp_dir, clone_from=repo_url, use_auth_token=use_auth_token) + if save_directory is None: + for filename in save_files: + shutil.copy(filename, Path(tmp_dir) / Path(filename).name) + else: + copy_tree(save_directory, tmp_dir) + + return repo.push_to_hub(commit_message=commit_message) diff --git a/src/transformers/modeling_flax_utils.py b/src/transformers/modeling_flax_utils.py index 88db0678d9e54d..1331b3ba399788 100644 --- a/src/transformers/modeling_flax_utils.py +++ b/src/transformers/modeling_flax_utils.py @@ -14,7 +14,6 @@ # limitations under the License. import os -from abc import ABC from functools import partial from pickle import UnpicklingError from typing import Dict, Set, Tuple, Union @@ -29,8 +28,10 @@ from .configuration_utils import PretrainedConfig from .file_utils import ( + CONFIG_NAME, FLAX_WEIGHTS_NAME, WEIGHTS_NAME, + PushToHubMixin, add_start_docstrings_to_model_forward, cached_path, copy_func, @@ -54,7 +55,7 @@ } -class FlaxPreTrainedModel(ABC): +class FlaxPreTrainedModel(PushToHubMixin): r""" Base class for all models. @@ -385,7 +386,7 @@ def from_pretrained( return model - def save_pretrained(self, save_directory: Union[str, os.PathLike]): + def save_pretrained(self, save_directory: Union[str, os.PathLike], push_to_hub=False, **kwargs): """ Save a model and its configuration file to a directory, so that it can be re-loaded using the `:func:`~transformers.FlaxPreTrainedModel.from_pretrained`` class method @@ -393,6 +394,11 @@ def save_pretrained(self, save_directory: Union[str, os.PathLike]): Arguments: save_directory (:obj:`str` or :obj:`os.PathLike`): Directory to which to save. Will be created if it doesn't exist. + push_to_hub (:obj:`bool`, `optional`, defaults to :obj:`False`): + Whether or not to push your model to the Hugging Face model hub after saving it. + kwargs: + Additional key word arguments passed along to the + :meth:`~transformers.file_utils.PushToHubMixin.push_to_hub` method. """ if os.path.isfile(save_directory): logger.error(f"Provided path ({save_directory}) should be a directory, not a file") @@ -406,10 +412,18 @@ def save_pretrained(self, save_directory: Union[str, os.PathLike]): self.config.save_pretrained(save_directory) # save model - with open(os.path.join(save_directory, FLAX_WEIGHTS_NAME), "wb") as f: + output_model_file = os.path.join(save_directory, FLAX_WEIGHTS_NAME) + with open(output_model_file, "wb") as f: model_bytes = to_bytes(self.params) f.write(model_bytes) + logger.info(f"Model weights saved in {output_model_file}") + + if push_to_hub: + saved_files = [os.path.join(save_directory, CONFIG_NAME), output_model_file] + url = self._push_to_hub(save_files=saved_files, **kwargs) + logger.info(f"Model pushed to the hub in this commit: {url}") + def overwrite_call_docstring(model_class, docstring): # copy __call__ function to be sure docstring is changed only for this function diff --git a/src/transformers/modeling_tf_utils.py b/src/transformers/modeling_tf_utils.py index 002a7667f20487..ea4cff65592bf9 100644 --- a/src/transformers/modeling_tf_utils.py +++ b/src/transformers/modeling_tf_utils.py @@ -30,10 +30,12 @@ from .configuration_utils import PretrainedConfig from .file_utils import ( + CONFIG_NAME, DUMMY_INPUTS, TF2_WEIGHTS_NAME, WEIGHTS_NAME, ModelOutput, + PushToHubMixin, cached_path, hf_bucket_url, is_offline_mode, @@ -591,7 +593,7 @@ def init_copy_embeddings(old_embeddings, new_num_tokens): return mask, current_weights -class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin): +class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, PushToHubMixin): r""" Base class for all TF models. @@ -1011,7 +1013,7 @@ def prune_heads(self, heads_to_prune): """ raise NotImplementedError - def save_pretrained(self, save_directory, saved_model=False, version=1): + def save_pretrained(self, save_directory, saved_model=False, version=1, push_to_hub=False, **kwargs): """ Save a model and its configuration file to a directory, so that it can be re-loaded using the :func:`~transformers.TFPreTrainedModel.from_pretrained` class method. @@ -1025,6 +1027,11 @@ def save_pretrained(self, save_directory, saved_model=False, version=1): The version of the saved model. A saved model needs to be versioned in order to be properly loaded by TensorFlow Serving as detailed in the official documentation https://www.tensorflow.org/tfx/serving/serving_basic + push_to_hub (:obj:`bool`, `optional`, defaults to :obj:`False`): + Whether or not to push your model to the Hugging Face model hub after saving it. + kwargs: + Additional key word arguments passed along to the + :meth:`~transformers.file_utils.PushToHubMixin.push_to_hub` method. """ if os.path.isfile(save_directory): logger.error(f"Provided path ({save_directory}) should be a directory, not a file") @@ -1045,6 +1052,11 @@ def save_pretrained(self, save_directory, saved_model=False, version=1): self.save_weights(output_model_file) logger.info(f"Model weights saved in {output_model_file}") + if push_to_hub: + saved_files = [os.path.join(save_directory, CONFIG_NAME), output_model_file] + url = self._push_to_hub(save_files=saved_files, **kwargs) + logger.info(f"Model pushed to the hub in this commit: {url}") + @classmethod def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): r""" diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index d3c2c78f1b0213..44513b39945171 100755 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -29,12 +29,14 @@ from .activations import get_activation from .configuration_utils import PretrainedConfig from .file_utils import ( + CONFIG_NAME, DUMMY_INPUTS, FLAX_WEIGHTS_NAME, TF2_WEIGHTS_NAME, TF_WEIGHTS_NAME, WEIGHTS_NAME, ModelOutput, + PushToHubMixin, cached_path, hf_bucket_url, is_offline_mode, @@ -385,7 +387,7 @@ def floating_point_ops( return 6 * self.estimate_tokens(input_dict) * self.num_parameters(exclude_embeddings=exclude_embeddings) -class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin): +class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMixin): r""" Base class for all models. @@ -799,6 +801,8 @@ def save_pretrained( save_config: bool = True, state_dict: Optional[dict] = None, save_function: Callable = torch.save, + push_to_hub: bool = False, + **kwargs, ): """ Save a model and its configuration file to a directory, so that it can be re-loaded using the @@ -818,6 +822,11 @@ def save_pretrained( save_function (:obj:`Callable`): The function to use to save the state dictionary. Useful on distributed training like TPUs when one need to replace :obj:`torch.save` by another method. + push_to_hub (:obj:`bool`, `optional`, defaults to :obj:`False`): + Whether or not to push your model to the Hugging Face model hub after saving it. + kwargs: + Additional key word arguments passed along to the + :meth:`~transformers.file_utils.PushToHubMixin.push_to_hub` method. """ if os.path.isfile(save_directory): logger.error(f"Provided path ({save_directory}) should be a directory, not a file") @@ -848,6 +857,13 @@ def save_pretrained( logger.info(f"Model weights saved in {output_model_file}") + if push_to_hub: + saved_files = [output_model_file] + if save_config: + saved_files.append(os.path.join(save_directory, CONFIG_NAME)) + url = self._push_to_hub(save_files=saved_files, **kwargs) + logger.info(f"Model pushed to the hub in this commit: {url}") + @classmethod def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], *model_args, **kwargs): r""" diff --git a/src/transformers/testing_utils.py b/src/transformers/testing_utils.py index 283ec1eb4d8dd3..2292acb662225f 100644 --- a/src/transformers/testing_utils.py +++ b/src/transformers/testing_utils.py @@ -50,6 +50,11 @@ DUMMY_DIFF_TOKENIZER_IDENTIFIER = "julien-c/dummy-diff-tokenizer" # Used to test Auto{Config, Model, Tokenizer} model_type detection. +# Used to test the hub +USER = "__DUMMY_TRANSFORMERS_USER__" +PASS = "__DUMMY_TRANSFORMERS_PASS__" +ENDPOINT_STAGING = "https://moon-staging.huggingface.co" + def parse_flag_from_env(key, default=False): try: @@ -84,6 +89,7 @@ def parse_int_from_env(key, default=None): _run_pt_tf_cross_tests = parse_flag_from_env("RUN_PT_TF_CROSS_TESTS", default=False) _run_pt_flax_cross_tests = parse_flag_from_env("RUN_PT_FLAX_CROSS_TESTS", default=False) _run_custom_tokenizers = parse_flag_from_env("RUN_CUSTOM_TOKENIZERS", default=False) +_run_staging = parse_flag_from_env("HUGGINGFACE_CO_STAGING", default=False) _run_pipeline_tests = parse_flag_from_env("RUN_PIPELINE_TESTS", default=False) _run_git_lfs_tests = parse_flag_from_env("RUN_GIT_LFS_TESTS", default=False) _tf_gpu_memory_limit = parse_int_from_env("TF_GPU_MEMORY_LIMIT", default=None) @@ -146,6 +152,23 @@ def is_pipeline_test(test_case): return pytest.mark.is_pipeline_test()(test_case) +def is_staging_test(test_case): + """ + Decorator marking a test as a staging test. + + Those tests will run using the staging environment of huggingface.co instead of the real model hub. + """ + if not _run_staging: + return unittest.skip("test is staging test")(test_case) + else: + try: + import pytest # We don't need a hard dependency on pytest in the main library + except ImportError: + return test_case + else: + return pytest.mark.is_staging_test()(test_case) + + def slow(test_case): """ Decorator marking a test as slow. diff --git a/src/transformers/tokenization_utils_base.py b/src/transformers/tokenization_utils_base.py index e9b2b149ef67be..e72b8897677cff 100644 --- a/src/transformers/tokenization_utils_base.py +++ b/src/transformers/tokenization_utils_base.py @@ -34,6 +34,7 @@ from .file_utils import ( ExplicitEnum, PaddingStrategy, + PushToHubMixin, TensorType, _is_jax, _is_numpy, @@ -1415,7 +1416,7 @@ def all_special_ids(self) -> List[int]: @add_end_docstrings(INIT_TOKENIZER_DOCSTRING) -class PreTrainedTokenizerBase(SpecialTokensMixin): +class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin): """ Base class for :class:`~transformers.PreTrainedTokenizer` and :class:`~transformers.PreTrainedTokenizerFast`. @@ -1850,6 +1851,8 @@ def save_pretrained( save_directory: Union[str, os.PathLike], legacy_format: Optional[bool] = None, filename_prefix: Optional[str] = None, + push_to_hub: bool = False, + **kwargs, ) -> Tuple[str]: """ Save the full tokenizer state. @@ -1925,13 +1928,21 @@ def convert_added_tokens(obj: Union[AddedToken, Any], add_type_field=True): file_names = (tokenizer_config_file, special_tokens_map_file) - return self._save_pretrained( + save_files = self._save_pretrained( save_directory=save_directory, file_names=file_names, legacy_format=legacy_format, filename_prefix=filename_prefix, ) + if push_to_hub: + # Annoyingly, the return contains files that don't exist. + existing_files = [f for f in save_files if os.path.isfile(f)] + url = self._push_to_hub(save_files=existing_files, **kwargs) + logger.info(f"Tokenizer pushed to the hub in this commit: {url}") + + return save_files + def _save_pretrained( self, save_directory: Union[str, os.PathLike], diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index ab2291be095c41..d53ad9ac44dc16 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -23,6 +23,7 @@ import re import shutil import sys +import tempfile import time import warnings from logging import StreamHandler @@ -62,6 +63,7 @@ from .file_utils import ( CONFIG_NAME, WEIGHTS_NAME, + PushToHubMixin, is_apex_available, is_datasets_available, is_in_notebook, @@ -2274,6 +2276,71 @@ def floating_point_ops(self, inputs: Dict[str, Union[torch.Tensor, Any]]): else: return 0 + def push_to_hub( + self, + save_directory: Optional[str] = None, + repo_name: Optional[str] = None, + repo_url: Optional[str] = None, + commit_message: Optional[str] = "add model", + organization: Optional[str] = None, + private: bool = None, + use_auth_token: Optional[Union[bool, str]] = None, + ): + """ + Upload `self.model` to the 🤗 model hub. + + Parameters: + save_directory (:obj:`str` or :obj:`os.PathLike`): + Folder containing the model weights and config. Will default to :obj:`self.args.output_dir`. + repo_name (:obj:`str`, `optional`): + Repository name for your model or tokenizer in the hub. If not specified, the repository name will be + the stem of :obj:`save_directory`. + repo_url (:obj:`str`, `optional`): + Specify this in case you want to push to an existing repository in the hub. If unspecified, a new + repository will be created in your namespace (unless you specify an :obj:`organization`) with + :obj:`repo_name`. + commit_message (:obj:`str`, `optional`, defaults to :obj:`"add model"`): + Message to commit while pushing. + organization (:obj:`str`, `optional`): + Organization in which you want to push your model or tokenizer (you must be a member of this + organization). + private (:obj:`bool`, `optional`): + Whether or not the repository created should be private (requires a paying subscription). + use_auth_token (:obj:`bool` or :obj:`str`, `optional`): + The token to use as HTTP bearer authorization for remote files. If :obj:`True`, will use the token + generated when running :obj:`transformers-cli login` (stored in :obj:`~/.huggingface`). Will default to + :obj:`True` if :obj:`repo_url` is not specified. + + Returns: + The url of the commit of your model in the given repository. + """ + if not self.is_world_process_zero(): + return + + if not isinstance(unwrap_model(self.model), PushToHubMixin): + raise ValueError( + "The `upload_model_to_hub` method only works for models that inherit from `PushToHubMixin` models." + ) + if save_directory is None: + save_directory = self.args.output_dir + + # To avoid pushing all checkpoints, we just copy all the files in save_directory in a tmp dir. + with tempfile.TemporaryDirectory() as tmp_dir: + for f in os.listdir(save_directory): + fname = os.path.join(save_directory, f) + if os.path.isfile(fname): + shutil.copy(fname, os.path.join(tmp_dir, f)) + + return unwrap_model(self.model)._push_to_hub( + save_directory=tmp_dir, + repo_name=repo_name, + repo_url=repo_url, + commit_message=commit_message, + organization=organization, + private=private, + use_auth_token=use_auth_token, + ) + # # Deprecated code # diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index be98825e2282f8..a91921d466a5db 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -295,11 +295,15 @@ class TrainingArguments: When using distributed training, the value of the flag :obj:`find_unused_parameters` passed to :obj:`DistributedDataParallel`. Will default to :obj:`False` if gradient checkpointing is used, :obj:`True` otherwise. - dataloader_pin_memory (:obj:`bool`, `optional`, defaults to :obj:`True`)): + dataloader_pin_memory (:obj:`bool`, `optional`, defaults to :obj:`True`): Whether you want to pin memory in data loaders or not. Will default to :obj:`True`. - skip_memory_metrics (:obj:`bool`, `optional`, defaults to :obj:`False`)): + skip_memory_metrics (:obj:`bool`, `optional`, defaults to :obj:`False`): Whether to skip adding of memory profiler reports to metrics. Defaults to :obj:`False`. - + push_to_hub (:obj:`bool`, `optional`, defaults to :obj:`False`): + Whether or not to upload the trained model to the hub after training. This argument is not directly used by + :class:`~transformers.Trainer`, it's intended to be used by your training/evaluation scripts instead. See + the `example scripts `__ for more + details. """ output_dir: str = field( @@ -527,6 +531,9 @@ class TrainingArguments: use_legacy_prediction_loop: bool = field( default=False, metadata={"help": "Whether or not to use the legacy prediction_loop in the Trainer."} ) + push_to_hub: bool = field( + default=False, metadata={"help": "Whether or not to upload the trained model to the model hub after training."} + ) _n_gpu: int = field(init=False, repr=False, default=-1) mp_parameters: str = field( default="", diff --git a/tests/conftest.py b/tests/conftest.py index 104a1394fdf4a5..7c5f161436dcea 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -38,6 +38,7 @@ def pytest_configure(config): config.addinivalue_line( "markers", "is_pt_flax_cross_test: mark test to run only when PT and FLAX interactions are tested" ) + config.addinivalue_line("markers", "is_staging_test: mark test to run only in the staging environment") def pytest_addoption(parser): diff --git a/tests/test_configuration_common.py b/tests/test_configuration_common.py index 125755e06c4a16..eeebf1c9215d4d 100644 --- a/tests/test_configuration_common.py +++ b/tests/test_configuration_common.py @@ -17,6 +17,12 @@ import json import os import tempfile +import unittest + +from huggingface_hub import HfApi +from requests.exceptions import HTTPError +from transformers import BertConfig +from transformers.testing_utils import ENDPOINT_STAGING, PASS, USER, is_staging_test class ConfigTester(object): @@ -81,3 +87,54 @@ def run_common_tests(self): self.create_and_test_config_from_and_save_pretrained() self.create_and_test_config_with_num_labels() self.check_config_can_be_init_without_params() + + +@is_staging_test +class ConfigPushToHubTester(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls._api = HfApi(endpoint=ENDPOINT_STAGING) + cls._token = cls._api.login(username=USER, password=PASS) + + @classmethod + def tearDownClass(cls): + try: + cls._api.delete_repo(token=cls._token, name="test-model") + except HTTPError: + pass + + try: + cls._api.delete_repo(token=cls._token, name="test-model-org", organization="valid_org") + except HTTPError: + pass + + def test_push_to_hub(self): + config = BertConfig( + vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37 + ) + with tempfile.TemporaryDirectory() as tmp_dir: + config.save_pretrained(tmp_dir, push_to_hub=True, repo_name="test-model", use_auth_token=self._token) + + new_config = BertConfig.from_pretrained(f"{USER}/test-model") + for k, v in config.__dict__.items(): + if k != "transformers_version": + self.assertEqual(v, getattr(new_config, k)) + + def test_push_to_hub_in_organization(self): + config = BertConfig( + vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37 + ) + + with tempfile.TemporaryDirectory() as tmp_dir: + config.save_pretrained( + tmp_dir, + push_to_hub=True, + repo_name="test-model-org", + use_auth_token=self._token, + organization="valid_org", + ) + + new_config = BertConfig.from_pretrained("valid_org/test-model-org") + for k, v in config.__dict__.items(): + if k != "transformers_version": + self.assertEqual(v, getattr(new_config, k)) diff --git a/tests/test_hf_api.py b/tests/test_hf_api.py index 3e42fa20d513e0..8b7b1ddc868ab7 100644 --- a/tests/test_hf_api.py +++ b/tests/test_hf_api.py @@ -22,13 +22,9 @@ from requests.exceptions import HTTPError from transformers.hf_api import HfApi, HfFolder, ModelInfo, RepoObj -from transformers.testing_utils import require_git_lfs +from transformers.testing_utils import ENDPOINT_STAGING, PASS, USER, is_staging_test, require_git_lfs -USER = "__DUMMY_TRANSFORMERS_USER__" -PASS = "__DUMMY_TRANSFORMERS_PASS__" - -ENDPOINT_STAGING = "https://moon-staging.huggingface.co" ENDPOINT_STAGING_BASIC_AUTH = f"https://{USER}:{PASS}@moon-staging.huggingface.co" REPO_NAME = f"my-model-{int(time.time())}" @@ -106,6 +102,7 @@ def test_token_workflow(self): @require_git_lfs +@is_staging_test class HfLargefilesTest(HfApiCommonTest): @classmethod def setUpClass(cls): diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 419b92e280278a..22b0d6609a3b3b 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -22,10 +22,22 @@ import unittest from typing import List, Tuple +from huggingface_hub import HfApi +from requests.exceptions import HTTPError from transformers import is_torch_available, logging from transformers.file_utils import WEIGHTS_NAME from transformers.models.auto import get_values -from transformers.testing_utils import CaptureLogger, require_torch, require_torch_multi_gpu, slow, torch_device +from transformers.testing_utils import ( + ENDPOINT_STAGING, + PASS, + USER, + CaptureLogger, + is_staging_test, + require_torch, + require_torch_multi_gpu, + slow, + torch_device, +) if is_torch_available(): @@ -1300,3 +1312,54 @@ def test_model_from_pretrained_with_different_pretrained_model_name(self): with CaptureLogger(logger) as cl: BertModel.from_pretrained(TINY_T5) self.assertTrue("You are using a model of type t5 to instantiate a model of type bert" in cl.out) + + +@require_torch +@is_staging_test +class ModelPushToHubTester(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls._api = HfApi(endpoint=ENDPOINT_STAGING) + cls._token = cls._api.login(username=USER, password=PASS) + + @classmethod + def tearDownClass(cls): + try: + cls._api.delete_repo(token=cls._token, name="test-model") + except HTTPError: + pass + + try: + cls._api.delete_repo(token=cls._token, name="test-model-org", organization="valid_org") + except HTTPError: + pass + + def test_push_to_hub(self): + config = BertConfig( + vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37 + ) + model = BertModel(config) + with tempfile.TemporaryDirectory() as tmp_dir: + model.save_pretrained(tmp_dir, push_to_hub=True, repo_name="test-model", use_auth_token=self._token) + + new_model = BertModel.from_pretrained(f"{USER}/test-model") + for p1, p2 in zip(model.parameters(), new_model.parameters()): + self.assertTrue(torch.equal(p1, p2)) + + def test_push_to_hub_in_organization(self): + config = BertConfig( + vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37 + ) + model = BertModel(config) + with tempfile.TemporaryDirectory() as tmp_dir: + model.save_pretrained( + tmp_dir, + push_to_hub=True, + repo_name="test-model-org", + use_auth_token=self._token, + organization="valid_org", + ) + + new_model = BertModel.from_pretrained("valid_org/test-model-org") + for p1, p2 in zip(model.parameters(), new_model.parameters()): + self.assertTrue(torch.equal(p1, p2)) diff --git a/tests/test_modeling_tf_common.py b/tests/test_modeling_tf_common.py index 51daf3779dc593..5052a1a0a6190b 100644 --- a/tests/test_modeling_tf_common.py +++ b/tests/test_modeling_tf_common.py @@ -24,11 +24,17 @@ from importlib import import_module from typing import List, Tuple +from huggingface_hub import HfApi +from requests.exceptions import HTTPError from transformers import is_tf_available from transformers.models.auto import get_values from transformers.testing_utils import ( + ENDPOINT_STAGING, + PASS, + USER, _tf_gpu_memory_limit, is_pt_tf_cross_test, + is_staging_test, require_onnx, require_tf, slow, @@ -50,6 +56,8 @@ TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING, TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING, TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING, + BertConfig, + TFBertModel, TFSharedEmbeddings, tf_top_k_top_p_filtering, ) @@ -1326,3 +1334,62 @@ def test_top_k_top_p_filtering(self): tf.debugging.assert_near(non_inf_output, non_inf_expected_output, rtol=1e-12) tf.debugging.assert_equal(non_inf_idx, non_inf_expected_idx) + + +@require_tf +@is_staging_test +class TFModelPushToHubTester(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls._api = HfApi(endpoint=ENDPOINT_STAGING) + cls._token = cls._api.login(username=USER, password=PASS) + + @classmethod + def tearDownClass(cls): + try: + cls._api.delete_repo(token=cls._token, name="test-model") + except HTTPError: + pass + + try: + cls._api.delete_repo(token=cls._token, name="test-model-org", organization="valid_org") + except HTTPError: + pass + + def test_push_to_hub(self): + config = BertConfig( + vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37 + ) + model = TFBertModel(config) + # Make sure model is properly initialized + _ = model(model.dummy_inputs) + with tempfile.TemporaryDirectory() as tmp_dir: + model.save_pretrained(tmp_dir, push_to_hub=True, repo_name="test-model", use_auth_token=self._token) + + new_model = TFBertModel.from_pretrained(f"{USER}/test-model") + models_equal = True + for p1, p2 in zip(model.weights, new_model.weights): + if tf.math.reduce_sum(tf.math.abs(p1 - p2)) > 0: + models_equal = False + self.assertTrue(models_equal) + + def test_push_to_hub_in_organization(self): + config = BertConfig( + vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37 + ) + model = TFBertModel(config) + with tempfile.TemporaryDirectory() as tmp_dir: + model.save_pretrained( + tmp_dir, + push_to_hub=True, + repo_name="test-model-org", + use_auth_token=self._token, + organization="valid_org", + ) + + new_model = TFBertModel.from_pretrained("valid_org/test-model-org") + models_equal = True + for p1, p2 in zip(model.weights, new_model.weights): + if tf.math.reduce_sum(tf.math.abs(p1 - p2)) > 0: + models_equal = False + self.assertTrue(models_equal) diff --git a/tests/test_tokenization_common.py b/tests/test_tokenization_common.py index aa83b749d49af9..febb9a05c08a54 100644 --- a/tests/test_tokenization_common.py +++ b/tests/test_tokenization_common.py @@ -20,11 +20,15 @@ import re import shutil import tempfile +import unittest from collections import OrderedDict from itertools import takewhile from typing import TYPE_CHECKING, Dict, List, Tuple, Union +from huggingface_hub import HfApi +from requests.exceptions import HTTPError from transformers import ( + BertTokenizer, PreTrainedTokenizer, PreTrainedTokenizerBase, PreTrainedTokenizerFast, @@ -32,8 +36,12 @@ is_torch_available, ) from transformers.testing_utils import ( + ENDPOINT_STAGING, + PASS, + USER, get_tests_dir, is_pt_tf_cross_test, + is_staging_test, require_tf, require_tokenizers, require_torch, @@ -2863,3 +2871,53 @@ def test_compare_prepare_for_model(self): ) for key in python_output: self.assertEqual(python_output[key], rust_output[key]) + + +@is_staging_test +class TokenzierPushToHubTester(unittest.TestCase): + vocab_tokens = ["[UNK]", "[CLS]", "[SEP]", "[PAD]", "[MASK]", "bla", "blou"] + + @classmethod + def setUpClass(cls): + cls._api = HfApi(endpoint=ENDPOINT_STAGING) + cls._token = cls._api.login(username=USER, password=PASS) + + @classmethod + def tearDownClass(cls): + try: + cls._api.delete_repo(token=cls._token, name="test-model") + except HTTPError: + pass + + try: + cls._api.delete_repo(token=cls._token, name="test-model-org", organization="valid_org") + except HTTPError: + pass + + def test_push_to_hub(self): + with tempfile.TemporaryDirectory() as tmp_dir: + vocab_file = os.path.join(tmp_dir, "vocab.txt") + with open(vocab_file, "w", encoding="utf-8") as vocab_writer: + vocab_writer.write("".join([x + "\n" for x in self.vocab_tokens])) + tokenizer = BertTokenizer(vocab_file) + tokenizer.save_pretrained(tmp_dir, push_to_hub=True, repo_name="test-model", use_auth_token=self._token) + + new_tokenizer = BertTokenizer.from_pretrained(f"{USER}/test-model") + self.assertDictEqual(new_tokenizer.vocab, tokenizer.vocab) + + def test_push_to_hub_in_organization(self): + with tempfile.TemporaryDirectory() as tmp_dir: + vocab_file = os.path.join(tmp_dir, "vocab.txt") + with open(vocab_file, "w", encoding="utf-8") as vocab_writer: + vocab_writer.write("".join([x + "\n" for x in self.vocab_tokens])) + tokenizer = BertTokenizer(vocab_file) + tokenizer.save_pretrained( + tmp_dir, + push_to_hub=True, + repo_name="test-model-org", + use_auth_token=self._token, + organization="valid_org", + ) + + new_tokenizer = BertTokenizer.from_pretrained("valid_org/test-model-org") + self.assertDictEqual(new_tokenizer.vocab, tokenizer.vocab) diff --git a/tests/test_trainer.py b/tests/test_trainer.py index b5071783f2bde8..2edce418a03686 100644 --- a/tests/test_trainer.py +++ b/tests/test_trainer.py @@ -16,16 +16,23 @@ import dataclasses import gc import os +import re import tempfile import unittest import numpy as np +from huggingface_hub import HfApi +from requests.exceptions import HTTPError from transformers import AutoTokenizer, IntervalStrategy, PretrainedConfig, TrainingArguments, is_torch_available from transformers.file_utils import WEIGHTS_NAME from transformers.testing_utils import ( + ENDPOINT_STAGING, + PASS, + USER, TestCasePlus, get_tests_dir, + is_staging_test, require_datasets, require_optuna, require_ray, @@ -1081,6 +1088,60 @@ def test_no_wd_param_group(self): self.assertListEqual(trainer.optimizer.param_groups[1]["params"], no_wd_params) +@require_torch +@is_staging_test +class TrainerIntegrationWithHubTester(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls._api = HfApi(endpoint=ENDPOINT_STAGING) + cls._token = cls._api.login(username=USER, password=PASS) + + @classmethod + def tearDownClass(cls): + try: + cls._api.delete_repo(token=cls._token, name="test-model") + except HTTPError: + pass + + try: + cls._api.delete_repo(token=cls._token, name="test-model-org", organization="valid_org") + except HTTPError: + pass + + def test_push_to_hub(self): + with tempfile.TemporaryDirectory() as tmp_dir: + trainer = get_regression_trainer(output_dir=tmp_dir) + trainer.save_model() + url = trainer.push_to_hub(repo_name="test-model", use_auth_token=self._token) + + # Extract repo_name from the url + re_search = re.search(ENDPOINT_STAGING + r"/([^/]+/[^/]+)/", url) + self.assertTrue(re_search is not None) + repo_name = re_search.groups()[0] + + self.assertEqual(repo_name, f"{USER}/test-model") + + model = RegressionPreTrainedModel.from_pretrained(repo_name) + self.assertEqual(model.a.item(), trainer.model.a.item()) + self.assertEqual(model.b.item(), trainer.model.b.item()) + + def test_push_to_hub_in_organization(self): + with tempfile.TemporaryDirectory() as tmp_dir: + trainer = get_regression_trainer(output_dir=tmp_dir) + trainer.save_model() + url = trainer.push_to_hub(repo_name="test-model-org", organization="valid_org", use_auth_token=self._token) + + # Extract repo_name from the url + re_search = re.search(ENDPOINT_STAGING + r"/([^/]+/[^/]+)/", url) + self.assertTrue(re_search is not None) + repo_name = re_search.groups()[0] + self.assertEqual(repo_name, "valid_org/test-model-org") + + model = RegressionPreTrainedModel.from_pretrained("valid_org/test-model-org") + self.assertEqual(model.a.item(), trainer.model.a.item()) + self.assertEqual(model.b.item(), trainer.model.b.item()) + + @require_torch @require_optuna class TrainerHyperParameterOptunaIntegrationTest(unittest.TestCase):