@@ -107,17 +107,25 @@ def replace_linear_class(
107107 raise ValueError (
108108 f"Unsupported parallel style type { type (style )} , expected str" )
109109
110- vllm_linear_cls = {
111- "colwise" : ColumnParallelLinear ,
112- "rowwise" : RowParallelLinear ,
113- }.get (style , ReplicatedLinear )
110+ vllm_linear_cls , vllm_linear_kwargs = {
111+ "colwise" : (ColumnParallelLinear , {}),
112+ "colwise_rep" : (ColumnParallelLinear , {
113+ "gather_output" : True
114+ }),
115+ "rowwise" : (RowParallelLinear , {}),
116+ "rowwise_rep" : (RowParallelLinear , {
117+ "input_is_parallel" : False
118+ }),
119+ "replicate" : (ReplicatedLinear , {}),
120+ }.get (style , (ReplicatedLinear , {}))
114121
115122 return vllm_linear_cls (
116123 input_size = linear .in_features ,
117124 output_size = linear .out_features ,
118125 bias = linear .bias is not None ,
119126 quant_config = quant_config ,
120127 return_bias = False ,
128+ ** vllm_linear_kwargs ,
121129 )
122130
123131
@@ -506,7 +514,7 @@ def tensor_parallel(self):
506514 # Some weight loaders expect linear layers to inherit from vLLM's
507515 # LinearBase class, so we set a default style which causes any
508516 # unspecified linear layers to be replaced with ReplicatedLinear
509- tp_plan [".*" ] = "replicated "
517+ tp_plan [".*" ] = "replicate "
510518
511519 def _tensor_parallel (module : nn .Module , prefix : str = "" ):
512520 for child_name , child_module in module .named_children ():
0 commit comments