Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Misc] Add BNB quantization for Whisper #12381

Merged
merged 7 commits into from
Feb 4, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 60 additions & 42 deletions vllm/model_executor/model_loader/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -803,9 +803,11 @@ def _hf_weight_iter(self, hf_weights_files, use_safetensors: bool):
iterator = safetensors_weights_iterator(hf_weights_files)
else:
iterator = pt_weights_iterator(hf_weights_files)
for name, param in iterator:
# mapping weight names from transformers to vllm.
yield self.weight_mapper(name), param
for org_name, param in iterator:
# mapping weight names from transformers to vllm while preserving
# original names.
mapped_name = self.weight_mapper(org_name)
yield org_name, mapped_name, param

def _get_quantized_weights_iterator(
self,
Expand Down Expand Up @@ -866,24 +868,30 @@ def _is_4bit_weight_name(self, weight_name: str):

def _quantized_8bit_generator(self, hf_weights_files, use_safetensors,
quant_state_dict) -> Generator:
for weight_name, weight_tensor in self._hf_weight_iter(
hf_weights_files, use_safetensors):
if not weight_name.lower().endswith(".scb"):
for (
org_weight_name,
mapped_weight_name,
weight_tensor,
) in self._hf_weight_iter(hf_weights_files, use_safetensors):
if not mapped_weight_name.lower().endswith(".scb"):
continue

weight_key = weight_name.lower().replace(".scb", ".weight")
weight_key = mapped_weight_name.lower().replace(".scb", ".weight")
quant_state_dict[weight_key] = weight_tensor

for weight_name, weight_tensor in self._hf_weight_iter(
hf_weights_files, use_safetensors):
if self._is_8bit_weight_name(weight_name):
for (
org_weight_name,
mapped_weight_name,
weight_tensor,
) in self._hf_weight_iter(hf_weights_files, use_safetensors):
if self._is_8bit_weight_name(mapped_weight_name):
continue

if weight_name in quant_state_dict:
if mapped_weight_name in quant_state_dict:
set_weight_attrs(weight_tensor, {"load_in_8bit": True})
yield weight_name, weight_tensor
yield org_weight_name, weight_tensor
else:
yield weight_name, weight_tensor
yield org_weight_name, weight_tensor

def _quantized_4bit_generator(self, hf_weights_files, use_safetensors,
quant_state_dict) -> Generator:
Expand All @@ -893,15 +901,19 @@ def _quantized_4bit_generator(self, hf_weights_files, use_safetensors,
weight_iterator = self._hf_weight_iter(hf_weights_files,
use_safetensors)
temp_state_dict = {}
for weight_name, weight_tensor in weight_iterator:
if not self._is_4bit_weight_name(weight_name):
for (
org_weight_name,
mapped_weight_name,
weight_tensor,
) in weight_iterator:
if not self._is_4bit_weight_name(mapped_weight_name):
continue
# bitsandbytes library requires
# weight.quant_state.bitsandbytes__* in CPU
if "quant_state.bitsandbytes" in weight_name:
temp_state_dict[weight_name] = weight_tensor.cpu().data
if "quant_state.bitsandbytes" in mapped_weight_name:
temp_state_dict[mapped_weight_name] = weight_tensor.cpu().data
else:
temp_state_dict[weight_name] = weight_tensor
temp_state_dict[mapped_weight_name] = weight_tensor

# Closure to parse quant_state for each prequant weight
def _parse_quant_state(param_name: str,
Expand All @@ -915,20 +927,24 @@ def _parse_quant_state(param_name: str,

# Second iterate over all prequant and normal weights
# pre quantized weights would have a quant_state
for weight_name, weight_tensor in self._hf_weight_iter(
hf_weights_files, use_safetensors):
if self._is_4bit_weight_name(weight_name):
for (
org_weight_name,
mapped_weight_name,
weight_tensor,
) in self._hf_weight_iter(hf_weights_files, use_safetensors):
if self._is_4bit_weight_name(mapped_weight_name):
continue

if (f"{weight_name}.quant_state.bitsandbytes__nf4"
if (f"{mapped_weight_name}.quant_state.bitsandbytes__nf4"
in temp_state_dict) or (
f"{weight_name}.quant_state.bitsandbytes__fp4"
f"{mapped_weight_name}.quant_state.bitsandbytes__fp4"
in temp_state_dict):
quant_state = _parse_quant_state(weight_name, temp_state_dict)
quant_state_dict[weight_name] = quant_state
yield weight_name, weight_tensor
quant_state = _parse_quant_state(mapped_weight_name,
temp_state_dict)
quant_state_dict[mapped_weight_name] = quant_state
yield org_weight_name, weight_tensor
else:
yield weight_name, weight_tensor
yield org_weight_name, weight_tensor

def _unquantized_generator(self, hf_weights_files, use_safetensors,
quant_state_dict) -> Generator:
Expand All @@ -937,18 +953,22 @@ def _unquantized_generator(self, hf_weights_files, use_safetensors,
tp_size = get_tensor_model_parallel_world_size()
tp_rank = get_tensor_model_parallel_rank()

for weight_name, weight_tensor in self._hf_weight_iter(
hf_weights_files, use_safetensors):
if any(target_module in weight_name for target_module in
self.target_modules) and weight_name.endswith(".weight"):
for (
org_weight_name,
mapped_weight_name,
weight_tensor,
) in self._hf_weight_iter(hf_weights_files, use_safetensors):
if any(target_module in mapped_weight_name
for target_module in self.target_modules
) and mapped_weight_name.endswith(".weight"):
# Without sharding
if any(
weight_name.startswith(module)
mapped_weight_name.startswith(module)
for module in self.unsharded_weights_modules):
weight_sub_tensor = weight_tensor
# Shard by column
elif any(
weight_name.startswith(module)
mapped_weight_name.startswith(module)
for module in self.column_sharded_weights_modules):
total_size = weight_tensor.size(-1)
start_index = total_size // tp_size * tp_rank
Expand All @@ -958,14 +978,14 @@ def _unquantized_generator(self, hf_weights_files, use_safetensors,
# Weights have fused on disk. In this case, we assume that the
# weight and module use same name.
elif any(
weight_name.startswith(module)
mapped_weight_name.startswith(module)
for module in self.maybe_fused_weights_modules):
# special case for fused weights
# get the size of each shard weight tensor
total_shard_sizes = next(
(sizes for module, sizes in
self.maybe_fused_weights_modules.items()
if weight_name.startswith(module)))
if mapped_weight_name.startswith(module)))
total_size = weight_tensor.size(0)
assert total_size == sum(total_shard_sizes)
# get the start/end index of each shard weight tensor
Expand Down Expand Up @@ -1008,23 +1028,21 @@ def _unquantized_generator(self, hf_weights_files, use_safetensors,
quant_type="nf4",
)

quant_state_dict[weight_name] = quant_state
quant_state_dict[mapped_weight_name] = quant_state
else:
processed_weight = weight_tensor

yield weight_name, processed_weight
yield org_weight_name, processed_weight

def _get_bnb_target_modules(self, model: nn.Module) -> None:

for name, module in model.named_modules():
if isinstance(module, (LinearBase, )):
last_name = name.split(".")[-1]
if sub_modules := self.modules_mapping.packed_mapping.get(
last_name, []):
if modules_info := self.modules_mapping.get_sub_modules(name):
# Map vllm's names to transformers's names.
rep_name, sub_modules = modules_info
for sub_name in sub_modules:
self.target_modules.append(
name.replace(last_name, sub_name))
name.replace(rep_name, sub_name))
# Add original module name even if the module has stacked map,
# in case model has a mixture of disk-merged and disk-splitted
# weights with same last name.
Expand Down
7 changes: 7 additions & 0 deletions vllm/model_executor/model_loader/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,3 +131,10 @@ def __post_init__(self):
packed_name,
index,
)

def get_sub_modules(self,
module_name: str) -> Optional[Tuple[str, List[str]]]:
for key, value in self.packed_mapping.items():
if module_name.endswith(key):
return key, value
return None
17 changes: 15 additions & 2 deletions vllm/model_executor/models/whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -638,6 +638,19 @@ def input_mapper_for_whisper(
@MULTIMODAL_REGISTRY.register_max_multimodal_tokens(
"audio", get_max_whisper_audio_tokens)
class WhisperForConditionalGeneration(nn.Module, SupportsMultiModal):
packed_modules_mapping = {
"self_attn.qkv_proj": [
"self_attn.q_proj",
"self_attn.k_proj",
"self_attn.v_proj",
],
"encoder_attn.kv_proj": ["encoder_attn.k_proj", "encoder_attn.v_proj"],
}

hf_to_vllm_mapper = WeightsMapper(orig_to_new_substr={
".fc1.": ".mlp.fc1.",
".fc2.": ".mlp.fc2."
})

def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
Expand Down Expand Up @@ -731,10 +744,10 @@ def sample(
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
loader = AutoWeightsLoader(self, skip_prefixes=["proj_out."])
mapper = WeightsMapper({".fc1.": ".mlp.fc1.", ".fc2.": ".mlp.fc2."})

# add fake zeros bias for k_proj to state_dict
weights = _create_fake_bias_for_k_proj(weights)
return loader.load_weights(weights, mapper=mapper)
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)


def _create_fake_bias_for_k_proj(
Expand Down