Skip to content

Commit

Permalink
Tensor parallelism (vllm-project#7)
Browse files Browse the repository at this point in the history
* Return support for other models apart from jamba

* Support n>1

* Revert 2 commits

d054737 'Support n>1'
b5167cc 'Return support for other models apart from jamba'

* TP on input and output

* Basic TP impl , working, correctness not working

* TP is working

* Roll back the verification that everything in the weights fits into the
model

* Cleanup

* Use world size func

* clean up

* Import

* Apply whitespace suggestions from code review

* Organize imports

* Add comment on the unsqueeze in conv1d

* Organize and remove redundant code in forward pass

* Remove print

* Add comments

Co-authored-by: tomeras91 <57313761+tomeras91@users.noreply.github.com>

* White spaces

* Set as A

* better comment

---------

Co-authored-by: Mor Zusman <morz@ai21.com>
Co-authored-by: tomeras91 <57313761+tomeras91@users.noreply.github.com>
  • Loading branch information
3 people committed Apr 16, 2024
1 parent 39c27b7 commit 7c75868
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 46 deletions.
109 changes: 65 additions & 44 deletions vllm/model_executor/models/jamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,22 +5,20 @@
from typing import Dict, List, Optional, Tuple

import torch
from torch import conv_transpose3d, nn
import os
from torch import nn
from vllm.model_executor.mamba_metadata import MambaCacheParams

from vllm.transformers_utils.configs.jamba import JambaConfig
from transformers.activations import ACT2FN
from torch.nn.parameter import Parameter
from vllm.config import LoRAConfig
from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.attention import PagedAttention
from vllm.model_executor.layers.fused_moe import fused_moe
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (LinearMethodBase,
from vllm.model_executor.layers.linear import (ColumnParallelLinear, LinearMethodBase, MergedColumnParallelLinear,
QKVParallelLinear,
ReplicatedLinear,
RowParallelLinear)
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding, ParallelLMHead, DEFAULT_VOCAB_PADDING_SIZE)
Expand Down Expand Up @@ -59,35 +57,65 @@ def __init__(self, config: JambaConfig, layer_idx):
self.time_step_rank = config.mamba_dt_rank
self.use_conv_bias = config.mamba_conv_bias
self.use_bias = config.mamba_proj_bias
self.conv1d = nn.Conv1d(
in_channels=self.intermediate_size,
out_channels=self.intermediate_size,
self.conv1d = ColumnParallelLinear(
input_size=self.conv_kernel_size,
output_size=self.intermediate_size,
bias=self.use_conv_bias,
kernel_size=self.conv_kernel_size,
groups=self.intermediate_size,
padding=self.conv_kernel_size - 1,
)
# unsqueeze to fit conv1d weights shape into the linear weights shape. Can't do this in `weight_loader` since it already exists in `ColumnParallelLinear` and `set_weight_attrs` doesn't allow to override it
self.conv1d.weight.data = self.conv1d.weight.data.unsqueeze(1)

self.activation = config.hidden_act
self.act = ACT2FN[config.hidden_act]
self.apply_inner_layernorms = config.mamba_inner_layernorms

# projection of the input hidden states
self.in_proj = nn.Linear(self.hidden_size, self.intermediate_size * 2, bias=self.use_bias)
self.in_proj = MergedColumnParallelLinear(
self.hidden_size,
[self.intermediate_size] * 2,
bias=self.use_bias
)
# selective projection used to make dt, B and C input dependant
self.x_proj = nn.Linear(self.intermediate_size, self.time_step_rank + self.ssm_state_size * 2, bias=False)
# time step projection (discretization)
self.dt_proj = nn.Linear(self.time_step_rank, self.intermediate_size, bias=True)
self.x_proj = RowParallelLinear(
self.intermediate_size,
self.time_step_rank + self.ssm_state_size * 2,
bias=False
)
# time step projection (discretization) - In the forward we need to apply dt_proj without the bias, as the bias is added in the selective scan kernel.
self.dt_proj = ColumnParallelLinear(
self.time_step_rank,
self.intermediate_size,
bias=True,
skip_bias_add=True
)

# S4D real initialization. These are not discretized!
# The core is to load them, compute the discrete states, then write the updated state. Keeps the memory bounded
A = torch.arange(1, self.ssm_state_size + 1, dtype=torch.float32)[None, :]
A = A.expand(self.intermediate_size, -1).contiguous()
def weight_loader(param:Parameter, loaded_weight:torch.Tensor):
tp_rank = get_tensor_model_parallel_rank()
tp_size = get_tensor_model_parallel_world_size()
param.data.copy_(loaded_weight.data.split(loaded_weight.shape[0] // tp_size, dim=0)[tp_rank])

self.A_log = nn.Parameter(torch.log(A))
self.D = nn.Parameter(torch.ones(self.intermediate_size))
self.out_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=self.use_bias)
def A_weight_loader(param:Parameter, loaded_weight:torch.Tensor):
weight_loader(param,-torch.exp(loaded_weight.float()))

tp_size = get_tensor_model_parallel_world_size()
self.A = nn.Parameter(torch.empty(
self.intermediate_size // tp_size,
self.ssm_state_size,
dtype=torch.float32
))
self.D = nn.Parameter(torch.ones(self.intermediate_size // tp_size))

set_weight_attrs(self.D, {
"weight_loader": weight_loader
})
set_weight_attrs(self.A, {
"weight_loader": A_weight_loader
})

self.out_proj = RowParallelLinear(
self.intermediate_size,
self.hidden_size,
bias=self.use_bias,
input_is_parallel=True
)
self.activation = config.hidden_act
self.apply_inner_layernorms = config.mamba_inner_layernorms

if self.apply_inner_layernorms:
self.dt_layernorm = RMSNorm(self.time_step_rank, eps=config.rms_norm_eps)
self.B_layernorm = RMSNorm(self.ssm_state_size, eps=config.rms_norm_eps)
Expand All @@ -108,8 +136,7 @@ def _apply_layernorms(self, dt, B, C):

def mamba_forward(self, hidden_states: torch.Tensor, cache_params: MambaCacheParams = None):
# 1. Gated MLP's linear projection
projected_states = self.in_proj(hidden_states).transpose(1, 2)

projected_states = self.in_proj(hidden_states)[0].transpose(1, 2)
hidden_states, gate = projected_states.chunk(2, dim=1)

# 2. Convolution sequence transformation
Expand All @@ -135,31 +162,22 @@ def mamba_forward(self, hidden_states: torch.Tensor, cache_params: MambaCachePar

# 3. State Space Model sequence transformation
# 3.a. input varying initialization of time_step, B and C
ssm_parameters = self.x_proj(hidden_states.transpose(1, 2))
ssm_parameters = self.x_proj(hidden_states.transpose(1, 2))[0]

time_step, B, C = torch.split(
ssm_parameters, [self.time_step_rank, self.ssm_state_size, self.ssm_state_size], dim=-1
)
time_step, B, C = self._apply_layernorms(time_step, B, C)

# Here we need to apply dt_proj without the bias, as the bias is added in the selective scan kernel.
# This is a hack to apply dt_proj while still using the forward pass of `torch.nn.Linear`, which is needed
# in order to make quantization work. Quantization code replaces `torch.nn.Linear` layers with quantized
# linear layers, and requires to call the forward pass directly.
# The original code here was: ```discrete_time_step = self.dt_proj.weight @ time_step.transpose(1, 2)```
dt_proj_bias = self.dt_proj.bias
self.dt_proj.bias = None
discrete_time_step = self.dt_proj(time_step).transpose(1, 2)
self.dt_proj.bias = dt_proj_bias

A = -torch.exp(self.A_log.float())
discrete_time_step = self.dt_proj(time_step)[0].transpose(1,2)
# 3.c perform the recurrence y ← SSM(A, B, C)(x)
time_proj_bias = self.dt_proj.bias.float() if hasattr(self.dt_proj, "bias") else None
if cache_params is not None and not cache_params.is_prompt:
scan_outputs = selective_state_update(
cache_params.ssm_state,
hidden_states[..., 0],
discrete_time_step[..., 0],
A,
self.A,
B[:, 0],
C[:, 0],
self.D,
Expand All @@ -171,7 +189,7 @@ def mamba_forward(self, hidden_states: torch.Tensor, cache_params: MambaCachePar
scan_outputs, ssm_state = selective_scan_fn(
hidden_states,
discrete_time_step,
A,
self.A,
B.transpose(1, 2),
C.transpose(1, 2),
self.D.float(),
Expand All @@ -184,7 +202,7 @@ def mamba_forward(self, hidden_states: torch.Tensor, cache_params: MambaCachePar
cache_params.ssm_state.copy_(ssm_state)

# 4. Final linear projection
contextualized_states = self.out_proj(scan_outputs.transpose(1, 2))
contextualized_states = self.out_proj(scan_outputs.transpose(1, 2))[0]
return contextualized_states

def forward(self, hidden_states: torch.Tensor, input_metadata: InputMetadata, conv_state: torch.Tensor, ssm_state: torch.Tensor):
Expand Down Expand Up @@ -625,6 +643,9 @@ def load_weights(self,
if "rotary_emb.inv_freq" in name:
continue

if "A_log" in name:
name = name.replace("A_log","A")

if ".self_attn." in name:
name = name.replace(".self_attn", "")

Expand Down
11 changes: 9 additions & 2 deletions vllm/worker/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,12 @@
from vllm.model_executor import SamplingMetadata
from vllm.model_executor.mamba_metadata import RequestInfo
from vllm.model_executor.model_loader import get_model
from vllm.model_executor.parallel_utils import custom_all_reduce, pynccl_utils
from vllm.model_executor.parallel_utils.communication_op import (
broadcast_tensor_dict)
from vllm.model_executor.parallel_utils.parallel_state import (
with_pynccl_for_all_reduce,
get_tensor_model_parallel_world_size)
from vllm.sampling_params import SamplingParams, SamplingType
from vllm.sequence import (MultiModalData, SamplerOutput, SequenceData,
SequenceGroupMetadata)
Expand Down Expand Up @@ -169,16 +175,17 @@ def prepare_contiguous_mamba_cache(self, dtype):
hf_config = self.model_config.hf_config
num_layers = hf_config.num_hidden_layers
max_batch_size = _BATCH_SIZES_TO_CAPTURE[-1]
world_size = get_tensor_model_parallel_world_size()
conv_state_shape = (
num_layers,
max_batch_size,
hf_config.mamba_expand * hf_config.hidden_size,
hf_config.mamba_expand * hf_config.hidden_size // world_size,
hf_config.mamba_d_conv,
)
ssm_state_shape = (
num_layers,
max_batch_size,
hf_config.mamba_expand * hf_config.hidden_size,
hf_config.mamba_expand * hf_config.hidden_size // world_size,
hf_config.mamba_d_state,
)
if self.mamba_cache is None:
Expand Down

0 comments on commit 7c75868

Please sign in to comment.