Skip to content

Commit

Permalink
feat: pipeline parallel support for DeepSeek v2
Browse files Browse the repository at this point in the history
Signed-off-by: Travis Johnson <tsjohnso@us.ibm.com>
  • Loading branch information
tjohnson31415 committed Jul 18, 2024
1 parent ecdb462 commit 2522798
Show file tree
Hide file tree
Showing 3 changed files with 84 additions and 23 deletions.
1 change: 1 addition & 0 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
_PP_SUPPORTED_MODELS = [
"AquilaModel",
"AquilaForCausalLM",
"DeepseekV2ForCausalLM",
"InternLMForCausalLM",
"LlamaForCausalLM",
"LLaMAForCausalLM",
Expand Down
88 changes: 70 additions & 18 deletions vllm/model_executor/models/deepseek_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only DeepseekV2 model."""
import functools
from typing import Any, Dict, Iterable, List, Optional, Tuple

import torch
Expand All @@ -29,7 +30,8 @@

from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig
from vllm.distributed import (get_tensor_model_parallel_world_size,
from vllm.distributed import (get_pp_group,
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce)
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe import FusedMoE
Expand All @@ -49,6 +51,8 @@
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors, SamplerOutput

from .utils import PPMissingLayer, is_pp_missing_parameter, make_layers


class DeepseekV2MLP(nn.Module):

Expand Down Expand Up @@ -394,33 +398,56 @@ def __init__(
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size

self.embed_tokens = VocabParallelEmbedding(
config.vocab_size,
config.hidden_size,
if get_pp_group().is_first_rank:
self.embed_tokens = VocabParallelEmbedding(
config.vocab_size,
config.hidden_size,
)
else:
self.embed_tokens = PPMissingLayer()

self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers,
# layer_idx is still an argument
functools.partial(DeepseekV2DecoderLayer,
config,
cache_config=cache_config,
quant_config=quant_config),
)
self.layers = nn.ModuleList([
DeepseekV2DecoderLayer(config,
layer_idx,
cache_config=cache_config,
quant_config=quant_config)
for layer_idx in range(config.num_hidden_layers)
])
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)

if get_pp_group().is_last_rank:
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
else:
self.norm = PPMissingLayer()

def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors],
) -> torch.Tensor:
hidden_states = self.embed_tokens(input_ids)
residual = None
for i in range(len(self.layers)):
if get_pp_group().is_first_rank:
hidden_states = self.embed_tokens(input_ids)
residual = None
else:
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]

for i in range(self.start_layer, self.end_layer):
layer = self.layers[i]
hidden_states, residual = layer(positions, hidden_states,
kv_caches[i], attn_metadata,
residual)
kv_caches[i - self.start_layer],
attn_metadata, residual)

if not get_pp_group().is_last_rank:
return IntermediateTensors({
"hidden_states": hidden_states,
"residual": residual
})

hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states

Expand Down Expand Up @@ -452,7 +479,7 @@ def forward(
intermediate_tensors: Optional[IntermediateTensors] = None,
) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, kv_caches,
attn_metadata)
attn_metadata, intermediate_tensors)
return hidden_states

def compute_logits(self, hidden_states: torch.Tensor,
Expand All @@ -469,6 +496,20 @@ def sample(
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens

def make_empty_intermediate_tensors(
self, batch_size: int, dtype: torch.dtype,
device: torch.device) -> IntermediateTensors:
return IntermediateTensors({
"hidden_states":
torch.zeros((batch_size, self.config.hidden_size),
dtype=dtype,
device=device),
"residual":
torch.zeros((batch_size, self.config.hidden_size),
dtype=dtype,
device=device),
})

def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
Expand Down Expand Up @@ -504,6 +545,10 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue

if is_pp_missing_parameter(name, self):
continue

param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
Expand All @@ -514,6 +559,10 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)

if is_pp_missing_parameter(name, self):
continue

param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param,
Expand All @@ -527,6 +576,9 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
if name.endswith(".bias") and name not in params_dict:
continue

if is_pp_missing_parameter(name, self):
continue

param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
Expand Down
18 changes: 13 additions & 5 deletions vllm/model_executor/models/utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import inspect
from typing import Callable, Dict, List, Tuple

import torch
Expand Down Expand Up @@ -119,7 +120,8 @@ def forward(*args, **kwargs):


def make_layers(
num_hidden_layers: int, layer_fn: Callable[[], torch.nn.Module]
num_hidden_layers: int,
layer_fn: Callable[..., torch.nn.Module],
) -> Tuple[int, int, torch.nn.ModuleList]:
"""Make a list of layers with the given layer function, taking
pipeline parallelism into account.
Expand All @@ -129,11 +131,17 @@ def make_layers(
start_layer, end_layer = get_pp_indices(num_hidden_layers,
get_pp_group().rank_in_group,
get_pp_group().world_size)

include_layer_idx = "layer_idx" in inspect.signature(layer_fn).parameters
layer_modules = (
[layer_fn(layer_idx=i)
for i in range(start_layer, end_layer)] if include_layer_idx else
[layer_fn() for _ in range(start_layer, end_layer)])
modules = torch.nn.ModuleList(
[PPMissingLayer() for _ in range(start_layer)] + [
maybe_offload_to_cpu(layer_fn())
for _ in range(start_layer, end_layer)
] + [PPMissingLayer() for _ in range(end_layer, num_hidden_layers)])
[PPMissingLayer() for _ in range(start_layer)] +
[maybe_offload_to_cpu(m) for m in layer_modules] +
[PPMissingLayer() for _ in range(end_layer, num_hidden_layers)])

return start_layer, end_layer, modules


Expand Down

0 comments on commit 2522798

Please sign in to comment.