Skip to content

Commit 75b5425

Browse files
committed
address review
Signed-off-by: NickLucche <nlucches@redhat.com>
1 parent 58e7308 commit 75b5425

File tree

4 files changed

+103
-104
lines changed

4 files changed

+103
-104
lines changed

vllm/model_executor/layers/linear.py

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1178,3 +1178,98 @@ def extra_repr(self) -> str:
11781178
s += f", tp_size={self.tp_size}"
11791179
s += f", reduce_results={self.reduce_results}"
11801180
return s
1181+
1182+
1183+
class QKVCrossParallelLinear(torch.nn.Module):
1184+
1185+
def __init__(self,
1186+
hidden_size: int,
1187+
head_size: int,
1188+
total_num_heads: int,
1189+
total_num_kv_heads: Optional[int] = None,
1190+
bias: bool = True,
1191+
skip_bias_add: bool = False,
1192+
params_dtype: Optional[torch.dtype] = None,
1193+
quant_config: Optional[QuantizationConfig] = None,
1194+
prefix: str = ""):
1195+
super().__init__()
1196+
# Empty placeholders for loading as a single module.
1197+
self.weight = torch.nn.Parameter()
1198+
set_weight_attrs(self.weight, {
1199+
"weight_loader": self.weight_loader_weight,
1200+
})
1201+
# Use a dictionary to avoid submodules parameters auto-registration:
1202+
# drop-in replacement for a `QKVParallelLinear` module.
1203+
self.proj = dict()
1204+
self.proj["q_proj_decoder"] = ColumnParallelLinear(
1205+
input_size=hidden_size,
1206+
output_size=total_num_heads * head_size,
1207+
bias=bias,
1208+
quant_config=quant_config,
1209+
skip_bias_add=skip_bias_add,
1210+
params_dtype=params_dtype,
1211+
prefix=f"{prefix}.q_proj_decoder")
1212+
1213+
self.proj["kv_proj_encoder"] = QKVParallelLinear(
1214+
hidden_size=hidden_size,
1215+
head_size=head_size,
1216+
total_num_heads=0,
1217+
total_num_kv_heads=total_num_kv_heads,
1218+
bias=bias,
1219+
quant_config=quant_config,
1220+
skip_bias_add=skip_bias_add,
1221+
params_dtype=params_dtype,
1222+
prefix=f"{prefix}.kv_proj_encoder")
1223+
1224+
# `kv_proj_encoder.num_kv_heads` accounts for sharding with tp>1.
1225+
self.kv_size = self.kv_proj_encoder.num_kv_heads * head_size
1226+
1227+
if bias:
1228+
self.bias = torch.nn.Parameter()
1229+
set_weight_attrs(self.bias, {
1230+
"weight_loader": self.weight_loader_bias,
1231+
})
1232+
1233+
@property
1234+
def q_proj_decoder(self):
1235+
return self.proj["q_proj_decoder"]
1236+
1237+
@property
1238+
def kv_proj_encoder(self):
1239+
return self.proj["kv_proj_encoder"]
1240+
1241+
def forward(self, decoder_hidden_states, encoder_hidden_states):
1242+
q, _ = self.q_proj_decoder(decoder_hidden_states)
1243+
if encoder_hidden_states is None:
1244+
# Encoder KV already cached.
1245+
k = None
1246+
v = None
1247+
else:
1248+
# Prefill phase, encoder KV cached here.
1249+
kv_enc, _ = self.kv_proj_encoder(encoder_hidden_states)
1250+
# Split kv in half
1251+
k, v = kv_enc.split(self.kv_size, dim=-1)
1252+
return q, k, v
1253+
1254+
def weight_loader_weight(self,
1255+
param: torch.nn.Parameter,
1256+
loaded_weight: torch.Tensor,
1257+
loaded_shard_id: Optional[str] = None):
1258+
# NOTE Use QKV/ColumnParallel weight_loader, ignore placeholder param.
1259+
param = self.q_proj_decoder.weight if loaded_shard_id == "q" \
1260+
else self.kv_proj_encoder.weight
1261+
param.weight_loader(
1262+
param,
1263+
loaded_weight) if loaded_shard_id == "q" else param.weight_loader(
1264+
param, loaded_weight, loaded_shard_id)
1265+
1266+
def weight_loader_bias(self,
1267+
param: torch.nn.Parameter,
1268+
loaded_weight: torch.Tensor,
1269+
loaded_shard_id: Optional[str] = None):
1270+
param = self.q_proj_decoder.bias if loaded_shard_id == "q" \
1271+
else self.kv_proj_encoder.bias
1272+
param.weight_loader(
1273+
param,
1274+
loaded_weight) if loaded_shard_id == "q" else param.weight_loader(
1275+
param, loaded_weight, loaded_shard_id)

vllm/model_executor/models/bart.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
from vllm.distributed import get_tensor_model_parallel_world_size
3232
from vllm.model_executor.layers.activation import get_act_fn
3333
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
34+
QKVCrossParallelLinear,
3435
QKVParallelLinear,
3536
RowParallelLinear)
3637
from vllm.model_executor.layers.logits_processor import LogitsProcessor
@@ -44,7 +45,7 @@
4445
from vllm.sequence import IntermediateTensors
4546

4647
from .interfaces import SupportsV0Only
47-
from .utils import QKVCrossParallelLinear, maybe_prefix
48+
from .utils import maybe_prefix
4849

4950
logger = logging.get_logger(__name__)
5051

@@ -169,7 +170,7 @@ def __init__(
169170
# Number of KV heads is less than TP size, so we replicate
170171
# the KV heads across multiple tensor parallel GPUs.
171172
assert tp_world_size % self.total_num_kv_heads == 0
172-
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_world_size)
173+
self.num_kv_heads = self.num_heads
173174
self.q_size = self.num_heads * self.head_dim
174175
self.kv_size = self.num_kv_heads * self.head_dim
175176

@@ -248,7 +249,7 @@ def __init__(
248249
# Number of KV heads is less than TP size, so we replicate
249250
# the KV heads across multiple tensor parallel GPUs.
250251
assert tp_world_size % self.total_num_kv_heads == 0
251-
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_world_size)
252+
self.num_kv_heads = self.num_heads
252253
self.q_size = self.num_heads * self.head_dim
253254
self.kv_size = self.num_kv_heads * self.head_dim
254255

vllm/model_executor/models/mllama.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
from vllm.model_executor.layers.layernorm import RMSNorm
4545
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
4646
QKVParallelLinear,
47+
QKVCrossParallelLinear,
4748
RowParallelLinear)
4849
from vllm.model_executor.layers.logits_processor import LogitsProcessor
4950
from vllm.model_executor.layers.quantization import QuantizationConfig
@@ -66,7 +67,7 @@
6667
from .clip import CLIPMLP
6768
from .interfaces import SupportsMultiModal, SupportsV0Only
6869
from .llama import LlamaDecoderLayer, LlamaMLP
69-
from .utils import QKVCrossParallelLinear, maybe_prefix
70+
from .utils import maybe_prefix
7071

7172
logger = init_logger(__name__)
7273

@@ -806,6 +807,7 @@ def __init__(
806807
self.num_key_value_heads // self.tensor_parallel_size
807808
self.hidden_size = config.hidden_size
808809
self.head_dim = config.hidden_size // self.num_heads
810+
self.num_key_value_heads = config.num_key_value_heads
809811

810812
self.layer_idx = layer_idx
811813
self.num_key_value_groups = self.num_heads // self.num_key_value_heads

vllm/model_executor/models/utils.py

Lines changed: 1 addition & 100 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,7 @@
1212

1313
from vllm.config import VllmConfig
1414
from vllm.logger import init_logger
15-
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
16-
QKVParallelLinear)
17-
from vllm.model_executor.layers.quantization.base_config import (
18-
QuantizationConfig)
1915
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
20-
from vllm.model_executor.utils import set_weight_attrs
2116
from vllm.multimodal import MultiModalPlaceholderMap, NestedTensors
2217
from vllm.sequence import IntermediateTensors
2318
from vllm.utils import is_pin_memory_available
@@ -655,98 +650,4 @@ def cast_overflow_tensors(
655650
if tensors.isinf().any() or tensors.isnan().any():
656651
clamp_value = torch.finfo(tensors.dtype).max - offset
657652
tensors = torch.clamp(tensors, min=-clamp_value, max=clamp_value)
658-
return tensors
659-
660-
class QKVCrossParallelLinear(torch.nn.Module):
661-
662-
def __init__(self,
663-
hidden_size: int,
664-
head_size: int,
665-
total_num_heads: int,
666-
total_num_kv_heads: Optional[int] = None,
667-
bias: bool = True,
668-
skip_bias_add: bool = False,
669-
params_dtype: Optional[torch.dtype] = None,
670-
quant_config: Optional[QuantizationConfig] = None,
671-
prefix: str = ""):
672-
super().__init__()
673-
# Empty placeholders for loading as a single module.
674-
self.weight = torch.nn.Parameter()
675-
set_weight_attrs(self.weight, {
676-
"weight_loader": self.weight_loader_weight,
677-
})
678-
# Use a dictionary to avoid submodules parameters auto-registration:
679-
# drop-in replacement for a `QKVParallelLinear` module.
680-
self.proj = dict()
681-
self.proj["q_proj_decoder"] = ColumnParallelLinear(
682-
input_size=hidden_size,
683-
output_size=total_num_heads * head_size,
684-
bias=bias,
685-
quant_config=quant_config,
686-
skip_bias_add=skip_bias_add,
687-
params_dtype=params_dtype,
688-
prefix=f"{prefix}.q_proj_decoder")
689-
690-
self.proj["kv_proj_encoder"] = QKVParallelLinear(
691-
hidden_size=hidden_size,
692-
head_size=head_size,
693-
total_num_heads=0,
694-
total_num_kv_heads=total_num_kv_heads,
695-
bias=bias,
696-
quant_config=quant_config,
697-
skip_bias_add=skip_bias_add,
698-
params_dtype=params_dtype,
699-
prefix=f"{prefix}.kv_proj_encoder")
700-
701-
# `kv_proj_encoder.num_kv_heads` accounts for sharding with tp>1.
702-
self.kv_size = self.kv_proj_encoder.num_kv_heads * head_size
703-
704-
if bias:
705-
self.bias = torch.nn.Parameter()
706-
set_weight_attrs(self.bias, {
707-
"weight_loader": self.weight_loader_bias,
708-
})
709-
710-
@property
711-
def q_proj_decoder(self):
712-
return self.proj["q_proj_decoder"]
713-
714-
@property
715-
def kv_proj_encoder(self):
716-
return self.proj["kv_proj_encoder"]
717-
718-
def forward(self, decoder_hidden_states, encoder_hidden_states):
719-
q, _ = self.q_proj_decoder(decoder_hidden_states)
720-
if encoder_hidden_states is None:
721-
# Encoder KV already cached.
722-
k = None
723-
v = None
724-
else:
725-
# Prefill phase, encoder KV cached here.
726-
kv_enc, _ = self.kv_proj_encoder(encoder_hidden_states)
727-
# Split kv in half
728-
k, v = kv_enc.split(self.kv_size, dim=-1)
729-
return q, k, v
730-
731-
def weight_loader_weight(self,
732-
param: torch.nn.Parameter,
733-
loaded_weight: torch.Tensor,
734-
loaded_shard_id: Optional[str] = None):
735-
# NOTE Use QKV/ColumnParallel weight_loader, ignore placeholder param.
736-
param = self.q_proj_decoder.weight if loaded_shard_id == "q" \
737-
else self.kv_proj_encoder.weight
738-
param.weight_loader(
739-
param,
740-
loaded_weight) if loaded_shard_id == "q" else param.weight_loader(
741-
param, loaded_weight, loaded_shard_id)
742-
743-
def weight_loader_bias(self,
744-
param: torch.nn.Parameter,
745-
loaded_weight: torch.Tensor,
746-
loaded_shard_id: Optional[str] = None):
747-
param = self.q_proj_decoder.bias if loaded_shard_id == "q" \
748-
else self.kv_proj_encoder.bias
749-
param.weight_loader(
750-
param,
751-
loaded_weight) if loaded_shard_id == "q" else param.weight_loader(
752-
param, loaded_weight, loaded_shard_id)
653+
return tensors

0 commit comments

Comments
 (0)