diff --git a/download-model.py b/download-model.py index 306784a355..8fe94371f2 100644 --- a/download-model.py +++ b/download-model.py @@ -72,7 +72,7 @@ def sanitize_model_and_branch_names(self, model, branch): return model, branch - def get_download_links_from_huggingface(self, model, branch, text_only=False, specific_file=None): + def get_download_links_from_huggingface(self, model, branch, text_only=False, specific_file=None, exclude_pattern=None): session = self.session page = f"/api/models/{model}/tree/{branch}" cursor = b"" @@ -100,13 +100,17 @@ def get_download_links_from_huggingface(self, model, branch, text_only=False, sp if specific_file not in [None, ''] and fname != specific_file: continue + # Exclude files matching the exclude pattern + if exclude_pattern is not None and re.match(exclude_pattern, fname): + continue + if not is_lora and fname.endswith(('adapter_config.json', 'adapter_model.bin')): is_lora = True is_pytorch = re.match(r"(pytorch|adapter|gptq)_model.*\.bin", fname) is_safetensors = re.match(r".*\.safetensors", fname) is_pt = re.match(r".*\.pt", fname) - is_gguf = re.match(r'.*\.gguf', fname) + is_gguf = re.match(r".*\.gguf", fname) is_tiktoken = re.match(r".*\.tiktoken", fname) is_tokenizer = re.match(r"(tokenizer|ice|spiece).*\.model", fname) or is_tiktoken is_text = re.match(r".*\.(txt|json|py|md)", fname) or is_tokenizer @@ -140,7 +144,6 @@ def get_download_links_from_huggingface(self, model, branch, text_only=False, sp # If both pytorch and safetensors are available, download safetensors only # Also if GGUF and safetensors are available, download only safetensors - # (why do people do this?) if (has_pytorch or has_pt or has_gguf) and has_safetensors: has_gguf = False for i in range(len(classifications) - 1, -1, -1): @@ -148,8 +151,6 @@ def get_download_links_from_huggingface(self, model, branch, text_only=False, sp links.pop(i) # For GGUF, try to download only the Q4_K_M if no specific file is specified. - # If not present, exclude all GGUFs, as that's likely a repository with both - # GGUF and fp16 files. if has_gguf and specific_file is None: has_q4km = False for i in range(len(classifications) - 1, -1, -1): @@ -312,6 +313,7 @@ def check_model_files(self, model, branch, links, sha256, output_folder): parser.add_argument('--threads', type=int, default=4, help='Number of files to download simultaneously.') parser.add_argument('--text-only', action='store_true', help='Only download text files (txt/json).') parser.add_argument('--specific-file', type=str, default=None, help='Name of the specific file to download (if not provided, downloads all).') + parser.add_argument('--exclude-pattern', type=str, default=None, help='Regex pattern to exclude files from download.') parser.add_argument('--output', type=str, default=None, help='Save the model files to this folder.') parser.add_argument('--model-dir', type=str, default=None, help='Save the model files to a subfolder of this folder instead of the default one (text-generation-webui/models).') parser.add_argument('--clean', action='store_true', help='Does not resume the previous download.') @@ -322,6 +324,7 @@ def check_model_files(self, model, branch, links, sha256, output_folder): branch = args.branch model = args.MODEL specific_file = args.specific_file + exclude_pattern = args.exclude_pattern if model is None: print("Error: Please specify the model you'd like to download (e.g. 'python download-model.py facebook/opt-1.3b').") @@ -336,7 +339,9 @@ def check_model_files(self, model, branch, links, sha256, output_folder): sys.exit() # Get the download links from Hugging Face - links, sha256, is_lora, is_llamacpp = downloader.get_download_links_from_huggingface(model, branch, text_only=args.text_only, specific_file=specific_file) + links, sha256, is_lora, is_llamacpp = downloader.get_download_links_from_huggingface( + model, branch, text_only=args.text_only, specific_file=specific_file, exclude_pattern=exclude_pattern + ) # Get the output folder if args.output: @@ -349,4 +354,7 @@ def check_model_files(self, model, branch, links, sha256, output_folder): downloader.check_model_files(model, branch, links, sha256, output_folder) else: # Download files - downloader.download_model_files(model, branch, links, sha256, output_folder, specific_file=specific_file, threads=args.threads, is_llamacpp=is_llamacpp) + downloader.download_model_files( + model, branch, links, sha256, output_folder, + specific_file=specific_file, threads=args.threads, is_llamacpp=is_llamacpp + )