diff --git a/vllm/model_executor/models/transformers.py b/vllm/model_executor/models/transformers.py index fc4585618b04..25b8b69e081b 100644 --- a/vllm/model_executor/models/transformers.py +++ b/vllm/model_executor/models/transformers.py @@ -107,10 +107,17 @@ def replace_linear_class( raise ValueError( f"Unsupported parallel style type {type(style)}, expected str") - vllm_linear_cls = { - "colwise": ColumnParallelLinear, - "rowwise": RowParallelLinear, - }.get(style, ReplicatedLinear) + vllm_linear_cls, vllm_linear_kwargs = { + "colwise": (ColumnParallelLinear, {}), + "colwise_rep": (ColumnParallelLinear, { + "gather_output": True + }), + "rowwise": (RowParallelLinear, {}), + "rowwise_rep": (RowParallelLinear, { + "input_is_parallel": False + }), + "replicate": (ReplicatedLinear, {}), + }.get(style, (ReplicatedLinear, {})) return vllm_linear_cls( input_size=linear.in_features, @@ -118,6 +125,7 @@ def replace_linear_class( bias=linear.bias is not None, quant_config=quant_config, return_bias=False, + **vllm_linear_kwargs, ) @@ -506,7 +514,7 @@ def tensor_parallel(self): # Some weight loaders expect linear layers to inherit from vLLM's # LinearBase class, so we set a default style which causes any # unspecified linear layers to be replaced with ReplicatedLinear - tp_plan[".*"] = "replicated" + tp_plan[".*"] = "replicate" def _tensor_parallel(module: nn.Module, prefix: str = ""): for child_name, child_module in module.named_children():