From 6ea1e63dd50d69ad7e6f5aa6d53ccda20ac58a0c Mon Sep 17 00:00:00 2001 From: Lucain Date: Wed, 27 Mar 2024 17:41:01 +0100 Subject: [PATCH] Allow saving / loading from Huggingface Hub preset (#1510) * first draft * update upload_preset * lint * consistent error messages * lint --- keras_nlp/utils/preset_utils.py | 61 ++++++++++++++++++++++++++++++--- 1 file changed, 56 insertions(+), 5 deletions(-) diff --git a/keras_nlp/utils/preset_utils.py b/keras_nlp/utils/preset_utils.py index e2a445271..5ddde4415 100644 --- a/keras_nlp/utils/preset_utils.py +++ b/keras_nlp/utils/preset_utils.py @@ -27,8 +27,16 @@ except ImportError: kagglehub = None +try: + import huggingface_hub + from huggingface_hub.utils import HFValidationError +except ImportError: + huggingface_hub = None + KAGGLE_PREFIX = "kaggle://" GS_PREFIX = "gs://" +HF_PREFIX = "hf://" + TOKENIZER_ASSET_DIR = "assets/tokenizer" CONFIG_FILE = "config.json" TOKENIZER_CONFIG_FILE = "tokenizer.json" @@ -69,15 +77,33 @@ def get_file(preset, path): url, cache_subdir=os.path.join("models", subdir), ) + elif preset.startswith(HF_PREFIX): + if huggingface_hub is None: + raise ImportError( + f"`from_preset()` requires the `huggingface_hub` package to load from '{preset}'. " + "Please install with `pip install huggingface_hub`." + ) + hf_handle = preset.removeprefix(HF_PREFIX) + try: + return huggingface_hub.hf_hub_download( + repo_id=hf_handle, filename=path + ) + except HFValidationError as e: + raise ValueError( + "Unexpected Hugging Face preset. Hugging Face model handles " + "should have the form 'hf://{org}/{model}'. For example, " + f"'hf://username/bert_base_en'. Received: preset={preset}." + ) from e elif os.path.exists(preset): # Assume a local filepath. return os.path.join(preset, path) else: raise ValueError( "Unknown preset identifier. A preset must be a one of:\n" - "1) a built in preset identifier like `'bert_base_en'`\n" + "1) a built-in preset identifier like `'bert_base_en'`\n" "2) a Kaggle Models handle like `'kaggle://keras/bert/keras/bert_base_en'`\n" - "3) a path to a local preset directory like `'./bert_base_en`\n" + "3) a Hugging Face handle like `'hf://username/bert_base_en'`\n" + "4) a path to a local preset directory like `'./bert_base_en`\n" "Use `print(cls.presets.keys())` to view all built-in presets for " "API symbol `cls`.\n" f"Received: preset='{preset}'" @@ -245,7 +271,9 @@ def upload_preset( uri: The URI identifying model to upload to. URIs with format `kaggle://///` - will be uploaded to Kaggle Hub. + will be uploaded to Kaggle Hub while URIs with format + `hf://[/]` will be uploaded to the Hugging + Face Hub. preset: The path to the local model preset directory. allow_incomplete: If True, allows the upload of presets without a tokenizer configuration. Otherwise, a tokenizer @@ -262,10 +290,33 @@ def upload_preset( if uri.startswith(KAGGLE_PREFIX): kaggle_handle = uri.removeprefix(KAGGLE_PREFIX) kagglehub.model_upload(kaggle_handle, preset) + elif uri.startswith(HF_PREFIX): + if huggingface_hub is None: + raise ImportError( + f"`upload_preset()` requires the `huggingface_hub` package to upload to '{uri}'. " + "Please install with `pip install huggingface_hub`." + ) + hf_handle = uri.removeprefix(HF_PREFIX) + try: + repo_url = huggingface_hub.create_repo( + repo_id=hf_handle, exist_ok=True + ) + except HFValidationError as e: + raise ValueError( + "Unexpected Hugging Face URI. Hugging Face model handles " + "should have the form 'hf://[{org}/]{model}'. For example, " + "'hf://username/bert_base_en' or 'hf://bert_case_en' to implicitly" + f"upload to your user account. Received: URI={uri}." + ) from e + huggingface_hub.upload_folder( + repo_id=repo_url.repo_id, folder_path=preset + ) else: raise ValueError( - f"Unexpected URI `'{uri}'`. Kaggle upload format should follow " - "`kaggle://///`." + "Unknown URI. An URI must be a one of:\n" + "1) a Kaggle Model handle like `'kaggle://///'`\n" + "2) a Hugging Face handle like `'hf://[/]'`\n" + f"Received: uri='{uri}'." )