diff --git a/ludwig/cli.py b/ludwig/cli.py index c89bc47f7cb..345a69d0e7d 100644 --- a/ludwig/cli.py +++ b/ludwig/cli.py @@ -56,7 +56,7 @@ def __init__(self): init_config Initialize a user config from a dataset and targets render_config Renders the fully populated config with all defaults set check_install Runs a quick training run on synthetic data to verify installation status - upload Push trained model artifacts to a registry (e.g., HuggingFace Hub) + upload Push trained model artifacts to a registry (e.g., Predibase, HuggingFace Hub) """, ) parser.add_argument("command", help="Subcommand to run") diff --git a/ludwig/upload.py b/ludwig/upload.py index 4cc86e488ad..fa626f3a1e1 100644 --- a/ludwig/upload.py +++ b/ludwig/upload.py @@ -4,7 +4,7 @@ from typing import Optional from ludwig.utils.print_utils import get_logging_level_registry -from ludwig.utils.upload_utils import HuggingFaceHub +from ludwig.utils.upload_utils import HuggingFaceHub, Predibase logger = logging.getLogger(__name__) @@ -12,6 +12,7 @@ def get_upload_registry(): return { "hf_hub": HuggingFaceHub, + "predibase": Predibase, } @@ -23,6 +24,8 @@ def upload_cli( private: bool = False, commit_message: str = "Upload trained [Ludwig](https://ludwig.ai/latest/) model weights", commit_description: Optional[str] = None, + dataset_file: Optional[str] = None, + dataset_name: Optional[str] = None, **kwargs, ) -> None: """Create an empty repo on the HuggingFace Hub and upload trained model artifacts to that repo. @@ -30,7 +33,7 @@ def upload_cli( Args: service (`str`): Name of the hosted model service to push the trained artifacts to. - Currently, this only supports `hf_hub`. + Currently, this only supports `hf_hub` and `predibase`. repo_id (`str`): A namespace (user or an organization) and a repo name separated by a `/`. @@ -49,10 +52,15 @@ def upload_cli( `f"Upload {path_in_repo} with huggingface_hub"` commit_description (`str` *optional*): The description of the generated commit + dataset_file (`str`, *optional*): + The path to the dataset file. Required if `service` is set to + `"predibase"` for new model repos. + dataset_name (`str`, *optional*): + The name of the dataset. Used by the `service` + `"predibase"`. """ model_service = get_upload_registry().get(service, "hf_hub") hub = model_service() - hub.login() hub.upload( repo_id=repo_id, model_path=model_path, @@ -60,6 +68,8 @@ def upload_cli( private=private, commit_message=commit_message, commit_description=commit_description, + dataset_file=dataset_file, + dataset_name=dataset_name, ) @@ -77,7 +87,7 @@ def cli(sys_argv): "service", help="Name of the model repository service.", default="hf_hub", - choices=["hf_hub"], + choices=["hf_hub", "predibase"], ) parser.add_argument( @@ -115,6 +125,11 @@ def cli(sys_argv): choices=["critical", "error", "warning", "info", "debug", "notset"], ) + parser.add_argument("-df", "--dataset_file", help="The location of the dataset file", default=None) + parser.add_argument( + "-dn", "--dataset_name", help="(Optional) The name of the dataset in the Provider", default=None + ) + args = parser.parse_args(sys_argv) args.logging_level = get_logging_level_registry()[args.logging_level] diff --git a/ludwig/utils/upload_utils.py b/ludwig/utils/upload_utils.py index 4dc193d6410..cc166170282 100644 --- a/ludwig/utils/upload_utils.py +++ b/ludwig/utils/upload_utils.py @@ -37,6 +37,8 @@ def upload( private: Optional[bool] = False, commit_message: Optional[str] = None, commit_description: Optional[str] = None, + dataset_file: Optional[str] = None, + dataset_name: Optional[str] = None, ) -> bool: """Abstract method to upload trained model artifacts to the target repository. @@ -68,9 +70,7 @@ def _validate_upload_parameters( trained model artifacts to the target repository. Args: - repo_id (str): The ID of the target repository. It must be a namespace (user or an organization) - and a repository name separated by a '/'. For example, if your HF username is 'johndoe' and you - want to create a repository called 'test', the repo_id should be 'johndoe/test'. + repo_id (str): The ID of the target repository. Each provider will verify their specific rules. model_path (str): The path to the directory containing the trained model artifacts. It should contain the model's weights, usually saved under 'model/model_weights'. repo_type (str, optional): The type of the repository. Not used in the base class, but subclasses @@ -85,18 +85,10 @@ def _validate_upload_parameters( implementations. Defaults to None. Raises: - AssertionError: If the repo_id does not have both a namespace and a repo name separated by a '/'. FileNotFoundError: If the model_path does not exist. Exception: If the trained model artifacts are not found at the expected location within model_path, or if the artifacts are not in the required format (i.e., 'pytorch_model.bin' or 'adapter_model.bin'). """ - # Validate repo_id has both a namespace and a repo name - assert "/" in repo_id, ( - "`repo_id` must be a namespace (user or an organization) and a repo name separated by a `/`." - " For example, if your HF username is `johndoe` and you want to create a repository called `test`, the" - " repo_id should be johndoe/test" - ) - # Make sure the model's save path is actually a valid path if not os.path.exists(model_path): raise FileNotFoundError(f"The path '{model_path}' does not exist.") @@ -110,21 +102,11 @@ def _validate_upload_parameters( "wrong during training where the model's weights were not saved." ) - # Make sure the model's saved artifacts either contain: - # 1. pytorch_model.bin -> regular model training, such as ECD or for LLMs - # 2. adapter_model.bin -> LLM fine-tuning using PEFT - files = set(os.listdir(trained_model_artifacts_path)) - if "pytorch_model.bin" not in files and "adapter_model.bin" not in files: - raise Exception( - f"Can't find model weights at {trained_model_artifacts_path}. Trained model weights should " - "either be saved as `pytorch_model.bin` for regular model training, or have `adapter_model.bin`" - "if using parameter efficient fine-tuning methods like LoRA." - ) - class HuggingFaceHub(BaseModelUpload): def __init__(self): self.api = None + self.login() def login(self): """Login to huggingface hub using the token stored in ~/.cache/huggingface/token and returns a HfApi client @@ -142,6 +124,68 @@ def login(self): self.api = hf_api + @staticmethod + def _validate_upload_parameters( + repo_id: str, + model_path: str, + repo_type: Optional[str] = None, + private: Optional[bool] = False, + commit_message: Optional[str] = None, + commit_description: Optional[str] = None, + ): + """Validate parameters before uploading trained model artifacts. + + This method checks if the input parameters meet the necessary requirements before uploading + trained model artifacts to the target repository. + + Args: + repo_id (str): The ID of the target repository. It must be a namespace (user or an organization) + and a repository name separated by a '/'. For example, if your HF username is 'johndoe' and you + want to create a repository called 'test', the repo_id should be 'johndoe/test'. + model_path (str): The path to the directory containing the trained model artifacts. It should contain + the model's weights, usually saved under 'model/model_weights'. + repo_type (str, optional): The type of the repository. Not used in the base class, but subclasses + may use it for specific repository implementations. Defaults to None. + private (bool, optional): Whether the repository should be private or not. Not used in the base class, + but subclasses may use it for specific repository implementations. Defaults to False. + commit_message (str, optional): A message to attach to the commit when uploading to version control + systems. Not used in the base class, but subclasses may use it for specific repository + implementations. Defaults to None. + commit_description (str, optional): A description of the commit when uploading to version control + systems. Not used in the base class, but subclasses may use it for specific repository + implementations. Defaults to None. + + Raises: + ValueError: If the repo_id does not have both a namespace and a repo name separated by a '/'. + """ + # Validate repo_id has both a namespace and a repo name + if "/" not in repo_id: + raise ValueError( + "`repo_id` must be a namespace (user or an organization) and a repo name separated by a `/`." + " For example, if your HF username is `johndoe` and you want to create a repository called `test`, the" + " repo_id should be johndoe/test" + ) + BaseModelUpload._validate_upload_parameters( + repo_id, + model_path, + repo_type, + private, + commit_message, + commit_description, + ) + + trained_model_artifacts_path = os.path.join(model_path, "model", "model_weights") + # Make sure the model's saved artifacts either contain: + # 1. pytorch_model.bin -> regular model training, such as ECD or for LLMs + # 2. adapter_model.bin -> LLM fine-tuning using PEFT + files = set(os.listdir(trained_model_artifacts_path)) + if "pytorch_model.bin" not in files and "adapter_model.bin" not in files: + raise Exception( + f"Can't find model weights at {trained_model_artifacts_path}. Trained model weights should " + "either be saved as `pytorch_model.bin` for regular model training, or have `adapter_model.bin`" + "if using parameter efficient fine-tuning methods like LoRA." + ) + def upload( self, repo_id: str, @@ -150,6 +194,7 @@ def upload( private: Optional[bool] = False, commit_message: Optional[str] = None, commit_description: Optional[str] = None, + **kwargs, ) -> bool: """Create an empty repo on the HuggingFace Hub and upload trained model artifacts to that repo. @@ -205,3 +250,143 @@ def upload( return True return False + + +class Predibase(BaseModelUpload): + def __init__(self): + self.pc = None + self.login() + + def login(self): + """Login to Predibase using the token stored in the PREDIBASE_API_TOKEN environment variable and return a + PredibaseClient object that can be used to interact with Predibase.""" + from predibase import PredibaseClient + + token = os.environ.get("PREDIBASE_API_TOKEN") + if token is None: + raise ValueError( + "Unable to find PREDIBASE_API_TOKEN environment variable. Please log into Predibase, " + "generate a token and use `export PREDIBASE_API_TOKEN=` to use Predibase" + ) + + try: + pc = PredibaseClient() + + # TODO: Check if subscription has expired + + self.pc = pc + except Exception as e: + raise Exception(f"Failed to login to Predibase: {e}") + return False + + return True + + @staticmethod + def _validate_upload_parameters( + repo_id: str, + model_path: str, + repo_type: Optional[str] = None, + private: Optional[bool] = False, + commit_message: Optional[str] = None, + commit_description: Optional[str] = None, + ): + """Validate parameters before uploading trained model artifacts. + + This method checks if the input parameters meet the necessary requirements before uploading + trained model artifacts to the target repository. + + Args: + repo_id (str): The ID of the target repository. It must be a less than 256 characters. + model_path (str): The path to the directory containing the trained model artifacts. It should contain + the model's weights, usually saved under 'model/model_weights'. + repo_type (str, optional): The type of the repository. Not used in the base class, but subclasses + may use it for specific repository implementations. Defaults to None. + private (bool, optional): Whether the repository should be private or not. Not used in the base class, + but subclasses may use it for specific repository implementations. Defaults to False. + commit_message (str, optional): A message to attach to the commit when uploading to version control + systems. Not used in the base class, but subclasses may use it for specific repository + implementations. Defaults to None. + commit_description (str, optional): A description of the commit when uploading to version control + systems. Not used in the base class, but subclasses may use it for specific repository + implementations. Defaults to None. + + Raises: + ValueError: If the repo_id is too long. + """ + if len(repo_id) > 255: + raise ValueError("`repo_id` must be 255 characters or less.") + + BaseModelUpload._validate_upload_parameters( + repo_id, + model_path, + repo_type, + private, + commit_message, + commit_description, + ) + + def upload( + self, + repo_id: str, + model_path: str, + commit_description: Optional[str] = None, + dataset_file: Optional[str] = None, + dataset_name: Optional[str] = None, + **kwargs, + ) -> bool: + """Create an empty repo in Predibase and upload trained model artifacts to that repo. + + Args: + model_path (`str`): + The path of the saved model. This is the top level directory where + the models weights as well as other associated training artifacts + are saved. + repo_name (`str`): + A repo name. + repo_description (`str` *optional*): + The description of the repo. + dataset_file (`str` *optional*): + The path to the dataset file. Required if `service` is set to + `"predibase"` for new model repos. + dataset_name (`str` *optional*): + The name of the dataset. Used by the `service` + `"predibase"`. Falls back to the filename. + """ + # Validate upload parameters are in the right format + Predibase._validate_upload_parameters( + repo_id, + model_path, + None, + False, + "", + commit_description, + ) + + # Upload the dataset to Predibase + try: + dataset = self.pc.upload_dataset(file_path=dataset_file, name=dataset_name) + except Exception as e: + raise RuntimeError("Failed to upload dataset to Predibase") from e + + # Create empty model repo using repo_name, but it is okay if it already exists. + try: + repo = self.pc.create_model_repo( + name=repo_id, + description=commit_description, + exists_ok=True, + ) + except Exception as e: + raise RuntimeError("Failed to create repo in Predibase") from e + + # Upload the zip file to Predibase + try: + self.pc.upload_model( + repo=repo, + model_path=model_path, + dataset=dataset, + ) + except Exception as e: + raise RuntimeError("Failed to upload model to Predibase") from e + + logger.info(f"Model uploaded to Predibase with repository name `{repo_id}`") + return True diff --git a/requirements_extra.txt b/requirements_extra.txt index 4be0e04497a..26fe48eb998 100644 --- a/requirements_extra.txt +++ b/requirements_extra.txt @@ -3,3 +3,6 @@ horovod[pytorch]>=0.24.0,!=0.26.0 # alternative to Dask modin[ray] + +# Allows users to upload +predibase>=2023.10.2