Skip to content

Commit

Permalink
[misc] only tqdm for first rank (#6672)
Browse files Browse the repository at this point in the history
  • Loading branch information
youkaichao authored Jul 23, 2024
1 parent 97234be commit c520124
Showing 1 changed file with 31 additions and 6 deletions.
37 changes: 31 additions & 6 deletions vllm/model_executor/model_loader/weight_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,6 +313,13 @@ def filter_files_not_needed_for_inference(
return hf_weights_files


# explicitly use pure text format, with a newline at the end
# this makes it impossible to see the animation in the progress bar
# but will avoid messing up with ray or multiprocessing, which wraps
# each line of output with some prefix.
_BAR_FORMAT = "{desc}: {percentage:3.0f}% Completed | {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]\n" # noqa: E501


def np_cache_weights_iterator(
model_name_or_path: str, cache_dir: Optional[str], hf_folder: str,
hf_weights_files: List[str]
Expand All @@ -321,6 +328,8 @@ def np_cache_weights_iterator(
Will dump the model weights to numpy files if they are not already dumped.
"""
enable_tqdm = not torch.distributed.is_initialized(
) or torch.distributed.get_rank() == 0
# Convert the model weights from torch tensors to numpy arrays for
# faster loading.
np_folder = os.path.join(hf_folder, "np")
Expand All @@ -331,8 +340,12 @@ def np_cache_weights_iterator(
with get_lock(model_name_or_path, cache_dir):
if not os.path.exists(weight_names_file):
weight_names: List[str] = []
for bin_file in tqdm(hf_weights_files,
desc="Loading np_cache checkpoint shards"):
for bin_file in tqdm(
hf_weights_files,
desc="Loading np_cache checkpoint shards",
disable=not enable_tqdm,
bar_format=_BAR_FORMAT,
):
state = torch.load(bin_file, map_location="cpu")
for name, param in state.items():
param_path = os.path.join(np_folder, name)
Expand All @@ -356,8 +369,14 @@ def safetensors_weights_iterator(
hf_weights_files: List[str]
) -> Generator[Tuple[str, torch.Tensor], None, None]:
"""Iterate over the weights in the model safetensor files."""
for st_file in tqdm(hf_weights_files,
desc="Loading safetensors checkpoint shards"):
enable_tqdm = not torch.distributed.is_initialized(
) or torch.distributed.get_rank() == 0
for st_file in tqdm(
hf_weights_files,
desc="Loading safetensors checkpoint shards",
disable=not enable_tqdm,
bar_format=_BAR_FORMAT,
):
with safe_open(st_file, framework="pt") as f:
for name in f.keys(): # noqa: SIM118
param = f.get_tensor(name)
Expand All @@ -368,8 +387,14 @@ def pt_weights_iterator(
hf_weights_files: List[str]
) -> Generator[Tuple[str, torch.Tensor], None, None]:
"""Iterate over the weights in the model bin/pt files."""
for bin_file in tqdm(hf_weights_files,
desc="Loading pt checkpoint shards"):
enable_tqdm = not torch.distributed.is_initialized(
) or torch.distributed.get_rank() == 0
for bin_file in tqdm(
hf_weights_files,
desc="Loading pt checkpoint shards",
disable=not enable_tqdm,
bar_format=_BAR_FORMAT,
):
state = torch.load(bin_file, map_location="cpu")
for name, param in state.items():
yield name, param
Expand Down

0 comments on commit c520124

Please sign in to comment.