Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 13 additions & 5 deletions vllm/model_executor/models/transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,17 +107,25 @@ def replace_linear_class(
raise ValueError(
f"Unsupported parallel style type {type(style)}, expected str")

vllm_linear_cls = {
"colwise": ColumnParallelLinear,
"rowwise": RowParallelLinear,
}.get(style, ReplicatedLinear)
vllm_linear_cls, vllm_linear_kwargs = {
"colwise": (ColumnParallelLinear, {}),
"colwise_rep": (ColumnParallelLinear, {
"gather_output": True
}),
"rowwise": (RowParallelLinear, {}),
"rowwise_rep": (RowParallelLinear, {
"input_is_parallel": False
}),
"replicate": (ReplicatedLinear, {}),
}.get(style, (ReplicatedLinear, {}))

return vllm_linear_cls(
input_size=linear.in_features,
output_size=linear.out_features,
bias=linear.bias is not None,
quant_config=quant_config,
return_bias=False,
**vllm_linear_kwargs,
)


Expand Down Expand Up @@ -506,7 +514,7 @@ def tensor_parallel(self):
# Some weight loaders expect linear layers to inherit from vLLM's
# LinearBase class, so we set a default style which causes any
# unspecified linear layers to be replaced with ReplicatedLinear
tp_plan[".*"] = "replicated"
tp_plan[".*"] = "replicate"

def _tensor_parallel(module: nn.Module, prefix: str = ""):
for child_name, child_module in module.named_children():
Expand Down