Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP][Core] Support tensor parallel division with remainder of attention heads #5367

Closed
wants to merge 24 commits into from
Closed
Show file tree
Hide file tree
Changes from 15 commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
b86675d
Change model config support unequal tp division
NadavShmayo Jun 9, 2024
a789569
Add unequal tp division util functions
NadavShmayo Jun 9, 2024
428b85f
Change parallel layers to support unequal tp division
NadavShmayo Jun 9, 2024
c485d50
Add unequal tp division support for opt model
NadavShmayo Jun 9, 2024
1cf543b
Add unequal tp division support for commandr model
NadavShmayo Jun 9, 2024
a6970c0
Add unequal tp division support for llama model
NadavShmayo Jun 9, 2024
6a4b70e
Remove asserts in Llama and CommandR implementation
NadavShmayo Jun 9, 2024
6b33c87
Add tp_rank to EmbeddingModelRunner class
NadavShmayo Jun 11, 2024
90d9f6c
Fix QKVLinear to work with packed dim
NadavShmayo Jun 11, 2024
014b682
Fix imports formatting in layer/linear.py file
NadavShmayo Jun 11, 2024
a30e120
Merge branch 'main' into unequal_tp_division
NadavShmayo Jun 30, 2024
73c0159
Merge branch 'main' into unequal_tp_division
NadavShmayo Jul 3, 2024
cdb2e27
Remove unused variable
NadavShmayo Jul 3, 2024
b9e5309
Fix failing tests
NadavShmayo Jul 3, 2024
a268f20
Fix formatting
NadavShmayo Jul 3, 2024
b033a43
Add uneven tensor parallel test cases
NadavShmayo Jul 3, 2024
34f9850
Fix review comments
NadavShmayo Jul 3, 2024
a154ade
Fix uneven TP tests and add to .buildkite
NadavShmayo Jul 4, 2024
fe906b5
Fix formatting and imports in new uneven TP tests
NadavShmayo Jul 4, 2024
537e16b
Fix uneven TP chunked prefill tests and buildkit config
NadavShmayo Jul 4, 2024
5639427
Change default padding size of ParallelLMHead to None
NadavShmayo Jul 4, 2024
6f7c0de
Add validation for LoRA with tensor parallel
NadavShmayo Jul 8, 2024
b8e870a
Fix LLama uneven TP lm head
NadavShmayo Jul 8, 2024
fc777b5
Merge branch 'main' into unequal_tp_division
NadavShmayo Jul 8, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 28 additions & 13 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from transformers import PretrainedConfig

import vllm.envs as envs
from vllm.distributed import get_current_tp_rank_partition_size
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
from vllm.model_executor.models import ModelRegistry
Expand Down Expand Up @@ -260,11 +261,13 @@ def verify_with_parallel_config(
total_num_attention_heads = getattr(self.hf_text_config,
"num_attention_heads", 0)
tensor_parallel_size = parallel_config.tensor_parallel_size
if total_num_attention_heads % tensor_parallel_size != 0:
if (total_num_attention_heads % tensor_parallel_size != 0
and self.quantization is not None):
raise ValueError(
f"Total number of attention heads ({total_num_attention_heads})"
f"Total number of attention heads "
f"({total_num_attention_heads})"
" must be divisible by tensor parallel size "
f"({tensor_parallel_size}).")
f"({tensor_parallel_size}) when quantization is used.")

total_num_hidden_layers = getattr(self.hf_text_config,
"num_hidden_layers", 0)
Expand Down Expand Up @@ -370,20 +373,32 @@ def get_total_num_kv_heads(self) -> int:
# equal to the number of attention heads.
return self.hf_text_config.num_attention_heads

def get_num_kv_heads(self, parallel_config: "ParallelConfig") -> int:
def get_num_kv_heads(self,
parallel_config: "ParallelConfig",
tp_rank: int = 0) -> int:
"""Returns the number of KV heads per GPU."""
total_num_kv_heads = self.get_total_num_kv_heads()
# If tensor parallelism is used, we divide the number of KV heads by
# the tensor parallel size. We will replicate the KV heads in the
# case where the number of KV heads is smaller than the tensor
# parallel size so each GPU has at least one KV head.
return max(1,
total_num_kv_heads // parallel_config.tensor_parallel_size)
result = get_current_tp_rank_partition_size(
total_num_kv_heads, tp_rank, parallel_config.tensor_parallel_size)
return max(1, result)

def get_num_attention_heads(self,
parallel_config: "ParallelConfig") -> int:
num_heads = getattr(self.hf_text_config, "num_attention_heads", 0)
return num_heads // parallel_config.tensor_parallel_size
parallel_config: "ParallelConfig",
tp_rank: int = 0) -> int:
if getattr(self.hf_text_config, "num_attention_heads", None) is None:
return 0

num_total_kv_heads = self.get_total_num_kv_heads()
num_kv_heads = self.get_num_kv_heads(parallel_config, tp_rank)
num_total_attention_heads = self.hf_text_config.num_attention_heads
num_heads_per_kv_head = num_total_attention_heads // num_total_kv_heads
# For GQA attention we make sure the whole attention head group is
# together on the same GPU.
return num_kv_heads * num_heads_per_kv_head

def get_num_layers(self, parallel_config: "ParallelConfig") -> int:
total_num_hidden_layers = getattr(self.hf_text_config,
Expand Down Expand Up @@ -760,7 +775,7 @@ class SchedulerConfig:
enable_chunked_prefill: If True, prefill requests can be chunked based
on the remaining max_num_batched_tokens.
embedding_mode: Whether the running model is for embedding.
preemption_mode: Whether to perform preemption by swapping or
preemption_mode: Whether to perform preemption by swapping or
recomputation. If not specified, we determine the mode as follows:
We use recomputation by default since it incurs lower overhead than
swapping. However, when the sequence group has multiple sequences
Expand Down Expand Up @@ -931,12 +946,12 @@ def maybe_create_spec_config(
typical_acceptance_sampler_posterior_threshold (Optional[float]):
A threshold value that sets a lower bound on the posterior
probability of a token in the target model for it to be
accepted. This threshold is used only when we use the
accepted. This threshold is used only when we use the
TypicalAcceptanceSampler for token acceptance.
typical_acceptance_sampler_posterior_alpha (Optional[float]):
A scaling factor for the entropy-based threshold in the
TypicalAcceptanceSampler.

Returns:
Optional["SpeculativeConfig"]: An instance of SpeculativeConfig if
the necessary conditions are met, else None.
Expand Down Expand Up @@ -1166,7 +1181,7 @@ def __init__(
typical_acceptance_sampler_posterior_threshold (Optional[float]):
A threshold value that sets a lower bound on the posterior
probability of a token in the target model for it to be
accepted. This threshold is used only when we use the
accepted. This threshold is used only when we use the
TypicalAcceptanceSampler for token acceptance.
typical_acceptance_sampler_posterior_alpha (Optional[float]):
A scaling factor for the entropy-based threshold in the
Expand Down
36 changes: 34 additions & 2 deletions vllm/distributed/parallel_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
The typical workflow is:

- call `init_distributed_environment` to initialize the distributed environment.
- call `initialize_model_parallel` or `ensure_model_parallel_initialized` to
- call `initialize_model_parallel` or `ensure_model_parallel_initialized` to
initialize the model parallel groups.

- any code dealing with the distributed stuff
Expand Down Expand Up @@ -272,7 +272,7 @@ def graph_capture(

def all_reduce(self, input_: torch.Tensor) -> torch.Tensor:
"""
NOTE: This operation will be applied in-place or out-of-place.
NOTE: This operation will be applied in-place or out-of-place.
Always assume this function modifies its input, but use the return
value as the output.
"""
Expand Down Expand Up @@ -1047,3 +1047,35 @@ def is_in_the_same_node(pg: ProcessGroup):
torch.distributed.all_reduce(is_in_the_same_node, group=pg)

return is_in_the_same_node.sum().item() == world_size


def get_current_tp_rank_partition_offset(total_size: int,
tp_rank: Optional[int] = None,
tp_size: Optional[int] = None,
multiple_of: int = 1) -> int:
if tp_rank is None:
tp_rank = get_tensor_model_parallel_rank()

if tp_size is None:
tp_size = get_tensor_model_parallel_world_size()

assert total_size % multiple_of == 0
total_size = total_size // multiple_of
return ((total_size // tp_size) * tp_rank +
min(total_size % tp_size, tp_rank)) * multiple_of


def get_current_tp_rank_partition_size(total_size: int,
tp_rank: Optional[int] = None,
tp_size: Optional[int] = None,
multiple_of: int = 1) -> int:
if tp_rank is None:
tp_rank = get_tensor_model_parallel_rank()

if tp_size is None:
tp_size = get_tensor_model_parallel_world_size()

assert total_size % multiple_of == 0
total_size = total_size // multiple_of
return ((total_size // tp_size) +
(total_size % tp_size > tp_rank)) * multiple_of
2 changes: 1 addition & 1 deletion vllm/engine/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -457,4 +457,4 @@ def log(self, stats: Stats):

class RayPrometheusStatLogger(PrometheusStatLogger):
"""RayPrometheusStatLogger uses Ray metrics instead."""
_metrics_cls = RayMetrics
_metrics_cls = RayMetrics
96 changes: 63 additions & 33 deletions vllm/model_executor/layers/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,15 @@
import torch.nn.functional as F
from torch.nn.parameter import Parameter

from vllm.distributed import (divide, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
split_tensor_along_last_dim,
tensor_model_parallel_all_gather,
tensor_model_parallel_all_reduce)
# fixme Isort and yapf conflict for this, so we disable isort for this block
# isort: off
from vllm.distributed import (
divide, get_current_tp_rank_partition_offset,
get_current_tp_rank_partition_size, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size, split_tensor_along_last_dim,
tensor_model_parallel_all_gather, tensor_model_parallel_all_reduce)
# isort: on

from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase)
Expand Down Expand Up @@ -259,14 +263,17 @@ def __init__(self,
self.gather_output = gather_output

# Divide the weight matrix along the last dimension.
tp_rank = get_tensor_model_parallel_rank()
tp_size = get_tensor_model_parallel_world_size()
assert self.quant_method is not None
self.output_size_per_partition = divide(self.output_size, tp_size)
self.output_size_per_partition = output_size // tp_size + (
output_size % tp_size > tp_rank)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you separate this into its own variable as the remainder for clarity and/or please add a comment describing what is the intended behavior? The condition makes it a bit unclear

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice catch, this should use the util function I added for this logic, which should make it more readable.

self.output_partition_sizes = [self.output_size_per_partition]
# If QKV or MergedColumn, use output size of each partition.
if hasattr(self, "output_sizes"):
self.output_partition_sizes = [
divide(output_size, tp_size)
get_current_tp_rank_partition_size(output_size, tp_rank,
tp_size)
for output_size in self.output_sizes
]

Expand Down Expand Up @@ -353,17 +360,17 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
quant_config: Quantization configure.
"""

def __init__(self,
input_size: int,
output_sizes: List[int],
bias: bool = True,
gather_output: bool = False,
skip_bias_add: bool = False,
params_dtype: Optional[torch.dtype] = None,
quant_config: Optional[QuantizationConfig] = None):
def __init__(
self,
input_size: int,
output_sizes: List[int],
bias: bool = True,
gather_output: bool = False,
skip_bias_add: bool = False,
params_dtype: Optional[torch.dtype] = None,
quant_config: Optional[QuantizationConfig] = None,
):
self.output_sizes = output_sizes
tp_size = get_tensor_model_parallel_world_size()
assert all(output_size % tp_size == 0 for output_size in output_sizes)
super().__init__(input_size=input_size,
output_size=sum(output_sizes),
bias=bias,
Expand Down Expand Up @@ -420,8 +427,12 @@ def weight_loader(self,
tp_rank = get_tensor_model_parallel_rank()
tp_size = get_tensor_model_parallel_world_size()
if output_dim is not None:
shard_offset = sum(self.output_sizes[:loaded_shard_id]) // tp_size
shard_size = self.output_sizes[loaded_shard_id] // tp_size
shard_offset = sum(
get_current_tp_rank_partition_size(output_size, tp_rank,
tp_size)
for output_size in self.output_sizes[:loaded_shard_id])
shard_size = get_current_tp_rank_partition_size(
self.output_sizes[loaded_shard_id], tp_rank, tp_size)
# Special case for quantization.
# If quantized, we need to adjust the offset and size to account
# for the packing.
Expand All @@ -441,7 +452,8 @@ def weight_loader(self,

param_data = param_data.narrow(output_dim, shard_offset,
shard_size)
start_idx = tp_rank * shard_size
start_idx = get_current_tp_rank_partition_offset(
loaded_weight.shape[output_dim], tp_rank, tp_size)
loaded_weight = loaded_weight.narrow(output_dim, start_idx,
shard_size)
# Special case for AQLM codebooks.
Expand Down Expand Up @@ -509,14 +521,18 @@ def __init__(self,
self.total_num_kv_heads = total_num_kv_heads
# Divide the weight matrix along the last dimension.
tp_size = get_tensor_model_parallel_world_size()
self.num_heads = divide(self.total_num_heads, tp_size)
tp_rank = get_tensor_model_parallel_rank()
self.num_heads_per_kv_head = (self.total_num_heads //
self.total_num_kv_heads)
self.num_kv_heads = get_current_tp_rank_partition_size(
self.total_num_kv_heads, tp_rank, tp_size)
self.num_heads = self.num_kv_heads * self.num_heads_per_kv_head
self.num_kv_head_replicas = 1
if tp_size >= self.total_num_kv_heads:
self.num_kv_heads = 1
self.num_kv_head_replicas = divide(tp_size,
self.total_num_kv_heads)
else:
self.num_kv_heads = divide(self.total_num_kv_heads, tp_size)
self.num_kv_head_replicas = 1

input_size = self.hidden_size
output_size = (self.num_heads +
2 * self.num_kv_heads) * tp_size * self.head_size
Expand Down Expand Up @@ -590,20 +606,25 @@ def weight_loader(self,
if loaded_shard_id == "q":
shard_offset = 0
shard_size = self.num_heads * self.head_size
multiple_of = self.head_size * self.num_heads_per_kv_head
elif loaded_shard_id == "k":
shard_offset = self.num_heads * self.head_size
shard_size = self.num_kv_heads * self.head_size
multiple_of = self.head_size
elif loaded_shard_id == "v":
shard_offset = (self.num_heads +
self.num_kv_heads) * self.head_size
shard_size = self.num_kv_heads * self.head_size
multiple_of = self.head_size

# Special case for Quantized Weights.
# If quantized, we need to adjust the offset and size to account
# for the packing.
packed_dim = getattr(param, "packed_dim", None)
if packed_dim == output_dim:
shard_size = shard_size // param.pack_factor
shard_offset = shard_offset // param.pack_factor
multiple_of = multiple_of // param.pack_factor

# Special case for Marlin.
shard_size, shard_offset = adjust_marlin_shard(
Expand All @@ -627,11 +648,11 @@ def weight_loader(self,

param_data = param_data.narrow(output_dim, shard_offset,
shard_size)
if loaded_shard_id == "q":
shard_id = tp_rank
else:
shard_id = tp_rank // self.num_kv_head_replicas
start_idx = shard_id * shard_size

tp_size = get_tensor_model_parallel_world_size()
total_size = loaded_weight.shape[output_dim]
start_idx = get_current_tp_rank_partition_offset(
total_size, tp_rank, tp_size, multiple_of=multiple_of)
loaded_weight = loaded_weight.narrow(output_dim, start_idx,
shard_size)
# Special case for for AQLM codebooks.
Expand Down Expand Up @@ -681,6 +702,8 @@ class RowParallelLinear(LinearBase):
We skip adding bias but instead return it.
params_dtype: Data type for the parameters.
quant_config: Quantization configure.
partition_multiple_of: Partitions will be divided,
so each partition is a multiple of this number.
"""

def __init__(self,
Expand All @@ -691,7 +714,8 @@ def __init__(self,
skip_bias_add: bool = False,
params_dtype: Optional[torch.dtype] = None,
reduce_results: bool = True,
quant_config: Optional[QuantizationConfig] = None):
quant_config: Optional[QuantizationConfig] = None,
partition_multiple_of: int = 1):
super().__init__(input_size, output_size, skip_bias_add, params_dtype,
quant_config)

Expand All @@ -700,7 +724,10 @@ def __init__(self,

# Divide the weight matrix along the last dimension.
self.tp_size = get_tensor_model_parallel_world_size()
self.input_size_per_partition = divide(input_size, self.tp_size)
self.tp_rank = get_tensor_model_parallel_rank()
self.partition_multiple_of = partition_multiple_of
self.input_size_per_partition = get_current_tp_rank_partition_size(
input_size, self.tp_rank, self.tp_size, partition_multiple_of)
assert self.quant_method is not None
self.quant_method.create_weights(
layer=self,
Expand All @@ -725,12 +752,15 @@ def __init__(self,
self.register_parameter("bias", None)

def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
tp_rank = get_tensor_model_parallel_rank()
input_dim = getattr(param, "input_dim", None)
param_data = param.data
if input_dim is not None:
shard_size = param_data.shape[input_dim]
start_idx = tp_rank * shard_size
start_idx = get_current_tp_rank_partition_offset(
self.input_size,
self.tp_rank,
self.tp_size,
multiple_of=self.partition_multiple_of)
loaded_weight = loaded_weight.narrow(input_dim, start_idx,
shard_size)

Expand Down
3 changes: 2 additions & 1 deletion vllm/model_executor/layers/vocab_parallel_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,10 +168,11 @@ def __init__(self,
embedding_dim: int,
params_dtype: Optional[torch.dtype] = None,
org_num_embeddings: Optional[int] = None,
padding_size: int = DEFAULT_VOCAB_PADDING_SIZE,
padding_size: Optional[int] = None,
quant_config: Optional[QuantizationConfig] = None):
super().__init__()

padding_size = padding_size or get_tensor_model_parallel_world_size()
# Keep the input dimensions.
tp_rank = get_tensor_model_parallel_rank()
self.tp_size = get_tensor_model_parallel_world_size()
Expand Down
Loading
Loading