diff --git a/src/grag/quantize/utils.py b/src/grag/quantize/utils.py index 7338a04..b497be2 100644 --- a/src/grag/quantize/utils.py +++ b/src/grag/quantize/utils.py @@ -10,7 +10,8 @@ import requests from git import Repo from grag.components.utils import get_config -from huggingface_hub import snapshot_download +from huggingface_hub import login, snapshot_download +from huggingface_hub.utils import GatedRepoError config = get_config() @@ -107,12 +108,29 @@ def fetch_model_repo(repo_id: str, model_path: Union[str, Path] = './grag-quanti model_path = Path(model_path) local_dir = model_path / f"{repo_id.split('/')[1]}" local_dir.mkdir(parents=True, exist_ok=True) - snapshot_download( - repo_id=repo_id, - local_dir=local_dir, - local_dir_use_symlinks="auto", - resume_download=True, - ) + + try: + snapshot_download( + repo_id=repo_id, + local_dir=local_dir, + local_dir_use_symlinks="auto", + resume_download=True, + ) + except GatedRepoError: + print( + "This model comes under gated repository. You must be authenticated to download the model. For more: https://huggingface.co/docs/hub/en/models-gated") + resp = input("If you have auth token, please provide it here ['n' or enter to exit]: ") + if resp == 'n' or resp == '': + print("No token provided, exiting.") + exit(0) + else: + login(resp) + snapshot_download( + repo_id=repo_id, + local_dir=local_dir, + local_dir_use_symlinks="auto", + resume_download=True, + ) print(f"Model downloaded in {local_dir}") return local_dir