Skip to content

Commit

Permalink
Add support for loading adapters from HuggingFace Model Hub (#162)
Browse files Browse the repository at this point in the history
  • Loading branch information
calpt authored Jun 11, 2021
1 parent 6ad0808 commit d4b61f1
Show file tree
Hide file tree
Showing 4 changed files with 29 additions and 5 deletions.
2 changes: 2 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@
"flake8>=3.8.3",
"flax>=0.3.2",
"fugashi>=1.0",
"huggingface_hub>=0.0.9",
"importlib_metadata",
"ipadic>=1.0.0,<2.0",
"isort>=5.5.4",
Expand Down Expand Up @@ -303,6 +304,7 @@ def run(self):
deps["dataclasses"] + ";python_version<'3.7'", # dataclasses for Python versions that don't have it
deps["importlib_metadata"] + ";python_version<'3.8'", # importlib_metadata for Python versions that don't have it
deps["filelock"], # filesystem locks, e.g., to prevent parallel downloads
deps["huggingface_hub"], # loading adapters from HF hub
deps["numpy"],
deps["packaging"], # utilities from PyPA to e.g., compare versions
deps["regex"], # for OpenAI GPT
Expand Down
10 changes: 9 additions & 1 deletion src/transformers/adapters/model_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,6 +317,7 @@ def load_adapter(
version: str = None,
model_name: str = None,
load_as: str = None,
source: str = "ah",
custom_weights_loaders: Optional[List[WeightsLoader]] = None,
leave_out: Optional[List[int]] = None,
**kwargs
Expand All @@ -337,14 +338,19 @@ def load_adapter(
model_name (str, optional): The string identifier of the pre-trained model.
load_as (str, optional): Load the adapter using this name. By default, the name with which the adapter was
saved will be used.
source (str, optional): Identifier of the source(s) from where to load the adapter. Can be:
- "ah" (default): search on AdapterHub.
- "hf": search on HuggingFace model hub.
- None: only search on local file system
leave_out: Dynamically drop adapter modules in the specified Transformer layers when loading the adapter.
Returns:
str: The name with which the adapter was added to the model.
"""
loader = AdapterLoader(self)
load_dir, load_name = loader.load(
adapter_name_or_path, config, version, model_name, load_as, leave_out=leave_out, **kwargs
adapter_name_or_path, config, version, model_name, load_as, source=source, leave_out=leave_out, **kwargs
)
# load additional custom weights
if custom_weights_loaders:
Expand Down Expand Up @@ -521,6 +527,7 @@ def load_adapter(
version: str = None,
model_name: str = None,
load_as: str = None,
source: str = "ah",
with_head: bool = True,
custom_weights_loaders: Optional[List[WeightsLoader]] = None,
leave_out: Optional[List[int]] = None,
Expand All @@ -536,6 +543,7 @@ def load_adapter(
version=version,
model_name=model_name,
load_as=load_as,
source=source,
custom_weights_loaders=custom_weights_loaders,
leave_out=leave_out,
**kwargs,
Expand Down
21 changes: 17 additions & 4 deletions src/transformers/adapters/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import json
import logging
import os
import re
import shutil
import tarfile
from collections.abc import Mapping
Expand All @@ -16,7 +15,9 @@

import requests
from filelock import FileLock
from huggingface_hub import snapshot_download

from .. import __version__
from ..file_utils import get_from_cache, is_remote_url, torch_cache_home


Expand All @@ -29,7 +30,6 @@
ADAPTERFUSION_CONFIG_NAME = "adapter_fusion_config.json"
ADAPTERFUSION_WEIGHTS_NAME = "pytorch_model_adapter_fusion.bin"

ADAPTER_IDENTIFIER_PATTERN = r"[0-9a-zA-Z\-_\/@]{2,}"
ADAPTER_HUB_URL = "https://raw.githubusercontent.com/Adapter-Hub/Hub/master/dist/v2/"
ADAPTER_HUB_INDEX_FILE = ADAPTER_HUB_URL + "index/{}.json"
ADAPTER_HUB_CONFIG_FILE = ADAPTER_HUB_URL + "architectures.json"
Expand Down Expand Up @@ -357,11 +357,23 @@ def pull_from_hub(
return download_path


def pull_from_hf_model_hub(specifier: str, version: str = None, **kwargs) -> str:
download_path = snapshot_download(
specifier,
revision=version,
cache_dir=kwargs.pop("cache_dir", None),
library_name="adapter-transformers",
library_version=__version__,
)
return download_path


def resolve_adapter_path(
adapter_name_or_path,
model_name: str = None,
adapter_config: Union[dict, str] = None,
version: str = None,
source: str = "ah",
**kwargs
) -> str:
"""
Expand Down Expand Up @@ -399,10 +411,11 @@ def resolve_adapter_path(
WEIGHTS_NAME, CONFIG_NAME, adapter_name_or_path
)
)
# matches possible form of identifier in hub
elif re.fullmatch(ADAPTER_IDENTIFIER_PATTERN, adapter_name_or_path):
elif source == "ah":
return pull_from_hub(
adapter_name_or_path, model_name, adapter_config=adapter_config, version=version, **kwargs
)
elif source == "hf":
return pull_from_hf_model_hub(adapter_name_or_path, version=version, **kwargs)
else:
raise ValueError("Unable to identify {} as a valid module location.".format(adapter_name_or_path))
1 change: 1 addition & 0 deletions src/transformers/dependency_versions_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
"flake8": "flake8>=3.8.3",
"flax": "flax>=0.3.2",
"fugashi": "fugashi>=1.0",
"huggingface_hub": "huggingface_hub>=0.0.9",
"importlib_metadata": "importlib_metadata",
"ipadic": "ipadic>=1.0.0,<2.0",
"isort": "isort>=5.5.4",
Expand Down

0 comments on commit d4b61f1

Please sign in to comment.