Skip to content

Commit 249824c

Browse files
authored
Refactor Linear handling in TransformersModel (#12727)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
1 parent 64862d1 commit 249824c

File tree

2 files changed

+48
-58
lines changed

2 files changed

+48
-58
lines changed

vllm/model_executor/layers/linear.py

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import itertools
44
from abc import abstractmethod
5-
from typing import Dict, List, Optional, Tuple
5+
from typing import Optional
66

77
import torch
88
import torch.nn.functional as F
@@ -47,8 +47,8 @@ def adjust_marlin_shard(param, shard_size, shard_offset):
4747

4848

4949
def adjust_bitsandbytes_4bit_shard(param: Parameter,
50-
shard_offsets: Dict[str, Tuple[int, int]],
51-
loaded_shard_id: str) -> Tuple[int, int]:
50+
shard_offsets: dict[str, tuple[int, int]],
51+
loaded_shard_id: str) -> tuple[int, int]:
5252
"""Adjust the quantization offsets and sizes for BitsAndBytes sharding."""
5353

5454
total, _ = shard_offsets["total"]
@@ -90,7 +90,7 @@ class LinearMethodBase(QuantizeMethodBase):
9090
@abstractmethod
9191
def create_weights(self, layer: torch.nn.Module,
9292
input_size_per_partition: int,
93-
output_partition_sizes: List[int], input_size: int,
93+
output_partition_sizes: list[int], input_size: int,
9494
output_size: int, params_dtype: torch.dtype,
9595
**extra_weight_attrs):
9696
"""Create weights for a linear layer.
@@ -123,7 +123,7 @@ class UnquantizedLinearMethod(LinearMethodBase):
123123

124124
def create_weights(self, layer: torch.nn.Module,
125125
input_size_per_partition: int,
126-
output_partition_sizes: List[int], input_size: int,
126+
output_partition_sizes: list[int], input_size: int,
127127
output_size: int, params_dtype: torch.dtype,
128128
**extra_weight_attrs):
129129
weight = Parameter(torch.empty(sum(output_partition_sizes),
@@ -179,7 +179,8 @@ def __init__(
179179
self.quant_method = quant_config.get_quant_method(self,
180180
prefix=prefix)
181181

182-
def forward(self, x: torch.Tensor) -> torch.Tensor:
182+
def forward(self,
183+
x: torch.Tensor) -> tuple[torch.Tensor, Optional[Parameter]]:
183184
raise NotImplementedError
184185

185186

@@ -240,9 +241,8 @@ def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
240241
assert param.size() == loaded_weight.size()
241242
param.data.copy_(loaded_weight)
242243

243-
def forward(
244-
self, x: torch.Tensor
245-
) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]:
244+
def forward(self,
245+
x: torch.Tensor) -> tuple[torch.Tensor, Optional[Parameter]]:
246246
bias = self.bias if not self.skip_bias_add else None
247247
assert self.quant_method is not None
248248
output = self.quant_method.apply(self, x, bias)
@@ -288,7 +288,7 @@ def __init__(self,
288288
skip_bias_add: bool = False,
289289
params_dtype: Optional[torch.dtype] = None,
290290
quant_config: Optional[QuantizationConfig] = None,
291-
output_sizes: Optional[List[int]] = None,
291+
output_sizes: Optional[list[int]] = None,
292292
prefix: str = ""):
293293
super().__init__(input_size, output_size, skip_bias_add, params_dtype,
294294
quant_config, prefix)
@@ -374,7 +374,7 @@ def weight_loader_v2(self, param: Parameter, loaded_weight: torch.Tensor):
374374
loaded_weight = loaded_weight.reshape(1)
375375
param.load_column_parallel_weight(loaded_weight=loaded_weight)
376376

377-
def forward(self, input_):
377+
def forward(self, input_) -> tuple[torch.Tensor, Optional[Parameter]]:
378378
bias = self.bias if not self.skip_bias_add else None
379379

380380
# Matrix multiply.
@@ -422,7 +422,7 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
422422

423423
def __init__(self,
424424
input_size: int,
425-
output_sizes: List[int],
425+
output_sizes: list[int],
426426
bias: bool = True,
427427
gather_output: bool = False,
428428
skip_bias_add: bool = False,
@@ -500,7 +500,7 @@ def weight_loader(self,
500500
current_shard_offset = 0
501501
use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit",
502502
False)
503-
shard_offsets: List[Tuple[int, int, int]] = []
503+
shard_offsets: list[tuple[int, int, int]] = []
504504
for i, output_size in enumerate(self.output_sizes):
505505
shard_offsets.append((i, current_shard_offset, output_size))
506506
current_shard_offset += output_size
@@ -602,7 +602,7 @@ def _load_fused_module_from_checkpoint(self, param: BasevLLMParameter,
602602
"""
603603

604604
current_shard_offset = 0
605-
shard_offsets: List[Tuple[int, int, int]] = []
605+
shard_offsets: list[tuple[int, int, int]] = []
606606
for i, output_size in enumerate(self.output_sizes):
607607
shard_offsets.append((i, current_shard_offset, output_size))
608608
current_shard_offset += output_size
@@ -1124,7 +1124,7 @@ def weight_loader_v2(self, param: BasevLLMParameter,
11241124

11251125
param.load_row_parallel_weight(loaded_weight=loaded_weight)
11261126

1127-
def forward(self, input_):
1127+
def forward(self, input_) -> tuple[torch.Tensor, Optional[Parameter]]:
11281128
if self.input_is_parallel:
11291129
input_parallel = input_
11301130
else:

vllm/model_executor/models/transformers.py

Lines changed: 33 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
# SPDX-License-Identifier: Apache-2.0
2+
23
# Copyright 2024 The vLLM team.
34
#
45
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -14,7 +15,7 @@
1415
# limitations under the License.
1516
"""Wrapper around `transformers` models"""
1617
import re
17-
from typing import Iterable, List, Optional, Set, Tuple, Union
18+
from typing import Iterable, Optional, Union
1819

1920
import torch
2021
from torch import nn
@@ -71,23 +72,10 @@ def vllm_flash_attention_forward(
7172
ALL_ATTENTION_FUNCTIONS["vllm"] = vllm_flash_attention_forward
7273

7374

74-
# Linear Layer that is compatible with transformers internal forward
75-
# TODO: This is a temporary solution, we should find a better way to integrate
76-
class HFColumnParallelLinear(ColumnParallelLinear):
77-
78-
def forward(self, input: torch.Tensor) -> torch.Tensor:
79-
return super().forward(input)[0]
80-
81-
82-
class HFRowParallelLinear(RowParallelLinear):
83-
84-
def forward(self, input: torch.Tensor) -> torch.Tensor:
85-
return super().forward(input)[0]
86-
87-
88-
def replace_tp_linear_class(orig_module: nn.Linear,
89-
style: str,
90-
quant_config=None):
75+
def replace_linear_class(
76+
linear: nn.Linear,
77+
style: str,
78+
quant_config=None) -> Union[ColumnParallelLinear, RowParallelLinear]:
9179
"""
9280
In model configurations, we use a neutral type (string) to specify parallel
9381
styles, here we use it to translate nn.Linear into vllm-style tp Linear.
@@ -99,26 +87,28 @@ def replace_tp_linear_class(orig_module: nn.Linear,
9987
raise ValueError(
10088
f"Unsupported parallel style type {type(style)}, expected str")
10189

102-
input_size = orig_module.in_features
103-
output_size = orig_module.out_features
104-
bias = orig_module.bias is not None
90+
vllm_linear_cls = {
91+
"colwise": ColumnParallelLinear,
92+
"rowwise": RowParallelLinear,
93+
}.get(style)
10594

106-
if style == "colwise":
107-
return HFColumnParallelLinear(
108-
input_size,
109-
output_size,
110-
bias,
111-
)
112-
elif style == "rowwise":
113-
return HFRowParallelLinear(
114-
input_size,
115-
output_size,
116-
bias,
117-
)
118-
# We don't consider colwise_rep since it's used in lm_head
119-
else:
95+
if vllm_linear_cls is None:
12096
raise ValueError(f"Unsupported parallel style value: {style}")
12197

98+
class HFCompatibleLinear(vllm_linear_cls):
99+
"""
100+
Wrapper class that removes `output_bias` from returned output.
101+
"""
102+
103+
def forward(self, input: torch.Tensor) -> torch.Tensor:
104+
return super().forward(input)[0]
105+
106+
return HFCompatibleLinear(
107+
input_size=linear.in_features,
108+
output_size=linear.out_features,
109+
bias=linear.bias is not None,
110+
)
111+
122112

123113
class TransformersModel(nn.Module):
124114
embedding_padding_modules = ["lm_head"]
@@ -192,16 +182,16 @@ def tensor_parallelize(self, module: nn.Module, prefix: str = ""):
192182
"support it yet!")
193183

194184
for child_name, child_module in module.named_children():
195-
qual_name = prefix + child_name
185+
qual_name = maybe_prefix(prefix, child_name)
196186
for pattern, style in self.config.base_model_tp_plan.items():
197187
if re.match(pattern, qual_name) and isinstance(
198188
child_module, nn.Linear):
199-
new_module = replace_tp_linear_class(
200-
child_module, style, self.quant_config)
189+
new_module = replace_linear_class(child_module, style,
190+
self.quant_config)
201191
setattr(module, child_name, new_module)
202192
self.log_replacement(qual_name, child_module, new_module)
203193
else:
204-
self.tensor_parallelize(child_module, prefix=f"{qual_name}.")
194+
self.tensor_parallelize(child_module, prefix=qual_name)
205195

206196
def replace_vocab_embed_class(self, module: nn.Module):
207197
# Use native set input embeddings
@@ -219,7 +209,7 @@ def forward(
219209
self,
220210
input_ids: torch.Tensor,
221211
positions: torch.Tensor,
222-
kv_caches: List[torch.Tensor], # argument not used
212+
kv_caches: list[torch.Tensor], # argument not used
223213
attn_metadata: AttentionMetadata,
224214
intermediate_tensors: Optional[IntermediateTensors] = None,
225215
inputs_embeds: Optional[torch.Tensor] = None,
@@ -249,10 +239,10 @@ def sample(self, logits: torch.Tensor,
249239
next_tokens = self.sampler(logits, sampling_metadata)
250240
return next_tokens
251241

252-
def load_weights(self, weights: Iterable[Tuple[str,
253-
torch.Tensor]]) -> Set[str]:
242+
def load_weights(self, weights: Iterable[tuple[str,
243+
torch.Tensor]]) -> set[str]:
254244
params_dict = dict(self.named_parameters())
255-
loaded_params: Set[str] = set()
245+
loaded_params = set[str]()
256246
for name, loaded_weight in weights:
257247
if name not in params_dict:
258248
name = f"{self.model.base_model_prefix}.{name}"

0 commit comments

Comments
 (0)