Skip to content

Commit

Permalink
Allow saving / loading from Huggingface Hub preset (#1510)
Browse files Browse the repository at this point in the history
* first draft

* update upload_preset

* lint

* consistent error messages

* lint
  • Loading branch information
Wauplin authored and abuelnasr0 committed Apr 2, 2024
1 parent 6a8166e commit 6ea1e63
Showing 1 changed file with 56 additions and 5 deletions.
61 changes: 56 additions & 5 deletions keras_nlp/utils/preset_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,16 @@
except ImportError:
kagglehub = None

try:
import huggingface_hub
from huggingface_hub.utils import HFValidationError
except ImportError:
huggingface_hub = None

KAGGLE_PREFIX = "kaggle://"
GS_PREFIX = "gs://"
HF_PREFIX = "hf://"

TOKENIZER_ASSET_DIR = "assets/tokenizer"
CONFIG_FILE = "config.json"
TOKENIZER_CONFIG_FILE = "tokenizer.json"
Expand Down Expand Up @@ -69,15 +77,33 @@ def get_file(preset, path):
url,
cache_subdir=os.path.join("models", subdir),
)
elif preset.startswith(HF_PREFIX):
if huggingface_hub is None:
raise ImportError(
f"`from_preset()` requires the `huggingface_hub` package to load from '{preset}'. "
"Please install with `pip install huggingface_hub`."
)
hf_handle = preset.removeprefix(HF_PREFIX)
try:
return huggingface_hub.hf_hub_download(
repo_id=hf_handle, filename=path
)
except HFValidationError as e:
raise ValueError(
"Unexpected Hugging Face preset. Hugging Face model handles "
"should have the form 'hf://{org}/{model}'. For example, "
f"'hf://username/bert_base_en'. Received: preset={preset}."
) from e
elif os.path.exists(preset):
# Assume a local filepath.
return os.path.join(preset, path)
else:
raise ValueError(
"Unknown preset identifier. A preset must be a one of:\n"
"1) a built in preset identifier like `'bert_base_en'`\n"
"1) a built-in preset identifier like `'bert_base_en'`\n"
"2) a Kaggle Models handle like `'kaggle://keras/bert/keras/bert_base_en'`\n"
"3) a path to a local preset directory like `'./bert_base_en`\n"
"3) a Hugging Face handle like `'hf://username/bert_base_en'`\n"
"4) a path to a local preset directory like `'./bert_base_en`\n"
"Use `print(cls.presets.keys())` to view all built-in presets for "
"API symbol `cls`.\n"
f"Received: preset='{preset}'"
Expand Down Expand Up @@ -245,7 +271,9 @@ def upload_preset(
uri: The URI identifying model to upload to.
URIs with format
`kaggle://<KAGGLE_USERNAME>/<MODEL>/<FRAMEWORK>/<VARIATION>`
will be uploaded to Kaggle Hub.
will be uploaded to Kaggle Hub while URIs with format
`hf://[<HF_USERNAME>/]<MODEL>` will be uploaded to the Hugging
Face Hub.
preset: The path to the local model preset directory.
allow_incomplete: If True, allows the upload of presets without
a tokenizer configuration. Otherwise, a tokenizer
Expand All @@ -262,10 +290,33 @@ def upload_preset(
if uri.startswith(KAGGLE_PREFIX):
kaggle_handle = uri.removeprefix(KAGGLE_PREFIX)
kagglehub.model_upload(kaggle_handle, preset)
elif uri.startswith(HF_PREFIX):
if huggingface_hub is None:
raise ImportError(
f"`upload_preset()` requires the `huggingface_hub` package to upload to '{uri}'. "
"Please install with `pip install huggingface_hub`."
)
hf_handle = uri.removeprefix(HF_PREFIX)
try:
repo_url = huggingface_hub.create_repo(
repo_id=hf_handle, exist_ok=True
)
except HFValidationError as e:
raise ValueError(
"Unexpected Hugging Face URI. Hugging Face model handles "
"should have the form 'hf://[{org}/]{model}'. For example, "
"'hf://username/bert_base_en' or 'hf://bert_case_en' to implicitly"
f"upload to your user account. Received: URI={uri}."
) from e
huggingface_hub.upload_folder(
repo_id=repo_url.repo_id, folder_path=preset
)
else:
raise ValueError(
f"Unexpected URI `'{uri}'`. Kaggle upload format should follow "
"`kaggle://<KAGGLE_USERNAME>/<MODEL>/<FRAMEWORK>/<VARIATION>`."
"Unknown URI. An URI must be a one of:\n"
"1) a Kaggle Model handle like `'kaggle://<KAGGLE_USERNAME>/<MODEL>/<FRAMEWORK>/<VARIATION>'`\n"
"2) a Hugging Face handle like `'hf://[<HF_USERNAME>/]<MODEL>'`\n"
f"Received: uri='{uri}'."
)


Expand Down

0 comments on commit 6ea1e63

Please sign in to comment.