@@ -27,9 +27,19 @@ const IGNORED: [&str; 5] = [
2727
2828const HF_TOKEN_ENV_VAR : & str = "HF_TOKEN" ;
2929
30+ /// Checks if a file is a model weight file
31+ fn is_weight_file ( filename : & str ) -> bool {
32+ filename. ends_with ( ".bin" )
33+ || filename. ends_with ( ".safetensors" )
34+ || filename. ends_with ( ".h5" )
35+ || filename. ends_with ( ".msgpack" )
36+ || filename. ends_with ( ".ckpt.index" )
37+ }
38+
3039/// Attempt to download a model from Hugging Face
3140/// Returns the directory it is in
32- pub async fn from_hf ( name : impl AsRef < Path > ) -> anyhow:: Result < PathBuf > {
41+ /// If ignore_weights is true, model weight files will be skipped
42+ pub async fn from_hf ( name : impl AsRef < Path > , ignore_weights : bool ) -> anyhow:: Result < PathBuf > {
3343 let name = name. as_ref ( ) ;
3444 let token = env:: var ( HF_TOKEN_ENV_VAR ) . ok ( ) ;
3545 let api = ApiBuilder :: new ( )
@@ -66,6 +76,11 @@ pub async fn from_hf(name: impl AsRef<Path>) -> anyhow::Result<PathBuf> {
6676 continue ;
6777 }
6878
79+ // If ignore_weights is true, skip weight files
80+ if ignore_weights && is_weight_file ( & sib. rfilename ) {
81+ continue ;
82+ }
83+
6984 match repo. get ( & sib. rfilename ) . await {
7085 Ok ( path) => {
7186 p = path;
@@ -83,8 +98,14 @@ pub async fn from_hf(name: impl AsRef<Path>) -> anyhow::Result<PathBuf> {
8398 }
8499
85100 if !files_downloaded {
101+ let file_type = if ignore_weights {
102+ "non-weight"
103+ } else {
104+ "valid"
105+ } ;
86106 return Err ( anyhow:: anyhow!(
87- "No valid files found for model '{}'." ,
107+ "No {} files found for model '{}'." ,
108+ file_type,
88109 model_name
89110 ) ) ;
90111 }
0 commit comments