Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP][Core] Support tensor parallel division with remainder of attention heads #5367

Closed
wants to merge 24 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
b86675d
Change model config support unequal tp division
NadavShmayo Jun 9, 2024
a789569
Add unequal tp division util functions
NadavShmayo Jun 9, 2024
428b85f
Change parallel layers to support unequal tp division
NadavShmayo Jun 9, 2024
c485d50
Add unequal tp division support for opt model
NadavShmayo Jun 9, 2024
1cf543b
Add unequal tp division support for commandr model
NadavShmayo Jun 9, 2024
a6970c0
Add unequal tp division support for llama model
NadavShmayo Jun 9, 2024
6a4b70e
Remove asserts in Llama and CommandR implementation
NadavShmayo Jun 9, 2024
6b33c87
Add tp_rank to EmbeddingModelRunner class
NadavShmayo Jun 11, 2024
90d9f6c
Fix QKVLinear to work with packed dim
NadavShmayo Jun 11, 2024
014b682
Fix imports formatting in layer/linear.py file
NadavShmayo Jun 11, 2024
a30e120
Merge branch 'main' into unequal_tp_division
NadavShmayo Jun 30, 2024
73c0159
Merge branch 'main' into unequal_tp_division
NadavShmayo Jul 3, 2024
cdb2e27
Remove unused variable
NadavShmayo Jul 3, 2024
b9e5309
Fix failing tests
NadavShmayo Jul 3, 2024
a268f20
Fix formatting
NadavShmayo Jul 3, 2024
b033a43
Add uneven tensor parallel test cases
NadavShmayo Jul 3, 2024
34f9850
Fix review comments
NadavShmayo Jul 3, 2024
a154ade
Fix uneven TP tests and add to .buildkite
NadavShmayo Jul 4, 2024
fe906b5
Fix formatting and imports in new uneven TP tests
NadavShmayo Jul 4, 2024
537e16b
Fix uneven TP chunked prefill tests and buildkit config
NadavShmayo Jul 4, 2024
5639427
Change default padding size of ParallelLMHead to None
NadavShmayo Jul 4, 2024
6f7c0de
Add validation for LoRA with tensor parallel
NadavShmayo Jul 8, 2024
b8e870a
Fix LLama uneven TP lm head
NadavShmayo Jul 8, 2024
fc777b5
Merge branch 'main' into unequal_tp_division
NadavShmayo Jul 8, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .buildkite/test-pipeline.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,12 @@ steps:
# See https://github.com/vllm-project/vllm/pull/5473#issuecomment-2166601837 for context.
- TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_basic_distributed_correctness.py
- TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_basic_distributed_correctness.py
- TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_uneven_distributed_correctness.py
- TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_uneven_distributed_correctness.py
- TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_uneven_chunked_prefill_distributed.py
- TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_uneven_chunked_prefill_distributed.py
- pytest -v -s spec_decode/e2e/test_integration_dist_tp4.py
- pytest -v -s spec_decode/e2e/test_integration_dist_tp3.py

- label: Pipeline Parallelism Test
working_dir: "/vllm-workspace/tests"
Expand Down
75 changes: 75 additions & 0 deletions tests/distributed/test_uneven_chunked_prefill_distributed.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
"""Compare the outputs of HF and distributed vLLM when using greedy sampling.
vLLM will allocate all the available memory, so we need to run the tests one
by one. The solution is to pass arguments (model name) by environment
variables.

Run:
```sh
TEST_DIST_MODEL=facebook/opt-125m pytest \
test_chunked_prefill_distributed.py
TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf \
test_chunked_prefill_distributed.py
```
"""
import os

import pytest

from vllm.utils import cuda_device_count_stateless

from ..models.utils import check_outputs_equal

MODELS = [
os.environ["TEST_DIST_MODEL"],
]
DISTRIBUTED_EXECUTOR_BACKEND = "DISTRIBUTED_EXECUTOR_BACKEND"


@pytest.mark.skipif(cuda_device_count_stateless() < 3,
reason="Need at least 3 GPUs to run the test.")
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("max_tokens", [5])
@pytest.mark.parametrize("chunked_prefill_token_size", [16])
def test_models_uneven_tensor_parallel(
hf_runner,
vllm_runner,
example_prompts,
model: str,
dtype: str,
max_tokens: int,
chunked_prefill_token_size: int,
) -> None:
distributed_executor_backend = os.getenv(DISTRIBUTED_EXECUTOR_BACKEND)

# Add a chunked prefill config.
max_num_seqs = min(chunked_prefill_token_size, 256)
assert chunked_prefill_token_size != -1
enable_chunked_prefill = True
max_num_batched_tokens = chunked_prefill_token_size

# NOTE: take care of the order. run vLLM first, and then run HF.
# vLLM needs a fresh new process without cuda initialization.
# if we run HF first, the cuda initialization will be done and it
# will hurt multiprocessing backend with fork method (the default method).

with vllm_runner(
model,
dtype=dtype,
tensor_parallel_size=3,
max_num_seqs=max_num_seqs,
enable_chunked_prefill=enable_chunked_prefill,
max_num_batched_tokens=max_num_batched_tokens,
distributed_executor_backend=distributed_executor_backend,
) as vllm_model:
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)

with hf_runner(model, dtype=dtype) as hf_model:
hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)

check_outputs_equal(
outputs_0_lst=hf_outputs,
outputs_1_lst=vllm_outputs,
name_0="hf",
name_1="vllm",
)
63 changes: 63 additions & 0 deletions tests/distributed/test_uneven_distributed_correctness.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
"""Compare the outputs of HF and distributed vLLM when using greedy sampling.
vLLM will allocate all the available memory, so we need to run the tests one
by one. The solution is to pass arguments (model name) by environment
variables.
Run:
```sh
cd $VLLM_PATH/tests

TEST_DIST_MODEL=facebook/opt-125m pytest \
distributed/test_basic_distributed_correctness.py
TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf \
distributed/test_basic_distributed_correctness.py
```
"""
import os

import pytest

from vllm.utils import cuda_device_count_stateless

from ..models.utils import check_outputs_equal

MODELS = [
os.environ["TEST_DIST_MODEL"],
]
DISTRIBUTED_EXECUTOR_BACKEND = "DISTRIBUTED_EXECUTOR_BACKEND"


@pytest.mark.skipif(cuda_device_count_stateless() < 3,
reason="Need at least 3 GPUs to run the test.")
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("max_tokens", [5])
def test_models_uneven_tensor_parallel(
hf_runner,
vllm_runner,
example_prompts,
model: str,
dtype: str,
max_tokens: int,
) -> None:
distributed_executor_backend = os.getenv(DISTRIBUTED_EXECUTOR_BACKEND)

# NOTE: take care of the order. run vLLM first, and then run HF.
# vLLM needs a fresh new process without cuda initialization.
# if we run HF first, the cuda initialization will be done and it
# will hurt multiprocessing backend with fork method (the default method).
with vllm_runner(model,
dtype=dtype,
tensor_parallel_size=3,
distributed_executor_backend=distributed_executor_backend
) as vllm_model:
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)

with hf_runner(model, dtype=dtype) as hf_model:
hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)

check_outputs_equal(
outputs_0_lst=hf_outputs,
outputs_1_lst=vllm_outputs,
name_0="hf",
name_1="vllm",
)
60 changes: 60 additions & 0 deletions tests/spec_decode/e2e/test_integration_dist_tp3.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
"""Tests which cover integration of the speculative decoding framework with
tensor parallelism.
"""

import pytest
import torch

from .conftest import run_greedy_equality_correctness_test


@pytest.mark.skipif(torch.cuda.device_count() < 3,
reason="Need at least 3 GPUs to run the test.")
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
# Use a small model for a fast test.
# Note this is repeated in the test body; to initialize a tokenizer.
"model": "JackFram/llama-68m",

# Skip cuda graph recording for fast test.
"enforce_eager": True,

# Required for spec decode.
"use_v2_block_manager": True,
"tensor_parallel_size": 3,

# Use AsyncLLM engine, so that the engine runs in its own process.
# Otherwise, since vLLM does not follow true SPMD, the test runner
# process will have both the engine and the rank0 worker. NCCL is not
# cleaned up properly, and its server host thread leaks, causing the
# second run of the test to fail with internal NCCL error.
"use_async": True,
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [
{
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
},
])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize(
"test_llm_kwargs",
[
#TODO(wooyeon): add spec_draft_dp=2 case
{
"speculative_draft_tensor_parallel_size": 1,
},
])
@pytest.mark.parametrize("batch_size", [2])
@pytest.mark.parametrize("seed", [1])
def test_draft_model_tp_lt_target_model_tp3(test_llm_generator,
baseline_llm_generator,
batch_size: int):
"""Verify spec decode works well with smaller tp for draft models.
"""
run_greedy_equality_correctness_test(baseline_llm_generator,
test_llm_generator,
batch_size,
max_output_len=32,
force_output_len=True)
47 changes: 34 additions & 13 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from transformers import PretrainedConfig

import vllm.envs as envs
from vllm.distributed import get_current_tp_rank_partition_size
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
from vllm.model_executor.models import ModelRegistry
Expand Down Expand Up @@ -259,11 +260,13 @@ def verify_with_parallel_config(
total_num_attention_heads = getattr(self.hf_text_config,
"num_attention_heads", 0)
tensor_parallel_size = parallel_config.tensor_parallel_size
if total_num_attention_heads % tensor_parallel_size != 0:
if (total_num_attention_heads % tensor_parallel_size != 0
and self.quantization is not None):
raise ValueError(
f"Total number of attention heads ({total_num_attention_heads})"
f"Total number of attention heads "
f"({total_num_attention_heads})"
" must be divisible by tensor parallel size "
f"({tensor_parallel_size}).")
f"({tensor_parallel_size}) when quantization is used.")

pipeline_parallel_size = parallel_config.pipeline_parallel_size
architectures = getattr(self.hf_config, "architectures", [])
Expand Down Expand Up @@ -361,20 +364,32 @@ def get_total_num_kv_heads(self) -> int:
# equal to the number of attention heads.
return self.hf_text_config.num_attention_heads

def get_num_kv_heads(self, parallel_config: "ParallelConfig") -> int:
def get_num_kv_heads(self,
parallel_config: "ParallelConfig",
tp_rank: int = 0) -> int:
"""Returns the number of KV heads per GPU."""
total_num_kv_heads = self.get_total_num_kv_heads()
# If tensor parallelism is used, we divide the number of KV heads by
# the tensor parallel size. We will replicate the KV heads in the
# case where the number of KV heads is smaller than the tensor
# parallel size so each GPU has at least one KV head.
return max(1,
total_num_kv_heads // parallel_config.tensor_parallel_size)
result = get_current_tp_rank_partition_size(
total_num_kv_heads, tp_rank, parallel_config.tensor_parallel_size)
return max(1, result)

def get_num_attention_heads(self,
parallel_config: "ParallelConfig") -> int:
num_heads = getattr(self.hf_text_config, "num_attention_heads", 0)
return num_heads // parallel_config.tensor_parallel_size
parallel_config: "ParallelConfig",
tp_rank: int = 0) -> int:
if getattr(self.hf_text_config, "num_attention_heads", None) is None:
return 0

num_total_kv_heads = self.get_total_num_kv_heads()
num_kv_heads = self.get_num_kv_heads(parallel_config, tp_rank)
num_total_attention_heads = self.hf_text_config.num_attention_heads
num_heads_per_kv_head = num_total_attention_heads // num_total_kv_heads
# For GQA attention we make sure the whole attention head group is
# together on the same GPU.
return num_kv_heads * num_heads_per_kv_head

def get_num_layers(self, parallel_config: "ParallelConfig") -> int:
from vllm.distributed.utils import get_pp_indices
Expand Down Expand Up @@ -750,7 +765,7 @@ class SchedulerConfig:
enable_chunked_prefill: If True, prefill requests can be chunked based
on the remaining max_num_batched_tokens.
embedding_mode: Whether the running model is for embedding.
preemption_mode: Whether to perform preemption by swapping or
preemption_mode: Whether to perform preemption by swapping or
recomputation. If not specified, we determine the mode as follows:
We use recomputation by default since it incurs lower overhead than
swapping. However, when the sequence group has multiple sequences
Expand Down Expand Up @@ -921,12 +936,12 @@ def maybe_create_spec_config(
typical_acceptance_sampler_posterior_threshold (Optional[float]):
A threshold value that sets a lower bound on the posterior
probability of a token in the target model for it to be
accepted. This threshold is used only when we use the
accepted. This threshold is used only when we use the
TypicalAcceptanceSampler for token acceptance.
typical_acceptance_sampler_posterior_alpha (Optional[float]):
A scaling factor for the entropy-based threshold in the
TypicalAcceptanceSampler.

Returns:
Optional["SpeculativeConfig"]: An instance of SpeculativeConfig if
the necessary conditions are met, else None.
Expand Down Expand Up @@ -1156,7 +1171,7 @@ def __init__(
typical_acceptance_sampler_posterior_threshold (Optional[float]):
A threshold value that sets a lower bound on the posterior
probability of a token in the target model for it to be
accepted. This threshold is used only when we use the
accepted. This threshold is used only when we use the
TypicalAcceptanceSampler for token acceptance.
typical_acceptance_sampler_posterior_alpha (Optional[float]):
A scaling factor for the entropy-based threshold in the
Expand Down Expand Up @@ -1284,6 +1299,11 @@ def verify_with_scheduler_config(self, scheduler_config: SchedulerConfig):
if scheduler_config.chunked_prefill_enabled:
raise ValueError("LoRA is not supported with chunked prefill yet.")

def verify_with_parallel_config(self, parallel_config: ParallelConfig):
if self.lora_vocab_padding_size % parallel_config.world_size != 0:
raise ValueError("LoRA vocab padding size must be divisible "
"by world size.")


@dataclass
class MultiModalConfig:
Expand Down Expand Up @@ -1529,6 +1549,7 @@ def __post_init__(self):
self.lora_config.verify_with_model_config(self.model_config)
self.lora_config.verify_with_scheduler_config(
self.scheduler_config)
self.lora_config.verify_with_parallel_config(self.parallel_config)

def to_dict(self):
"""Return the configs as a dictionary, for use in **kwargs.
Expand Down
36 changes: 34 additions & 2 deletions vllm/distributed/parallel_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
The typical workflow is:

- call `init_distributed_environment` to initialize the distributed environment.
- call `initialize_model_parallel` or `ensure_model_parallel_initialized` to
- call `initialize_model_parallel` or `ensure_model_parallel_initialized` to
initialize the model parallel groups.

- any code dealing with the distributed stuff
Expand Down Expand Up @@ -272,7 +272,7 @@ def graph_capture(

def all_reduce(self, input_: torch.Tensor) -> torch.Tensor:
"""
NOTE: This operation will be applied in-place or out-of-place.
NOTE: This operation will be applied in-place or out-of-place.
Always assume this function modifies its input, but use the return
value as the output.
"""
Expand Down Expand Up @@ -1055,3 +1055,35 @@ def is_in_the_same_node(pg: ProcessGroup):
torch.distributed.all_reduce(is_in_the_same_node, group=pg)

return is_in_the_same_node.sum().item() == world_size


def get_current_tp_rank_partition_offset(total_size: int,
tp_rank: Optional[int] = None,
tp_size: Optional[int] = None,
multiple_of: int = 1) -> int:
if tp_rank is None:
tp_rank = get_tensor_model_parallel_rank()

if tp_size is None:
tp_size = get_tensor_model_parallel_world_size()

assert total_size % multiple_of == 0
total_size = total_size // multiple_of
return ((total_size // tp_size) * tp_rank +
min(total_size % tp_size, tp_rank)) * multiple_of


def get_current_tp_rank_partition_size(total_size: int,
tp_rank: Optional[int] = None,
tp_size: Optional[int] = None,
multiple_of: int = 1) -> int:
if tp_rank is None:
tp_rank = get_tensor_model_parallel_rank()

if tp_size is None:
tp_size = get_tensor_model_parallel_world_size()

assert total_size % multiple_of == 0
total_size = total_size // multiple_of
return ((total_size // tp_size) +
(total_size % tp_size > tp_rank)) * multiple_of
2 changes: 1 addition & 1 deletion vllm/engine/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -457,4 +457,4 @@ def log(self, stats: Stats):

class RayPrometheusStatLogger(PrometheusStatLogger):
"""RayPrometheusStatLogger uses Ray metrics instead."""
_metrics_cls = RayMetrics
_metrics_cls = RayMetrics
Loading
Loading