Skip to content

Commit

Permalink
[Misc] Add BNB quantization for Whisper (#12381)
Browse files Browse the repository at this point in the history
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
  • Loading branch information
jeejeelee authored Feb 4, 2025
1 parent c36ac98 commit 96b2362
Show file tree
Hide file tree
Showing 3 changed files with 82 additions and 44 deletions.
102 changes: 60 additions & 42 deletions vllm/model_executor/model_loader/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -803,9 +803,11 @@ def _hf_weight_iter(self, hf_weights_files, use_safetensors: bool):
iterator = safetensors_weights_iterator(hf_weights_files)
else:
iterator = pt_weights_iterator(hf_weights_files)
for name, param in iterator:
# mapping weight names from transformers to vllm.
yield self.weight_mapper(name), param
for org_name, param in iterator:
# mapping weight names from transformers to vllm while preserving
# original names.
mapped_name = self.weight_mapper(org_name)
yield org_name, mapped_name, param

def _get_quantized_weights_iterator(
self,
Expand Down Expand Up @@ -866,24 +868,30 @@ def _is_4bit_weight_name(self, weight_name: str):

def _quantized_8bit_generator(self, hf_weights_files, use_safetensors,
quant_state_dict) -> Generator:
for weight_name, weight_tensor in self._hf_weight_iter(
hf_weights_files, use_safetensors):
if not weight_name.lower().endswith(".scb"):
for (
org_weight_name,
mapped_weight_name,
weight_tensor,
) in self._hf_weight_iter(hf_weights_files, use_safetensors):
if not mapped_weight_name.lower().endswith(".scb"):
continue

weight_key = weight_name.lower().replace(".scb", ".weight")
weight_key = mapped_weight_name.lower().replace(".scb", ".weight")
quant_state_dict[weight_key] = weight_tensor

for weight_name, weight_tensor in self._hf_weight_iter(
hf_weights_files, use_safetensors):
if self._is_8bit_weight_name(weight_name):
for (
org_weight_name,
mapped_weight_name,
weight_tensor,
) in self._hf_weight_iter(hf_weights_files, use_safetensors):
if self._is_8bit_weight_name(mapped_weight_name):
continue

if weight_name in quant_state_dict:
if mapped_weight_name in quant_state_dict:
set_weight_attrs(weight_tensor, {"load_in_8bit": True})
yield weight_name, weight_tensor
yield org_weight_name, weight_tensor
else:
yield weight_name, weight_tensor
yield org_weight_name, weight_tensor

def _quantized_4bit_generator(self, hf_weights_files, use_safetensors,
quant_state_dict) -> Generator:
Expand All @@ -893,15 +901,19 @@ def _quantized_4bit_generator(self, hf_weights_files, use_safetensors,
weight_iterator = self._hf_weight_iter(hf_weights_files,
use_safetensors)
temp_state_dict = {}
for weight_name, weight_tensor in weight_iterator:
if not self._is_4bit_weight_name(weight_name):
for (
org_weight_name,
mapped_weight_name,
weight_tensor,
) in weight_iterator:
if not self._is_4bit_weight_name(mapped_weight_name):
continue
# bitsandbytes library requires
# weight.quant_state.bitsandbytes__* in CPU
if "quant_state.bitsandbytes" in weight_name:
temp_state_dict[weight_name] = weight_tensor.cpu().data
if "quant_state.bitsandbytes" in mapped_weight_name:
temp_state_dict[mapped_weight_name] = weight_tensor.cpu().data
else:
temp_state_dict[weight_name] = weight_tensor
temp_state_dict[mapped_weight_name] = weight_tensor

# Closure to parse quant_state for each prequant weight
def _parse_quant_state(param_name: str,
Expand All @@ -915,20 +927,24 @@ def _parse_quant_state(param_name: str,

# Second iterate over all prequant and normal weights
# pre quantized weights would have a quant_state
for weight_name, weight_tensor in self._hf_weight_iter(
hf_weights_files, use_safetensors):
if self._is_4bit_weight_name(weight_name):
for (
org_weight_name,
mapped_weight_name,
weight_tensor,
) in self._hf_weight_iter(hf_weights_files, use_safetensors):
if self._is_4bit_weight_name(mapped_weight_name):
continue

if (f"{weight_name}.quant_state.bitsandbytes__nf4"
if (f"{mapped_weight_name}.quant_state.bitsandbytes__nf4"
in temp_state_dict) or (
f"{weight_name}.quant_state.bitsandbytes__fp4"
f"{mapped_weight_name}.quant_state.bitsandbytes__fp4"
in temp_state_dict):
quant_state = _parse_quant_state(weight_name, temp_state_dict)
quant_state_dict[weight_name] = quant_state
yield weight_name, weight_tensor
quant_state = _parse_quant_state(mapped_weight_name,
temp_state_dict)
quant_state_dict[mapped_weight_name] = quant_state
yield org_weight_name, weight_tensor
else:
yield weight_name, weight_tensor
yield org_weight_name, weight_tensor

def _unquantized_generator(self, hf_weights_files, use_safetensors,
quant_state_dict) -> Generator:
Expand All @@ -937,18 +953,22 @@ def _unquantized_generator(self, hf_weights_files, use_safetensors,
tp_size = get_tensor_model_parallel_world_size()
tp_rank = get_tensor_model_parallel_rank()

for weight_name, weight_tensor in self._hf_weight_iter(
hf_weights_files, use_safetensors):
if any(target_module in weight_name for target_module in
self.target_modules) and weight_name.endswith(".weight"):
for (
org_weight_name,
mapped_weight_name,
weight_tensor,
) in self._hf_weight_iter(hf_weights_files, use_safetensors):
if any(target_module in mapped_weight_name
for target_module in self.target_modules
) and mapped_weight_name.endswith(".weight"):
# Without sharding
if any(
weight_name.startswith(module)
mapped_weight_name.startswith(module)
for module in self.unsharded_weights_modules):
weight_sub_tensor = weight_tensor
# Shard by column
elif any(
weight_name.startswith(module)
mapped_weight_name.startswith(module)
for module in self.column_sharded_weights_modules):
total_size = weight_tensor.size(-1)
start_index = total_size // tp_size * tp_rank
Expand All @@ -958,14 +978,14 @@ def _unquantized_generator(self, hf_weights_files, use_safetensors,
# Weights have fused on disk. In this case, we assume that the
# weight and module use same name.
elif any(
weight_name.startswith(module)
mapped_weight_name.startswith(module)
for module in self.maybe_fused_weights_modules):
# special case for fused weights
# get the size of each shard weight tensor
total_shard_sizes = next(
(sizes for module, sizes in
self.maybe_fused_weights_modules.items()
if weight_name.startswith(module)))
if mapped_weight_name.startswith(module)))
total_size = weight_tensor.size(0)
assert total_size == sum(total_shard_sizes)
# get the start/end index of each shard weight tensor
Expand Down Expand Up @@ -1008,23 +1028,21 @@ def _unquantized_generator(self, hf_weights_files, use_safetensors,
quant_type="nf4",
)

quant_state_dict[weight_name] = quant_state
quant_state_dict[mapped_weight_name] = quant_state
else:
processed_weight = weight_tensor

yield weight_name, processed_weight
yield org_weight_name, processed_weight

def _get_bnb_target_modules(self, model: nn.Module) -> None:

for name, module in model.named_modules():
if isinstance(module, (LinearBase, )):
last_name = name.split(".")[-1]
if sub_modules := self.modules_mapping.packed_mapping.get(
last_name, []):
if modules_info := self.modules_mapping.get_sub_modules(name):
# Map vllm's names to transformers's names.
rep_name, sub_modules = modules_info
for sub_name in sub_modules:
self.target_modules.append(
name.replace(last_name, sub_name))
name.replace(rep_name, sub_name))
# Add original module name even if the module has stacked map,
# in case model has a mixture of disk-merged and disk-splitted
# weights with same last name.
Expand Down
7 changes: 7 additions & 0 deletions vllm/model_executor/model_loader/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,3 +131,10 @@ def __post_init__(self):
packed_name,
index,
)

def get_sub_modules(self,
module_name: str) -> Optional[Tuple[str, List[str]]]:
for key, value in self.packed_mapping.items():
if module_name.endswith(key):
return key, value
return None
17 changes: 15 additions & 2 deletions vllm/model_executor/models/whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -638,6 +638,19 @@ def input_mapper_for_whisper(
@MULTIMODAL_REGISTRY.register_max_multimodal_tokens(
"audio", get_max_whisper_audio_tokens)
class WhisperForConditionalGeneration(nn.Module, SupportsMultiModal):
packed_modules_mapping = {
"self_attn.qkv_proj": [
"self_attn.q_proj",
"self_attn.k_proj",
"self_attn.v_proj",
],
"encoder_attn.kv_proj": ["encoder_attn.k_proj", "encoder_attn.v_proj"],
}

hf_to_vllm_mapper = WeightsMapper(orig_to_new_substr={
".fc1.": ".mlp.fc1.",
".fc2.": ".mlp.fc2."
})

def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
Expand Down Expand Up @@ -731,10 +744,10 @@ def sample(
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
loader = AutoWeightsLoader(self, skip_prefixes=["proj_out."])
mapper = WeightsMapper({".fc1.": ".mlp.fc1.", ".fc2.": ".mlp.fc2."})

# add fake zeros bias for k_proj to state_dict
weights = _create_fake_bias_for_k_proj(weights)
return loader.load_weights(weights, mapper=mapper)
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)


def _create_fake_bias_for_k_proj(
Expand Down

0 comments on commit 96b2362

Please sign in to comment.