@@ -366,16 +366,22 @@ def filter_files_not_needed_for_inference(
366366_BAR_FORMAT = "{desc}: {percentage:3.0f}% Completed | {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]\n " # noqa: E501
367367
368368
369+ def enable_tqdm (use_tqdm_on_load : bool ):
370+ return use_tqdm_on_load and (not torch .distributed .is_initialized ()
371+ or torch .distributed .get_rank () == 0 )
372+
373+
369374def np_cache_weights_iterator (
370- model_name_or_path : str , cache_dir : Optional [str ], hf_folder : str ,
371- hf_weights_files : List [str ]
375+ model_name_or_path : str ,
376+ cache_dir : Optional [str ],
377+ hf_folder : str ,
378+ hf_weights_files : List [str ],
379+ use_tqdm_on_load : bool ,
372380) -> Generator [Tuple [str , torch .Tensor ], None , None ]:
373381 """Iterate over the weights in the model np files.
374382
375383 Will dump the model weights to numpy files if they are not already dumped.
376384 """
377- enable_tqdm = not torch .distributed .is_initialized (
378- ) or torch .distributed .get_rank () == 0
379385 # Convert the model weights from torch tensors to numpy arrays for
380386 # faster loading.
381387 np_folder = os .path .join (hf_folder , "np" )
@@ -389,7 +395,7 @@ def np_cache_weights_iterator(
389395 for bin_file in tqdm (
390396 hf_weights_files ,
391397 desc = "Loading np_cache checkpoint shards" ,
392- disable = not enable_tqdm ,
398+ disable = not enable_tqdm ( use_tqdm_on_load ) ,
393399 bar_format = _BAR_FORMAT ,
394400 ):
395401 state = torch .load (bin_file ,
@@ -414,15 +420,14 @@ def np_cache_weights_iterator(
414420
415421
416422def safetensors_weights_iterator (
417- hf_weights_files : List [str ]
423+ hf_weights_files : List [str ],
424+ use_tqdm_on_load : bool ,
418425) -> Generator [Tuple [str , torch .Tensor ], None , None ]:
419426 """Iterate over the weights in the model safetensor files."""
420- enable_tqdm = not torch .distributed .is_initialized (
421- ) or torch .distributed .get_rank () == 0
422427 for st_file in tqdm (
423428 hf_weights_files ,
424429 desc = "Loading safetensors checkpoint shards" ,
425- disable = not enable_tqdm ,
430+ disable = not enable_tqdm ( use_tqdm_on_load ) ,
426431 bar_format = _BAR_FORMAT ,
427432 ):
428433 with safe_open (st_file , framework = "pt" ) as f :
@@ -432,32 +437,30 @@ def safetensors_weights_iterator(
432437
433438
434439def runai_safetensors_weights_iterator (
435- hf_weights_files : List [str ]
440+ hf_weights_files : List [str ],
441+ use_tqdm_on_load : bool ,
436442) -> Generator [Tuple [str , torch .Tensor ], None , None ]:
437443 """Iterate over the weights in the model safetensor files."""
438- enable_tqdm = not torch .distributed .is_initialized (
439- ) or torch .distributed .get_rank () == 0
440444 with SafetensorsStreamer () as streamer :
441445 for st_file in tqdm (
442446 hf_weights_files ,
443447 desc = "Loading safetensors using Runai Model Streamer" ,
444- disable = not enable_tqdm ,
448+ disable = not enable_tqdm ( use_tqdm_on_load ) ,
445449 bar_format = _BAR_FORMAT ,
446450 ):
447451 streamer .stream_file (st_file )
448452 yield from streamer .get_tensors ()
449453
450454
451455def pt_weights_iterator (
452- hf_weights_files : List [str ]
456+ hf_weights_files : List [str ],
457+ use_tqdm_on_load : bool ,
453458) -> Generator [Tuple [str , torch .Tensor ], None , None ]:
454459 """Iterate over the weights in the model bin/pt files."""
455- enable_tqdm = not torch .distributed .is_initialized (
456- ) or torch .distributed .get_rank () == 0
457460 for bin_file in tqdm (
458461 hf_weights_files ,
459462 desc = "Loading pt checkpoint shards" ,
460- disable = not enable_tqdm ,
463+ disable = not enable_tqdm ( use_tqdm_on_load ) ,
461464 bar_format = _BAR_FORMAT ,
462465 ):
463466 state = torch .load (bin_file , map_location = "cpu" , weights_only = True )
0 commit comments