From 7c7586808e616dd4aa85a2e81cf7aa345af0d6e7 Mon Sep 17 00:00:00 2001 From: Mor Zusman Date: Sun, 14 Apr 2024 13:46:40 +0300 Subject: [PATCH] Tensor parallelism (#7) * Return support for other models apart from jamba * Support n>1 * Revert 2 commits d054737 'Support n>1' b5167cc 'Return support for other models apart from jamba' * TP on input and output * Basic TP impl , working, correctness not working * TP is working * Roll back the verification that everything in the weights fits into the model * Cleanup * Use world size func * clean up * Import * Apply whitespace suggestions from code review * Organize imports * Add comment on the unsqueeze in conv1d * Organize and remove redundant code in forward pass * Remove print * Add comments Co-authored-by: tomeras91 <57313761+tomeras91@users.noreply.github.com> * White spaces * Set as A * better comment --------- Co-authored-by: Mor Zusman Co-authored-by: tomeras91 <57313761+tomeras91@users.noreply.github.com> --- vllm/model_executor/models/jamba.py | 109 +++++++++++++++++----------- vllm/worker/model_runner.py | 11 ++- 2 files changed, 74 insertions(+), 46 deletions(-) diff --git a/vllm/model_executor/models/jamba.py b/vllm/model_executor/models/jamba.py index 73902d3e25783..b0de0a23636db 100644 --- a/vllm/model_executor/models/jamba.py +++ b/vllm/model_executor/models/jamba.py @@ -5,22 +5,20 @@ from typing import Dict, List, Optional, Tuple import torch -from torch import conv_transpose3d, nn -import os +from torch import nn from vllm.model_executor.mamba_metadata import MambaCacheParams from vllm.transformers_utils.configs.jamba import JambaConfig -from transformers.activations import ACT2FN +from torch.nn.parameter import Parameter from vllm.config import LoRAConfig from vllm.model_executor.input_metadata import InputMetadata from vllm.model_executor.layers.attention import PagedAttention from vllm.model_executor.layers.fused_moe import fused_moe from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.linear import (LinearMethodBase, +from vllm.model_executor.layers.linear import (ColumnParallelLinear, LinearMethodBase, MergedColumnParallelLinear, QKVParallelLinear, ReplicatedLinear, RowParallelLinear) -from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding, ParallelLMHead, DEFAULT_VOCAB_PADDING_SIZE) @@ -59,35 +57,65 @@ def __init__(self, config: JambaConfig, layer_idx): self.time_step_rank = config.mamba_dt_rank self.use_conv_bias = config.mamba_conv_bias self.use_bias = config.mamba_proj_bias - self.conv1d = nn.Conv1d( - in_channels=self.intermediate_size, - out_channels=self.intermediate_size, + self.conv1d = ColumnParallelLinear( + input_size=self.conv_kernel_size, + output_size=self.intermediate_size, bias=self.use_conv_bias, - kernel_size=self.conv_kernel_size, - groups=self.intermediate_size, - padding=self.conv_kernel_size - 1, ) + # unsqueeze to fit conv1d weights shape into the linear weights shape. Can't do this in `weight_loader` since it already exists in `ColumnParallelLinear` and `set_weight_attrs` doesn't allow to override it + self.conv1d.weight.data = self.conv1d.weight.data.unsqueeze(1) - self.activation = config.hidden_act - self.act = ACT2FN[config.hidden_act] - self.apply_inner_layernorms = config.mamba_inner_layernorms - - # projection of the input hidden states - self.in_proj = nn.Linear(self.hidden_size, self.intermediate_size * 2, bias=self.use_bias) + self.in_proj = MergedColumnParallelLinear( + self.hidden_size, + [self.intermediate_size] * 2, + bias=self.use_bias + ) # selective projection used to make dt, B and C input dependant - self.x_proj = nn.Linear(self.intermediate_size, self.time_step_rank + self.ssm_state_size * 2, bias=False) - # time step projection (discretization) - self.dt_proj = nn.Linear(self.time_step_rank, self.intermediate_size, bias=True) + self.x_proj = RowParallelLinear( + self.intermediate_size, + self.time_step_rank + self.ssm_state_size * 2, + bias=False + ) + # time step projection (discretization) - In the forward we need to apply dt_proj without the bias, as the bias is added in the selective scan kernel. + self.dt_proj = ColumnParallelLinear( + self.time_step_rank, + self.intermediate_size, + bias=True, + skip_bias_add=True + ) - # S4D real initialization. These are not discretized! - # The core is to load them, compute the discrete states, then write the updated state. Keeps the memory bounded - A = torch.arange(1, self.ssm_state_size + 1, dtype=torch.float32)[None, :] - A = A.expand(self.intermediate_size, -1).contiguous() + def weight_loader(param:Parameter, loaded_weight:torch.Tensor): + tp_rank = get_tensor_model_parallel_rank() + tp_size = get_tensor_model_parallel_world_size() + param.data.copy_(loaded_weight.data.split(loaded_weight.shape[0] // tp_size, dim=0)[tp_rank]) - self.A_log = nn.Parameter(torch.log(A)) - self.D = nn.Parameter(torch.ones(self.intermediate_size)) - self.out_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=self.use_bias) + def A_weight_loader(param:Parameter, loaded_weight:torch.Tensor): + weight_loader(param,-torch.exp(loaded_weight.float())) + + tp_size = get_tensor_model_parallel_world_size() + self.A = nn.Parameter(torch.empty( + self.intermediate_size // tp_size, + self.ssm_state_size, + dtype=torch.float32 + )) + self.D = nn.Parameter(torch.ones(self.intermediate_size // tp_size)) + + set_weight_attrs(self.D, { + "weight_loader": weight_loader + }) + set_weight_attrs(self.A, { + "weight_loader": A_weight_loader + }) + self.out_proj = RowParallelLinear( + self.intermediate_size, + self.hidden_size, + bias=self.use_bias, + input_is_parallel=True + ) + self.activation = config.hidden_act + self.apply_inner_layernorms = config.mamba_inner_layernorms + if self.apply_inner_layernorms: self.dt_layernorm = RMSNorm(self.time_step_rank, eps=config.rms_norm_eps) self.B_layernorm = RMSNorm(self.ssm_state_size, eps=config.rms_norm_eps) @@ -108,8 +136,7 @@ def _apply_layernorms(self, dt, B, C): def mamba_forward(self, hidden_states: torch.Tensor, cache_params: MambaCacheParams = None): # 1. Gated MLP's linear projection - projected_states = self.in_proj(hidden_states).transpose(1, 2) - + projected_states = self.in_proj(hidden_states)[0].transpose(1, 2) hidden_states, gate = projected_states.chunk(2, dim=1) # 2. Convolution sequence transformation @@ -135,23 +162,14 @@ def mamba_forward(self, hidden_states: torch.Tensor, cache_params: MambaCachePar # 3. State Space Model sequence transformation # 3.a. input varying initialization of time_step, B and C - ssm_parameters = self.x_proj(hidden_states.transpose(1, 2)) + ssm_parameters = self.x_proj(hidden_states.transpose(1, 2))[0] + time_step, B, C = torch.split( ssm_parameters, [self.time_step_rank, self.ssm_state_size, self.ssm_state_size], dim=-1 ) time_step, B, C = self._apply_layernorms(time_step, B, C) - # Here we need to apply dt_proj without the bias, as the bias is added in the selective scan kernel. - # This is a hack to apply dt_proj while still using the forward pass of `torch.nn.Linear`, which is needed - # in order to make quantization work. Quantization code replaces `torch.nn.Linear` layers with quantized - # linear layers, and requires to call the forward pass directly. - # The original code here was: ```discrete_time_step = self.dt_proj.weight @ time_step.transpose(1, 2)``` - dt_proj_bias = self.dt_proj.bias - self.dt_proj.bias = None - discrete_time_step = self.dt_proj(time_step).transpose(1, 2) - self.dt_proj.bias = dt_proj_bias - - A = -torch.exp(self.A_log.float()) + discrete_time_step = self.dt_proj(time_step)[0].transpose(1,2) # 3.c perform the recurrence y ← SSM(A, B, C)(x) time_proj_bias = self.dt_proj.bias.float() if hasattr(self.dt_proj, "bias") else None if cache_params is not None and not cache_params.is_prompt: @@ -159,7 +177,7 @@ def mamba_forward(self, hidden_states: torch.Tensor, cache_params: MambaCachePar cache_params.ssm_state, hidden_states[..., 0], discrete_time_step[..., 0], - A, + self.A, B[:, 0], C[:, 0], self.D, @@ -171,7 +189,7 @@ def mamba_forward(self, hidden_states: torch.Tensor, cache_params: MambaCachePar scan_outputs, ssm_state = selective_scan_fn( hidden_states, discrete_time_step, - A, + self.A, B.transpose(1, 2), C.transpose(1, 2), self.D.float(), @@ -184,7 +202,7 @@ def mamba_forward(self, hidden_states: torch.Tensor, cache_params: MambaCachePar cache_params.ssm_state.copy_(ssm_state) # 4. Final linear projection - contextualized_states = self.out_proj(scan_outputs.transpose(1, 2)) + contextualized_states = self.out_proj(scan_outputs.transpose(1, 2))[0] return contextualized_states def forward(self, hidden_states: torch.Tensor, input_metadata: InputMetadata, conv_state: torch.Tensor, ssm_state: torch.Tensor): @@ -625,6 +643,9 @@ def load_weights(self, if "rotary_emb.inv_freq" in name: continue + if "A_log" in name: + name = name.replace("A_log","A") + if ".self_attn." in name: name = name.replace(".self_attn", "") diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 6c4cdf9a2e05f..d4b6ba6c9d032 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -23,6 +23,12 @@ from vllm.model_executor import SamplingMetadata from vllm.model_executor.mamba_metadata import RequestInfo from vllm.model_executor.model_loader import get_model +from vllm.model_executor.parallel_utils import custom_all_reduce, pynccl_utils +from vllm.model_executor.parallel_utils.communication_op import ( + broadcast_tensor_dict) +from vllm.model_executor.parallel_utils.parallel_state import ( + with_pynccl_for_all_reduce, + get_tensor_model_parallel_world_size) from vllm.sampling_params import SamplingParams, SamplingType from vllm.sequence import (MultiModalData, SamplerOutput, SequenceData, SequenceGroupMetadata) @@ -169,16 +175,17 @@ def prepare_contiguous_mamba_cache(self, dtype): hf_config = self.model_config.hf_config num_layers = hf_config.num_hidden_layers max_batch_size = _BATCH_SIZES_TO_CAPTURE[-1] + world_size = get_tensor_model_parallel_world_size() conv_state_shape = ( num_layers, max_batch_size, - hf_config.mamba_expand * hf_config.hidden_size, + hf_config.mamba_expand * hf_config.hidden_size // world_size, hf_config.mamba_d_conv, ) ssm_state_shape = ( num_layers, max_batch_size, - hf_config.mamba_expand * hf_config.hidden_size, + hf_config.mamba_expand * hf_config.hidden_size // world_size, hf_config.mamba_d_state, ) if self.mamba_cache is None: