Skip to content

Commit 61d61f7

Browse files
committed
init
Signed-off-by: Isotr0py <2037008807@qq.com>
1 parent 7caff01 commit 61d61f7

File tree

1 file changed

+21
-7
lines changed

1 file changed

+21
-7
lines changed

vllm/model_executor/layers/linear.py

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1229,7 +1229,7 @@ def extra_repr(self) -> str:
12291229
return s
12301230

12311231

1232-
class QKVCrossParallelLinear(torch.nn.Module):
1232+
class QKVCrossParallelLinear(LinearBase):
12331233

12341234
def __init__(self,
12351235
hidden_size: int,
@@ -1241,12 +1241,26 @@ def __init__(self,
12411241
params_dtype: Optional[torch.dtype] = None,
12421242
quant_config: Optional[QuantizationConfig] = None,
12431243
prefix: str = ""):
1244-
super().__init__()
1244+
# input_size and output_size are not used, just for alignment
1245+
input_size = hidden_size
1246+
output_size = (total_num_heads + (total_num_kv_heads or 0)) * head_size
1247+
super().__init__(input_size=input_size,
1248+
output_size=output_size,
1249+
skip_bias_add=skip_bias_add,
1250+
params_dtype=params_dtype,
1251+
quant_config=quant_config,
1252+
prefix=prefix)
1253+
12451254
# Empty placeholders for loading as a single module.
1246-
self.weight = torch.nn.Parameter()
1247-
set_weight_attrs(self.weight, {
1248-
"weight_loader": self.weight_loader_weight,
1249-
})
1255+
placeholder_size = 0
1256+
quant_method = quant_config.get_quant_method(self, prefix=prefix)
1257+
quant_method.create_weights(self,
1258+
placeholder_size, [placeholder_size],
1259+
placeholder_size,
1260+
placeholder_size,
1261+
self.params_dtype,
1262+
weight_loader=self.weight_loader_weight)
1263+
12501264
# Use a dictionary to avoid submodules parameters auto-registration:
12511265
# drop-in replacement for a `QKVParallelLinear` module.
12521266
self.proj = dict()
@@ -1321,4 +1335,4 @@ def weight_loader_bias(self,
13211335
param.weight_loader(
13221336
param,
13231337
loaded_weight) if loaded_shard_id == "q" else param.weight_loader(
1324-
param, loaded_weight, loaded_shard_id)
1338+
param, loaded_weight, loaded_shard_id)

0 commit comments

Comments
 (0)