diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index 1224b94d56e0..fa8a261db7d7 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -223,6 +223,7 @@ class LinearBase(CustomOp): quant_config: Quantization configure. prefix: Prefix for parameter names. return_bias: If true, return bias together with outputs in forward pass. + disable_tp: If true, tensor parallelism will be disabled for this layer. """ def __init__( @@ -235,6 +236,7 @@ def __init__( prefix: str = "", *, return_bias: bool = True, + disable_tp: bool = False, ): super().__init__() @@ -254,6 +256,17 @@ def __init__( self.quant_method = quant_config.get_quant_method(self, prefix=prefix) self.return_bias = return_bias + self.disable_tp = disable_tp + self.tp_rank = (get_tensor_model_parallel_rank() + if not disable_tp else 0) + self.tp_size = (get_tensor_model_parallel_world_size() + if not disable_tp else 1) + + def __post_init__(self): + for param in self.parameters(): + if isinstance(param, BasevLLMParameter): + param.tp_rank = self.tp_rank + param.tp_size = self.tp_size @CustomOp.register("replicated_linear") @@ -270,6 +283,7 @@ class ReplicatedLinear(LinearBase): prefix: The name of the layer in the state dict, including all parents (e.g. model.layers.0.qkv_proj) return_bias: If true, return bias together with outputs in forward pass. + disable_tp: Take no effect for replicated linear layers. """ def __init__( @@ -283,26 +297,21 @@ def __init__( prefix: str = "", *, return_bias: bool = True, + disable_tp: bool = False, ): - # If MergedReplicatedLinear, use output size of each partition. - if hasattr(self, "output_sizes"): - self.output_partition_sizes = self.output_sizes - else: - self.output_partition_sizes = [output_size] - super().__init__(input_size, output_size, skip_bias_add, params_dtype, quant_config, prefix=prefix, - return_bias=return_bias) + return_bias=return_bias, + disable_tp=disable_tp) # All the linear layer supports quant method. assert self.quant_method is not None self.quant_method.create_weights(self, - self.input_size, - self.output_partition_sizes, + self.input_size, [self.output_size], self.input_size, self.output_size, self.params_dtype, @@ -358,74 +367,6 @@ def extra_repr(self) -> str: return s -class MergedReplicatedLinear(ReplicatedLinear): - """Replicated linear layer. - - Args: - input_size: input dimension of the linear layer. - output_sizes: list of output dimensions of the linear layer. - bias: If true, add bias. - skip_bias_add: If true, skip adding bias but instead return it. - params_dtype: Data type for the parameters. - quant_config: Quantization configure. - prefix: The name of the layer in the state dict, including all parents - (e.g. model.layers.0.qkv_proj) - return_bias: If true, return bias together with outputs in forward pass. - """ - - def __init__( - self, - input_size: int, - output_sizes: list[int], - bias: bool = True, - skip_bias_add: bool = False, - params_dtype: Optional[torch.dtype] = None, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "", - *, - return_bias: bool = True, - ): - self.output_sizes = output_sizes - super().__init__(input_size, - sum(output_sizes), - bias, - skip_bias_add, - params_dtype, - quant_config, - prefix=prefix, - return_bias=return_bias) - - def weight_loader(self, - param: Union[Parameter, BasevLLMParameter], - loaded_weight: torch.Tensor, - loaded_shard_id: Optional[int] = None): - assert loaded_shard_id is not None - assert loaded_shard_id < len(self.output_sizes) - - if isinstance(param, BlockQuantScaleParameter): - from vllm.model_executor.layers.quantization.fp8 import ( - Fp8LinearMethod, Fp8MoEMethod) - assert self.quant_method is not None - assert isinstance(self.quant_method, - (Fp8LinearMethod, Fp8MoEMethod)) - weight_block_size = self.quant_method.quant_config.weight_block_size - assert weight_block_size is not None - block_n, _ = weight_block_size[0], weight_block_size[1] - shard_offset = ( - (sum(self.output_sizes[:loaded_shard_id]) + block_n - 1) // - block_n) - shard_size = ((self.output_sizes[loaded_shard_id] + block_n - 1) // - block_n) - elif isinstance(param, PerTensorScaleParameter): - shard_offset = loaded_shard_id - shard_size = 1 - else: - shard_offset = sum(self.output_sizes[:loaded_shard_id]) - shard_size = self.output_sizes[loaded_shard_id] - - param.data[shard_offset:shard_offset + shard_size] = loaded_weight - - @CustomOp.register("column_parallel_linear") class ColumnParallelLinear(LinearBase): """Linear layer with column parallelism. @@ -448,7 +389,9 @@ class ColumnParallelLinear(LinearBase): output_sizes: list of output sizes packed into one output, like for QKV the list would be size 3. prefix: The name of the layer in the state dict, including all parents - (e.g. model.layers.0.qkv_proj) + (e.g. model.layers.0.qkv_proj) + return_bias: If true, return bias together with outputs in forward pass. + disable_tp: If true, weights matrix won't be sharded through tp rank. """ def __init__( @@ -464,9 +407,13 @@ def __init__( prefix: str = "", *, return_bias: bool = True, + disable_tp: bool = False, ): # Divide the weight matrix along the last dimension. - self.tp_size = get_tensor_model_parallel_world_size() + self.tp_rank = (get_tensor_model_parallel_rank() + if not disable_tp else 0) + self.tp_size = (get_tensor_model_parallel_world_size() + if not disable_tp else 1) self.input_size_per_partition = input_size self.output_size_per_partition = divide(output_size, self.tp_size) self.output_partition_sizes = [self.output_size_per_partition] @@ -483,7 +430,8 @@ def __init__( params_dtype, quant_config, prefix, - return_bias=return_bias) + return_bias=return_bias, + disable_tp=disable_tp) self.gather_output = gather_output @@ -512,8 +460,6 @@ def __init__( else: self.register_parameter("bias", None) - self.tp_rank = get_tensor_model_parallel_rank() - def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): output_dim = getattr(param, "output_dim", None) @@ -554,7 +500,8 @@ def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): assert param_data.shape == loaded_weight.shape param_data.copy_(loaded_weight) - def weight_loader_v2(self, param: Parameter, loaded_weight: torch.Tensor): + def weight_loader_v2(self, param: BasevLLMParameter, + loaded_weight: torch.Tensor): # Special case for loading scales off disk, which often do not # have a shape (such as in the case of AutoFP8). if len(loaded_weight.shape) == 0: @@ -570,7 +517,7 @@ def forward( # Matrix multiply. assert self.quant_method is not None output_parallel = self.quant_method.apply(self, input_, bias) - if self.gather_output: + if self.gather_output and self.tp_size > 1: # All-gather across the partitions. output = tensor_model_parallel_all_gather(output_parallel) else: @@ -584,7 +531,7 @@ def extra_repr(self) -> str: s = f"in_features={self.input_size}" s += f", output_features={self.output_size_per_partition}" s += f", bias={self.bias is not None}" - s += f", tp_size={get_tensor_model_parallel_world_size()}" + s += f", tp_size={self.tp_size}" s += f", gather_output={self.gather_output}" return s @@ -611,6 +558,8 @@ class MergedColumnParallelLinear(ColumnParallelLinear): prefix: The name of the layer in the state dict, including all parents (e.g. model.layers.0.qkv_proj) return_bias: If true, return bias together with outputs in forward pass. + disable_tp: If true, all weights matrix won't be sharded, this layer + will be treated as a "Replicated" MergedLinear. """ def __init__( @@ -625,10 +574,13 @@ def __init__( prefix: str = "", *, return_bias: bool = True, + disable_tp: bool = False, ): self.output_sizes = output_sizes - self.tp_size = get_tensor_model_parallel_world_size() - self.tp_rank = get_tensor_model_parallel_rank() + self.tp_size = (get_tensor_model_parallel_world_size() + if not disable_tp else 1) + self.tp_rank = (get_tensor_model_parallel_rank() + if not disable_tp else 0) assert all(output_size % self.tp_size == 0 for output_size in output_sizes) @@ -640,7 +592,8 @@ def __init__( params_dtype=params_dtype, quant_config=quant_config, prefix=prefix, - return_bias=return_bias) + return_bias=return_bias, + disable_tp=disable_tp) def weight_loader(self, param: Parameter, @@ -832,8 +785,6 @@ def weight_loader_v2(self, assert loaded_shard_id < len(self.output_sizes) - tp_size = get_tensor_model_parallel_world_size() - if isinstance(param, BlockQuantScaleParameter): from vllm.model_executor.layers.quantization.fp8 import ( Fp8LinearMethod, Fp8MoEMethod) @@ -845,17 +796,19 @@ def weight_loader_v2(self, block_n, _ = weight_block_size[0], weight_block_size[1] shard_offset = ( (sum(self.output_sizes[:loaded_shard_id]) + block_n - 1) // - block_n) // tp_size + block_n) // self.tp_size shard_size = ((self.output_sizes[loaded_shard_id] + block_n - 1) // - block_n // tp_size) + block_n // self.tp_size) else: - shard_offset = sum(self.output_sizes[:loaded_shard_id]) // tp_size - shard_size = self.output_sizes[loaded_shard_id] // tp_size + shard_offset = sum( + self.output_sizes[:loaded_shard_id]) // self.tp_size + shard_size = self.output_sizes[loaded_shard_id] // self.tp_size param.load_merged_column_weight(loaded_weight=loaded_weight, shard_id=loaded_shard_id, shard_offset=shard_offset, - shard_size=shard_size) + shard_size=shard_size, + tp_rank=self.tp_rank) class QKVParallelLinear(ColumnParallelLinear): @@ -883,6 +836,7 @@ class QKVParallelLinear(ColumnParallelLinear): prefix: The name of the layer in the state dict, including all parents (e.g. model.layers.0.qkv_proj) return_bias: If true, return bias together with outputs in forward pass. + disable_tp: If true, weights matrix won't be sharded through tp rank. """ def __init__( @@ -898,6 +852,7 @@ def __init__( prefix: str = "", *, return_bias: bool = True, + disable_tp: bool = False, ): self.hidden_size = hidden_size self.head_size = head_size @@ -906,7 +861,8 @@ def __init__( total_num_kv_heads = total_num_heads self.total_num_kv_heads = total_num_kv_heads # Divide the weight matrix along the last dimension. - tp_size = get_tensor_model_parallel_world_size() + tp_size = (get_tensor_model_parallel_world_size() + if not disable_tp else 1) self.num_heads = divide(self.total_num_heads, tp_size) if tp_size >= self.total_num_kv_heads: self.num_kv_heads = 1 @@ -932,7 +888,8 @@ def __init__( params_dtype=params_dtype, quant_config=quant_config, prefix=prefix, - return_bias=return_bias) + return_bias=return_bias, + disable_tp=disable_tp) def _get_shard_offset_mapping(self, loaded_shard_id: str): shard_offset_mapping = { @@ -993,10 +950,13 @@ def weight_loader_v2(self, loaded_shard_id: Optional[str] = None): if loaded_shard_id is None: # special case for certain models if isinstance(param, PerTensorScaleParameter): - param.load_qkv_weight(loaded_weight=loaded_weight, shard_id=0) + param.load_qkv_weight(loaded_weight=loaded_weight, + shard_id=0, + tp_rank=self.tp_rank) return elif type(param) in (RowvLLMParameter, BasevLLMParameter): - param.load_qkv_weight(loaded_weight=loaded_weight) + param.load_qkv_weight(loaded_weight=loaded_weight, + tp_rank=self.tp_rank) return # TODO: @dsikka - move to parameter.py self._load_fused_module_from_checkpoint(param, loaded_weight) @@ -1020,7 +980,8 @@ def weight_loader_v2(self, num_heads=self.num_kv_head_replicas, shard_id=loaded_shard_id, shard_offset=shard_offset, - shard_size=shard_size) + shard_size=shard_size, + tp_rank=self.tp_rank) def weight_loader(self, param: Parameter, @@ -1226,6 +1187,7 @@ class RowParallelLinear(LinearBase): prefix: The name of the layer in the state dict, including all parents (e.g. model.layers.0.down_proj) return_bias: If true, return bias together with outputs in forward pass. + disable_tp: If true, weights matrix won't be sharded through tp rank. """ def __init__( @@ -1241,10 +1203,13 @@ def __init__( prefix: str = "", *, return_bias: bool = True, + disable_tp: bool = False, ): # Divide the weight matrix along the first dimension. - self.tp_rank = get_tensor_model_parallel_rank() - self.tp_size = get_tensor_model_parallel_world_size() + self.tp_rank = (get_tensor_model_parallel_rank() + if not disable_tp else 0) + self.tp_size = (get_tensor_model_parallel_world_size() + if not disable_tp else 1) self.input_size_per_partition = divide(input_size, self.tp_size) self.output_size_per_partition = output_size self.output_partition_sizes = [output_size] @@ -1255,7 +1220,8 @@ def __init__( params_dtype, quant_config, prefix, - return_bias=return_bias) + return_bias=return_bias, + disable_tp=disable_tp) self.input_is_parallel = input_is_parallel self.reduce_results = reduce_results @@ -1339,10 +1305,9 @@ def forward( if self.input_is_parallel: input_parallel = input_ else: - tp_rank = get_tensor_model_parallel_rank() splitted_input = split_tensor_along_last_dim( input_, num_partitions=self.tp_size) - input_parallel = splitted_input[tp_rank].contiguous() + input_parallel = splitted_input[self.tp_rank].contiguous() # Matrix multiply. assert self.quant_method is not None diff --git a/vllm/model_executor/model_loader/bitsandbytes_loader.py b/vllm/model_executor/model_loader/bitsandbytes_loader.py index b8393956eed3..c8dd1ec0ec3c 100644 --- a/vllm/model_executor/model_loader/bitsandbytes_loader.py +++ b/vllm/model_executor/model_loader/bitsandbytes_loader.py @@ -69,6 +69,7 @@ def __init__(self, load_config: LoadConfig): # Store all module names (from transformers) that support # BNB quantization. self.target_modules: list[str] = [] + self.tp_disabled_modules: list[str] = [] # Store the mapping of expert parameters for MoE models. self.expert_params_mapping: list[tuple[str, str, int, str]] = [] # mapping weight names from transformers to vllm. @@ -322,14 +323,24 @@ def _unquantized_generator(self, hf_weights_files, use_safetensors, quant_state_dict) -> Generator: from bitsandbytes.functional import quantize_4bit - tp_size = get_tensor_model_parallel_world_size() - tp_rank = get_tensor_model_parallel_rank() + global_tp_size = get_tensor_model_parallel_world_size() + global_tp_rank = get_tensor_model_parallel_rank() for ( org_weight_name, mapped_weight_name, weight_tensor, ) in self._hf_weight_iter(hf_weights_files, use_safetensors): + + # override tp_size and tp_rank if the module has disabled TP + if any(tp_disabled_module in mapped_weight_name + for tp_disabled_module in self.tp_disabled_modules): + tp_size = 1 + tp_rank = 0 + else: + tp_size = global_tp_size + tp_rank = global_tp_rank + if any(target_module in mapped_weight_name for target_module in self.target_modules ) and mapped_weight_name.endswith(".weight"): @@ -418,12 +429,16 @@ def _get_bnb_target_modules(self, model: nn.Module) -> None: # Map vllm's names to transformers's names. rep_name, sub_modules = modules_info for sub_name in sub_modules: - self.target_modules.append( - name.replace(rep_name, sub_name)) + new_name = name.replace(rep_name, sub_name) + self.target_modules.append(new_name) + if module.disable_tp: + self.tp_disabled_modules.append(new_name) # Add original module name even if the module has stacked map, # in case model has a mixture of disk-merged and disk-split # weights with same last name. self.target_modules.append(name) + if module.disable_tp: + self.tp_disabled_modules.append(name) elif isinstance(module, FusedMoE) and hasattr( module.quant_method, "quant_config"): # TODO: support FusedMoE with prequant and 8bit. diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index bb95a1dbf122..d65dcfebaeff 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -43,7 +43,6 @@ from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (ColumnParallelLinear, MergedColumnParallelLinear, - MergedReplicatedLinear, ReplicatedLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor @@ -435,12 +434,13 @@ def __init__( self.max_position_embeddings = max_position_embeddings if self.q_lora_rank is not None: - self.fused_qkv_a_proj = MergedReplicatedLinear( + self.fused_qkv_a_proj = MergedColumnParallelLinear( self.hidden_size, [self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], bias=False, quant_config=quant_config, - prefix=f"{prefix}.fused_qkv_a_proj") + prefix=f"{prefix}.fused_qkv_a_proj", + disable_tp=True) else: self.kv_a_proj_with_mqa = ReplicatedLinear( self.hidden_size, diff --git a/vllm/model_executor/models/glm4_1v.py b/vllm/model_executor/models/glm4_1v.py index f9fd5163d66b..fd5fecac67d6 100644 --- a/vllm/model_executor/models/glm4_1v.py +++ b/vllm/model_executor/models/glm4_1v.py @@ -51,14 +51,10 @@ from vllm.logger import init_logger from vllm.model_executor import SamplingMetadata from vllm.model_executor.layers.layernorm import RMSNorm -# yapf: disable from vllm.model_executor.layers.linear import (ColumnParallelLinear, MergedColumnParallelLinear, - MergedReplicatedLinear, QKVParallelLinear, - ReplicatedLinear, RowParallelLinear) -# yapf: enable from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.module_mapping import MultiModelKeys @@ -174,20 +170,22 @@ def __init__( use_data_parallel: bool = False, ): super().__init__() - cls_gate_up = (MergedReplicatedLinear - if use_data_parallel else MergedColumnParallelLinear) - self.gate_up_proj = cls_gate_up(input_size=in_features, - output_sizes=[hidden_features] * 2, - bias=bias, - quant_config=quant_config, - prefix=f"{prefix}.gate_up_proj") - cls_down = (ReplicatedLinear - if use_data_parallel else RowParallelLinear) - self.down_proj = cls_down(hidden_features, - in_features, - bias=bias, - quant_config=quant_config, - prefix=f"{prefix}.down_proj") + self.gate_up_proj = MergedColumnParallelLinear( + input_size=in_features, + output_sizes=[hidden_features] * 2, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.gate_up_proj", + disable_tp=use_data_parallel, + ) + self.down_proj = RowParallelLinear( + hidden_features, + in_features, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.down_proj", + disable_tp=use_data_parallel, + ) self.act_fn = SiluAndMul() def forward(self, x: torch.Tensor): @@ -234,48 +232,32 @@ def __init__( # Per attention head and per partition values. self.tp_size = (1 if use_data_parallel else get_tensor_model_parallel_world_size()) - self.tp_rank = parallel_state.get_tensor_model_parallel_rank() + self.tp_rank = (0 if use_data_parallel else + parallel_state.get_tensor_model_parallel_rank()) self.hidden_size_per_attention_head = dist_utils.divide( projection_size, num_heads) self.num_attention_heads_per_partition = dist_utils.divide( num_heads, self.tp_size) - if use_data_parallel: - self.qkv = ReplicatedLinear( - input_size=embed_dim, - output_size=3 * projection_size, - bias=False, - quant_config=quant_config, - # Change qkv prefix to align with GLM-4.5V-FP8 quantization cfg - prefix=f"{prefix}.qkv_proj" - if quant_config else f"{prefix}.qkv", - ) - self.proj = ReplicatedLinear( - input_size=projection_size, - output_size=embed_dim, - quant_config=quant_config, - prefix=f"{prefix}.proj", - bias=False, - ) - else: - self.qkv = QKVParallelLinear( - hidden_size=embed_dim, - head_size=self.hidden_size_per_attention_head, - total_num_heads=num_heads, - total_num_kv_heads=num_heads, - bias=False, - quant_config=quant_config, - # Change qkv prefix to align with GLM-4.5V-FP8 quantization cfg - prefix=f"{prefix}.qkv_proj" - if quant_config else f"{prefix}.qkv", - ) - self.proj = RowParallelLinear( - input_size=projection_size, - output_size=embed_dim, - quant_config=quant_config, - prefix=f"{prefix}.proj", - bias=False, - ) + self.qkv = QKVParallelLinear( + hidden_size=embed_dim, + head_size=self.hidden_size_per_attention_head, + total_num_heads=num_heads, + total_num_kv_heads=num_heads, + bias=False, + quant_config=quant_config, + # Change qkv prefix to align with GLM-4.5V-FP8 quantization cfg + prefix=f"{prefix}.qkv_proj" if quant_config else f"{prefix}.qkv", + disable_tp=use_data_parallel, + ) + self.proj = RowParallelLinear( + input_size=projection_size, + output_size=embed_dim, + quant_config=quant_config, + prefix=f"{prefix}.proj", + bias=False, + disable_tp=use_data_parallel, + ) # Detect attention implementation. self.attn_backend: _Backend = get_vit_attn_backend(support_fa=True) @@ -494,41 +476,31 @@ def __init__( ) -> None: super().__init__() self.hidden_size = d_model - if use_data_parallel: - self.proj = ReplicatedLinear( - input_size=self.hidden_size, - output_size=self.hidden_size, - bias=bias, - quant_config=quant_config, - prefix=f"{prefix}.proj", - ) - else: - self.proj = ColumnParallelLinear( - self.hidden_size, - self.hidden_size, - bias=bias, - gather_output=True, - quant_config=quant_config, - prefix=f"{prefix}.proj", - ) + self.proj = ColumnParallelLinear( + self.hidden_size, + self.hidden_size, + bias=bias, + gather_output=True, + quant_config=quant_config, + prefix=f"{prefix}.proj", + disable_tp=use_data_parallel, + ) self.post_projection_norm = nn.LayerNorm(self.hidden_size) - cls_gate_up = (MergedReplicatedLinear - if use_data_parallel else MergedColumnParallelLinear) - self.gate_up_proj = cls_gate_up( + self.gate_up_proj = MergedColumnParallelLinear( input_size=self.hidden_size, output_sizes=[context_dim] * 2, bias=bias, quant_config=quant_config, prefix=f"{prefix}.gate_up_proj", + disable_tp=use_data_parallel, ) - cls_down = (ReplicatedLinear - if use_data_parallel else RowParallelLinear) - self.down_proj = cls_down( + self.down_proj = RowParallelLinear( context_dim, self.hidden_size, bias=bias, quant_config=quant_config, prefix=f"{prefix}.down_proj", + disable_tp=use_data_parallel, ) self.act_fn = SiluAndMul() self.extra_activation_func = nn.GELU() diff --git a/vllm/model_executor/models/qwen2_5_vl.py b/vllm/model_executor/models/qwen2_5_vl.py index c8f7fc16b4e8..0a89f86fc738 100644 --- a/vllm/model_executor/models/qwen2_5_vl.py +++ b/vllm/model_executor/models/qwen2_5_vl.py @@ -48,7 +48,6 @@ # yapf: disable from vllm.model_executor.layers.linear import (ColumnParallelLinear, MergedColumnParallelLinear, - MergedReplicatedLinear, QKVParallelLinear, ReplicatedLinear, RowParallelLinear) @@ -178,22 +177,20 @@ def __init__(self, prefix: str = "", use_data_parallel: bool = False): super().__init__() - cls_gate_up_proj = (MergedReplicatedLinear if use_data_parallel else - MergedColumnParallelLinear) - self.gate_up_proj = cls_gate_up_proj( + self.gate_up_proj = MergedColumnParallelLinear( input_size=in_features, output_sizes=[hidden_features] * 2, # [gate_proj, up_proj] bias=bias, quant_config=quant_config, - prefix=f"{prefix}.gate_up_proj") - - cls_down_proj = (ReplicatedLinear - if use_data_parallel else RowParallelLinear) - self.down_proj = cls_down_proj(hidden_features, - in_features, - bias=bias, - quant_config=quant_config, - prefix=f"{prefix}.down_proj") + prefix=f"{prefix}.gate_up_proj", + disable_tp=use_data_parallel) + + self.down_proj = RowParallelLinear(hidden_features, + in_features, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.down_proj", + disable_tp=use_data_parallel) self.act_fn = act_fn def forward(self, x: torch.Tensor): @@ -243,30 +240,21 @@ def __init__( self.num_attention_heads_per_partition = dist_utils.divide( num_heads, self.tp_size) - if use_data_parallel: - self.qkv = ReplicatedLinear(embed_dim, - self.hidden_size_per_attention_head * - 3 * num_heads, - bias=True, - quant_config=quant_config, - prefix=f"{prefix}.qkv") - - else: - self.qkv = QKVParallelLinear( - hidden_size=embed_dim, - head_size=self.hidden_size_per_attention_head, - total_num_heads=num_heads, - total_num_kv_heads=num_heads, - bias=True, - quant_config=quant_config, - prefix=f"{prefix}.qkv") - - cls_proj = (ReplicatedLinear - if use_data_parallel else RowParallelLinear) - self.proj = cls_proj(input_size=projection_size, - output_size=embed_dim, - quant_config=quant_config, - prefix=f"{prefix}.proj") + self.qkv = QKVParallelLinear( + hidden_size=embed_dim, + head_size=self.hidden_size_per_attention_head, + total_num_heads=num_heads, + total_num_kv_heads=num_heads, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.qkv", + disable_tp=use_data_parallel) + + self.proj = RowParallelLinear(input_size=projection_size, + output_size=embed_dim, + quant_config=quant_config, + prefix=f"{prefix}.proj", + disable_tp=use_data_parallel) # Detect attention implementation. self.attn_backend: _Backend = get_vit_attn_backend(support_fa=True) diff --git a/vllm/model_executor/models/step3_vl.py b/vllm/model_executor/models/step3_vl.py index f379d2c15fb6..17299b64978e 100644 --- a/vllm/model_executor/models/step3_vl.py +++ b/vllm/model_executor/models/step3_vl.py @@ -21,7 +21,6 @@ from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.linear import (ColumnParallelLinear, QKVParallelLinear, - ReplicatedLinear, RowParallelLinear) from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler @@ -667,35 +666,21 @@ def __init__(self, self.q_size = self.num_heads * self.head_dim - if use_data_parallel: - self.qkv_proj = ReplicatedLinear( - self.embed_dim, - 3 * self.q_size, - bias=True, - quant_config=quant_config, - prefix=prefix, - ) - self.out_proj = ReplicatedLinear( - self.total_num_heads * self.head_dim, - self.embed_dim, - bias=True, - quant_config=quant_config, - prefix=prefix, - ) - else: - self.qkv_proj = QKVParallelLinear( - self.embed_dim, - self.head_dim, - self.total_num_heads, - bias=True, - quant_config=quant_config, - prefix=prefix, - ) - self.out_proj = RowParallelLinear(self.embed_dim, - self.embed_dim, - bias=True, - quant_config=quant_config, - prefix=prefix) + self.qkv_proj = QKVParallelLinear( + self.embed_dim, + self.head_dim, + self.total_num_heads, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.qkv_proj", + disable_tp=use_data_parallel, + ) + self.out_proj = RowParallelLinear(self.embed_dim, + self.embed_dim, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.out_proj", + disable_tp=use_data_parallel) def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): return tensor.view(bsz, seq_len, self.num_heads, @@ -740,20 +725,18 @@ def __init__(self, super().__init__() self.config = config self.activation_fn = get_act_fn(config.hidden_act) - cls_fc1 = (ReplicatedLinear - if use_data_parallel else ColumnParallelLinear) - self.fc1 = cls_fc1(config.hidden_size, - config.intermediate_size, - bias=True, - quant_config=quant_config, - prefix=prefix) - cls_fc2 = (ReplicatedLinear - if use_data_parallel else RowParallelLinear) - self.fc2 = cls_fc2(config.intermediate_size, - config.hidden_size, - bias=True, - quant_config=quant_config, - prefix=prefix) + self.fc1 = ColumnParallelLinear(config.hidden_size, + config.intermediate_size, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.fc1", + disable_tp=use_data_parallel) + self.fc2 = RowParallelLinear(config.intermediate_size, + config.hidden_size, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.fc2", + disable_tp=use_data_parallel) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: hidden_states, _ = self.fc1(hidden_states) diff --git a/vllm/model_executor/parameter.py b/vllm/model_executor/parameter.py index 9465308e94e6..221712ba9a33 100644 --- a/vllm/model_executor/parameter.py +++ b/vllm/model_executor/parameter.py @@ -57,6 +57,8 @@ def __init__(self, data: torch.Tensor, weight_loader: Callable): weight_loader = _make_synced_weight_loader(weight_loader) self._weight_loader = weight_loader + self.tp_rank = get_tensor_model_parallel_rank() + self.tp_size = get_tensor_model_parallel_world_size() @property def weight_loader(self): @@ -116,10 +118,10 @@ def output_dim(self): return self._output_dim def load_column_parallel_weight(self, loaded_weight: torch.Tensor): - tp_rank = get_tensor_model_parallel_rank() shard_size = self.data.shape[self.output_dim] loaded_weight = loaded_weight.narrow(self.output_dim, - tp_rank * shard_size, shard_size) + self.tp_rank * shard_size, + shard_size) assert self.data.shape == loaded_weight.shape self.data.copy_(loaded_weight) @@ -127,6 +129,7 @@ def load_merged_column_weight(self, loaded_weight: torch.Tensor, **kwargs): shard_offset = kwargs.get("shard_offset") shard_size = kwargs.get("shard_size") + # TODO: move these to PackedColumnParameter and PackedvLLMParameter if isinstance( self, @@ -137,11 +140,11 @@ def load_merged_column_weight(self, loaded_weight: torch.Tensor, **kwargs): param_data = self.data - tp_rank = get_tensor_model_parallel_rank() param_data = param_data.narrow(self.output_dim, shard_offset, shard_size) loaded_weight = loaded_weight.narrow(self.output_dim, - tp_rank * shard_size, shard_size) + self.tp_rank * shard_size, + shard_size) assert param_data.shape == loaded_weight.shape param_data.copy_(loaded_weight) @@ -161,8 +164,8 @@ def load_qkv_weight(self, loaded_weight: torch.Tensor, **kwargs): shard_offset=shard_offset, shard_size=shard_size) param_data = self.data - tp_rank = get_tensor_model_parallel_rank() - shard_id = tp_rank if shard_id == "q" else tp_rank // num_heads + shard_id = (self.tp_rank if shard_id == "q" else self.tp_rank // + num_heads) param_data = param_data.narrow(self.output_dim, shard_offset, shard_size) loaded_weight = loaded_weight.narrow(self.output_dim, @@ -189,10 +192,10 @@ def input_dim(self): return self._input_dim def load_row_parallel_weight(self, loaded_weight: torch.Tensor): - tp_rank = get_tensor_model_parallel_rank() shard_size = self.data.shape[self.input_dim] loaded_weight = loaded_weight.narrow(self.input_dim, - tp_rank * shard_size, shard_size) + self.tp_rank * shard_size, + shard_size) if len(loaded_weight.shape) == 0: loaded_weight = loaded_weight.reshape(1) @@ -414,9 +417,6 @@ def __init__(self, input_dim: int = 1, output_dim: int = 0, **kwargs): "weight_loader": self._fake_weight_loader } - self.tp_rank = get_tensor_model_parallel_rank() - self.tp_size = get_tensor_model_parallel_world_size() - if self.tp_size > 1: raise NotImplementedError(f"{self.__class__.__name__} does not " "currently support tensor parallelism")