Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added exclude pattern param to download-model.py script #6542

Merged
merged 2 commits into from
Jan 8, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 15 additions & 7 deletions download-model.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def sanitize_model_and_branch_names(self, model, branch):

return model, branch

def get_download_links_from_huggingface(self, model, branch, text_only=False, specific_file=None):
def get_download_links_from_huggingface(self, model, branch, text_only=False, specific_file=None, exclude_pattern=None):
session = self.session
page = f"/api/models/{model}/tree/{branch}"
cursor = b""
Expand Down Expand Up @@ -100,13 +100,17 @@ def get_download_links_from_huggingface(self, model, branch, text_only=False, sp
if specific_file not in [None, ''] and fname != specific_file:
continue

# Exclude files matching the exclude pattern
if exclude_pattern is not None and re.match(exclude_pattern, fname):
continue

if not is_lora and fname.endswith(('adapter_config.json', 'adapter_model.bin')):
is_lora = True

is_pytorch = re.match(r"(pytorch|adapter|gptq)_model.*\.bin", fname)
is_safetensors = re.match(r".*\.safetensors", fname)
is_pt = re.match(r".*\.pt", fname)
is_gguf = re.match(r'.*\.gguf', fname)
is_gguf = re.match(r".*\.gguf", fname)
is_tiktoken = re.match(r".*\.tiktoken", fname)
is_tokenizer = re.match(r"(tokenizer|ice|spiece).*\.model", fname) or is_tiktoken
is_text = re.match(r".*\.(txt|json|py|md)", fname) or is_tokenizer
Expand Down Expand Up @@ -140,16 +144,13 @@ def get_download_links_from_huggingface(self, model, branch, text_only=False, sp

# If both pytorch and safetensors are available, download safetensors only
# Also if GGUF and safetensors are available, download only safetensors
# (why do people do this?)
if (has_pytorch or has_pt or has_gguf) and has_safetensors:
has_gguf = False
for i in range(len(classifications) - 1, -1, -1):
if classifications[i] in ['pytorch', 'pt', 'gguf']:
links.pop(i)

# For GGUF, try to download only the Q4_K_M if no specific file is specified.
# If not present, exclude all GGUFs, as that's likely a repository with both
# GGUF and fp16 files.
if has_gguf and specific_file is None:
has_q4km = False
for i in range(len(classifications) - 1, -1, -1):
Expand Down Expand Up @@ -312,6 +313,7 @@ def check_model_files(self, model, branch, links, sha256, output_folder):
parser.add_argument('--threads', type=int, default=4, help='Number of files to download simultaneously.')
parser.add_argument('--text-only', action='store_true', help='Only download text files (txt/json).')
parser.add_argument('--specific-file', type=str, default=None, help='Name of the specific file to download (if not provided, downloads all).')
parser.add_argument('--exclude-pattern', type=str, default=None, help='Regex pattern to exclude files from download.')
parser.add_argument('--output', type=str, default=None, help='Save the model files to this folder.')
parser.add_argument('--model-dir', type=str, default=None, help='Save the model files to a subfolder of this folder instead of the default one (text-generation-webui/models).')
parser.add_argument('--clean', action='store_true', help='Does not resume the previous download.')
Expand All @@ -322,6 +324,7 @@ def check_model_files(self, model, branch, links, sha256, output_folder):
branch = args.branch
model = args.MODEL
specific_file = args.specific_file
exclude_pattern = args.exclude_pattern

if model is None:
print("Error: Please specify the model you'd like to download (e.g. 'python download-model.py facebook/opt-1.3b').")
Expand All @@ -336,7 +339,9 @@ def check_model_files(self, model, branch, links, sha256, output_folder):
sys.exit()

# Get the download links from Hugging Face
links, sha256, is_lora, is_llamacpp = downloader.get_download_links_from_huggingface(model, branch, text_only=args.text_only, specific_file=specific_file)
links, sha256, is_lora, is_llamacpp = downloader.get_download_links_from_huggingface(
model, branch, text_only=args.text_only, specific_file=specific_file, exclude_pattern=exclude_pattern
)

# Get the output folder
if args.output:
Expand All @@ -349,4 +354,7 @@ def check_model_files(self, model, branch, links, sha256, output_folder):
downloader.check_model_files(model, branch, links, sha256, output_folder)
else:
# Download files
downloader.download_model_files(model, branch, links, sha256, output_folder, specific_file=specific_file, threads=args.threads, is_llamacpp=is_llamacpp)
downloader.download_model_files(
model, branch, links, sha256, output_folder,
specific_file=specific_file, threads=args.threads, is_llamacpp=is_llamacpp
)