Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Speculative Decoding] Support draft model on different tensor-parallel size than target model #5414

Merged
merged 131 commits into from
Jun 25, 2024
Merged
Show file tree
Hide file tree
Changes from 89 commits
Commits
Show all changes
131 commits
Select commit Hold shift + click to select a range
f5b5f94
tp1 draft worker
wooyeonlee0 Jun 10, 2024
709de21
refactor singlt_tp_worker
wooyeonlee0 Jun 10, 2024
0eacc96
update execute_model logic
wooyeonlee0 Jun 10, 2024
2011ed0
fix
wooyeonlee0 Jun 11, 2024
2e16c4e
DummyProposerWorker
wooyeonlee0 Jun 11, 2024
b412a51
fix
wooyeonlee0 Jun 11, 2024
593ccfa
init only partial workers
wooyeonlee0 Jun 11, 2024
c5d3476
Use multi_step_worker logic
wooyeonlee0 Jun 12, 2024
44e623b
self._patch_tp_group
wooyeonlee0 Jun 12, 2024
98caf17
refactor it to support other draft-tp than 1
wooyeonlee0 Jun 12, 2024
7fc4ff5
spec-tp configuarable
wooyeonlee0 Jun 12, 2024
a96e720
ngram worker support test
wooyeonlee0 Jun 12, 2024
db39576
minor refine
wooyeonlee0 Jun 12, 2024
b2e8595
cleanup
wooyeonlee0 Jun 12, 2024
756442a
return type fix
wooyeonlee0 Jun 12, 2024
32094f1
cleanup
wooyeonlee0 Jun 12, 2024
7890191
cleanup
wooyeonlee0 Jun 12, 2024
53b2ea9
typo
wooyeonlee0 Jun 12, 2024
a29c9c5
verify arg
wooyeonlee0 Jun 12, 2024
52ba09d
remove testing code
wooyeonlee0 Jun 12, 2024
d26ef08
cleanup
wooyeonlee0 Jun 12, 2024
80c4994
rename module
wooyeonlee0 Jun 12, 2024
0f16f3f
cleanup
wooyeonlee0 Jun 12, 2024
140f478
cleanup
wooyeonlee0 Jun 12, 2024
3fd7e91
remove unnecessary methods
wooyeonlee0 Jun 12, 2024
495aa30
fix
wooyeonlee0 Jun 12, 2024
3a5a47f
undo unrelated changes
wooyeonlee0 Jun 12, 2024
07ddbb8
minor fix
wooyeonlee0 Jun 12, 2024
b0a677d
fix ruff errors
wooyeonlee0 Jun 12, 2024
96782a2
Merge branch 'main' into spec-tp1-draft
wooyeonlee0 Jun 12, 2024
9998b9c
typo
wooyeonlee0 Jun 12, 2024
e92ecdc
temporal fix
wooyeonlee0 Jun 12, 2024
b421607
formatting
wooyeonlee0 Jun 12, 2024
386ab9b
isort
wooyeonlee0 Jun 12, 2024
b25f74e
line length
wooyeonlee0 Jun 12, 2024
8b51f08
fix
wooyeonlee0 Jun 13, 2024
d4b283c
Merge remote-tracking branch 'origin' into spec-tp1-draft
wooyeonlee0 Jun 13, 2024
dfc90cb
line length
wooyeonlee0 Jun 13, 2024
9bef5e4
comment
wooyeonlee0 Jun 13, 2024
85d087d
add type hint
wooyeonlee0 Jun 13, 2024
9af36b7
isort
wooyeonlee0 Jun 13, 2024
5a0bf45
add more type hints
wooyeonlee0 Jun 13, 2024
531c9f0
fix
wooyeonlee0 Jun 13, 2024
287da20
test
wooyeonlee0 Jun 13, 2024
08d1b2a
nit
wooyeonlee0 Jun 13, 2024
237c966
fix yapf
wooyeonlee0 Jun 13, 2024
0bb38c2
fix
wooyeonlee0 Jun 13, 2024
c097d6c
fix
wooyeonlee0 Jun 13, 2024
957a325
fix
wooyeonlee0 Jun 13, 2024
3ec8cb5
Merge remote-tracking branch 'origin' into spec-tp1-draft
wooyeonlee0 Jun 14, 2024
8a8a1e4
add comments
wooyeonlee0 Jun 14, 2024
7f06f64
combine smaller_tp_worker logic into multi_step_worker
wooyeonlee0 Jun 14, 2024
1e87579
fix
wooyeonlee0 Jun 14, 2024
abc546c
fix
wooyeonlee0 Jun 14, 2024
7880cb0
add small_tp correctness test
wooyeonlee0 Jun 14, 2024
2ebe6f3
nit
wooyeonlee0 Jun 14, 2024
90d46ee
fix
wooyeonlee0 Jun 14, 2024
7e1426c
refactor. remove log
wooyeonlee0 Jun 14, 2024
ad52d93
remove return
wooyeonlee0 Jun 14, 2024
355475b
fix
wooyeonlee0 Jun 14, 2024
9cfdb5b
fix about context managing
wooyeonlee0 Jun 14, 2024
6a6c5ff
nit
wooyeonlee0 Jun 14, 2024
ddef229
consistent condition. if self._is_dummy:
wooyeonlee0 Jun 14, 2024
965f648
fix ruff errors
wooyeonlee0 Jun 14, 2024
1bb5534
isort
wooyeonlee0 Jun 14, 2024
ea6b8f5
fix yapf
wooyeonlee0 Jun 14, 2024
71977d2
undo ngramworker support
wooyeonlee0 Jun 14, 2024
bc5f77a
add comment
wooyeonlee0 Jun 14, 2024
5655a49
remove smaller_tp_proposer_worker
wooyeonlee0 Jun 14, 2024
eabc16a
ruff
wooyeonlee0 Jun 14, 2024
f748edf
remove ranks arg
wooyeonlee0 Jun 17, 2024
c099c94
Merge remote-tracking branch 'origin' into spec-tp1-draft
wooyeonlee0 Jun 17, 2024
4b74a45
undo
wooyeonlee0 Jun 17, 2024
c9786ad
add dist test
wooyeonlee0 Jun 17, 2024
a42664a
nit
wooyeonlee0 Jun 17, 2024
ac7701a
fix
wooyeonlee0 Jun 17, 2024
eea6a7e
test fix
wooyeonlee0 Jun 17, 2024
a648f5d
yapf fix
wooyeonlee0 Jun 17, 2024
f23ba8c
update comment
wooyeonlee0 Jun 17, 2024
aa9af93
require 2 gpus
wooyeonlee0 Jun 17, 2024
56c8927
restore draft_ranks arg in MultiStepWorker.__init__
wooyeonlee0 Jun 18, 2024
385b4f8
comment
wooyeonlee0 Jun 18, 2024
43f37eb
ruff mypy
wooyeonlee0 Jun 18, 2024
99350e2
isort
wooyeonlee0 Jun 18, 2024
a9f3e23
yapf
wooyeonlee0 Jun 18, 2024
6ba250d
allow None for draft_ranks
wooyeonlee0 Jun 18, 2024
3e78613
spec-tp arg in benchmark_latency
wooyeonlee0 Jun 18, 2024
6532af7
yapf
wooyeonlee0 Jun 18, 2024
6839797
yapf
wooyeonlee0 Jun 18, 2024
aac586b
Merge remote-tracking branch 'origin' into spec-tp1-draft
wooyeonlee0 Jun 19, 2024
98e584d
remove is_dummy check from sampler_output
wooyeonlee0 Jun 19, 2024
2d5e64d
add comment
wooyeonlee0 Jun 20, 2024
ba88bd4
yapf
wooyeonlee0 Jun 20, 2024
46e5274
resolve cade comments
wooyeonlee0 Jun 21, 2024
85f4f25
refactoring patch_tp_group
wooyeonlee0 Jun 21, 2024
c1b5373
cleanup patch_tp_group logic
wooyeonlee0 Jun 21, 2024
4a58617
speculative_draft_tensor_parallel_size
wooyeonlee0 Jun 21, 2024
b09e7be
ruff, yapf
wooyeonlee0 Jun 21, 2024
7168d78
remove world group patch
wooyeonlee0 Jun 21, 2024
fe0bd5b
isort, yapf
wooyeonlee0 Jun 21, 2024
2e0d170
yield fix
wooyeonlee0 Jun 21, 2024
36f8aa5
debugging
wooyeonlee0 Jun 21, 2024
54bf514
log
wooyeonlee0 Jun 21, 2024
bfd7d2f
reintroduce smaller_tp_proposer_worker
wooyeonlee0 Jun 21, 2024
f337428
add lora methods
wooyeonlee0 Jun 21, 2024
4654b9f
missing method
wooyeonlee0 Jun 21, 2024
e39926e
remove world group related logics
wooyeonlee0 Jun 21, 2024
1c6eefd
Always wrapping MultiStepWorker
wooyeonlee0 Jun 21, 2024
f2d2ee5
remove unused logger
wooyeonlee0 Jun 21, 2024
302955c
isort. minor rename
wooyeonlee0 Jun 21, 2024
3d4754e
LoraNotSupported. return type
wooyeonlee0 Jun 21, 2024
620b224
yapf, ruff
wooyeonlee0 Jun 21, 2024
b245d3c
add skip_spec_test
wooyeonlee0 Jun 21, 2024
1e71e98
remove spec-tp 3 case
wooyeonlee0 Jun 21, 2024
a01c00d
spec-draft-tp
wooyeonlee0 Jun 21, 2024
debffc2
_TP_STATE_PATCHED
wooyeonlee0 Jun 24, 2024
39fe67f
remove stale comment
wooyeonlee0 Jun 24, 2024
af1b0be
dist_tp2, dist_tp4 tests
wooyeonlee0 Jun 24, 2024
834c6e0
remove unnecessary overriding methods
wooyeonlee0 Jun 24, 2024
5bc2bc3
comment
wooyeonlee0 Jun 24, 2024
8740369
yapf
wooyeonlee0 Jun 24, 2024
4d82ca1
comment
wooyeonlee0 Jun 24, 2024
7bf831c
undo change in test utils
wooyeonlee0 Jun 24, 2024
3fccc76
remove test_skip_speculation
wooyeonlee0 Jun 24, 2024
e8d0e93
tp4 test only for spec_tp1
wooyeonlee0 Jun 25, 2024
91c2e43
allow only value 1 for spec_tp
wooyeonlee0 Jun 25, 2024
fac7e68
yapf
wooyeonlee0 Jun 25, 2024
271822e
add todo comment
wooyeonlee0 Jun 25, 2024
ae0d7f1
add tests for check that test_skip fails even there's no spec_draft_t…
wooyeonlee0 Jun 25, 2024
b84a070
remove test_skip_speculation from dist tests
wooyeonlee0 Jun 25, 2024
86fda24
yapf
wooyeonlee0 Jun 25, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 26 additions & 20 deletions benchmarks/benchmark_latency.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,26 +20,28 @@ def main(args: argparse.Namespace):

# NOTE(woosuk): If the request cannot be processed in a single batch,
# the engine will automatically process the request in multiple batches.
llm = LLM(model=args.model,
speculative_model=args.speculative_model,
num_speculative_tokens=args.num_speculative_tokens,
tokenizer=args.tokenizer,
quantization=args.quantization,
tensor_parallel_size=args.tensor_parallel_size,
trust_remote_code=args.trust_remote_code,
dtype=args.dtype,
enforce_eager=args.enforce_eager,
kv_cache_dtype=args.kv_cache_dtype,
quantization_param_path=args.quantization_param_path,
device=args.device,
ray_workers_use_nsight=args.ray_workers_use_nsight,
use_v2_block_manager=args.use_v2_block_manager,
enable_chunked_prefill=args.enable_chunked_prefill,
download_dir=args.download_dir,
block_size=args.block_size,
gpu_memory_utilization=args.gpu_memory_utilization,
load_format=args.load_format,
distributed_executor_backend=args.distributed_executor_backend)
llm = LLM(
model=args.model,
speculative_model=args.speculative_model,
num_speculative_tokens=args.num_speculative_tokens,
speculative_tensor_parallel_size=args.speculative_tensor_parallel_size,
tokenizer=args.tokenizer,
quantization=args.quantization,
tensor_parallel_size=args.tensor_parallel_size,
trust_remote_code=args.trust_remote_code,
dtype=args.dtype,
enforce_eager=args.enforce_eager,
kv_cache_dtype=args.kv_cache_dtype,
quantization_param_path=args.quantization_param_path,
device=args.device,
ray_workers_use_nsight=args.ray_workers_use_nsight,
use_v2_block_manager=args.use_v2_block_manager,
enable_chunked_prefill=args.enable_chunked_prefill,
download_dir=args.download_dir,
block_size=args.block_size,
gpu_memory_utilization=args.gpu_memory_utilization,
load_format=args.load_format,
distributed_executor_backend=args.distributed_executor_backend)

sampling_params = SamplingParams(
n=args.n,
Expand Down Expand Up @@ -122,6 +124,10 @@ def run_to_completion(profile_dir: Optional[str] = None):
parser.add_argument('--model', type=str, default='facebook/opt-125m')
parser.add_argument('--speculative-model', type=str, default=None)
parser.add_argument('--num-speculative-tokens', type=int, default=None)
parser.add_argument('--speculative-tensor-parallel-size',
'-spec-tp',
type=int,
default=None)
wooyeonlee0 marked this conversation as resolved.
Show resolved Hide resolved
parser.add_argument('--tokenizer', type=str, default=None)
parser.add_argument('--quantization',
'-q',
Expand Down
46 changes: 46 additions & 0 deletions tests/spec_decode/e2e/test_integration_dist.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,3 +63,49 @@ def test_target_model_tp_gt_1(baseline_llm_generator, test_llm_generator,
batch_size,
max_output_len=output_len,
force_output_len=True)


@pytest.mark.skipif(torch.cuda.device_count() < 2,
reason="Need at least 2 GPUs to run the test.")
wooyeonlee0 marked this conversation as resolved.
Show resolved Hide resolved
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
# Use a small model for a fast test.
# Note this is repeated in the test body; to initialize a tokenizer.
"model": "JackFram/llama-68m",

# Skip cuda graph recording for fast test.
"enforce_eager": True,

# Required for spec decode.
"use_v2_block_manager": True,
"tensor_parallel_size": 2,

# Use AsyncLLM engine, so that the engine runs in its own process.
# Otherwise, since vLLM does not follow true SPMD, the test runner
# process will have both the engine and the rank0 worker. NCCL is not
# cleaned up properly, and its server host thread leaks, causing the
# second run of the test to fail with internal NCCL error.
"use_async": True,
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [
{
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
"speculative_tensor_parallel_size": 1,
},
])
@pytest.mark.parametrize("batch_size", [2])
@pytest.mark.parametrize("seed", [1])
def test_draft_model_tp_lt_target_model(test_llm_generator,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we have a test where we disable some speculations? This will verify that control-flow logic behaves correctly even when draft TP == 1 or draft TP == 2.

see this test for example.

def test_skip_speculation(baseline_llm_generator, test_llm_generator,
batch_size: int, output_len: int):
"""Verify greedy equality when some (or all) sequences skip speculation.
We do this by setting the max model len of the draft model to an
artificially low value, such that when the sequences grow beyond it, they
are skipped in speculative decoding.
"""
run_greedy_equality_correctness_test(baseline_llm_generator,
test_llm_generator,
batch_size,
max_output_len=output_len,
force_output_len=True)

baseline_llm_generator,
batch_size: int):
"""Verify spec decode works well with smaller tp for draft models.
"""
run_greedy_equality_correctness_test(baseline_llm_generator,
test_llm_generator,
batch_size,
max_output_len=32,
force_output_len=True)
32 changes: 20 additions & 12 deletions tests/spec_decode/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from vllm.sequence import (CompletionSequenceGroupOutput, Logprob,
SamplerOutput, SequenceData, SequenceGroupMetadata,
SequenceOutput)
from vllm.spec_decode.multi_step_worker import MultiStepWorker
from vllm.utils import get_distributed_init_method, get_ip, get_open_port
from vllm.worker.cache_engine import CacheEngine
from vllm.worker.worker import Worker
Expand Down Expand Up @@ -66,6 +67,7 @@ def create_worker(cls: Callable[..., T],
num_gpu_blocks: int,
seed: int,
is_driver_worker: bool = True,
draft_ranks: Optional[List[int]] = None,
enforce_eager: bool = True) -> T:
engine_args = EngineArgs(
model=model_name,
Expand All @@ -78,18 +80,24 @@ def create_worker(cls: Callable[..., T],
distributed_init_method = get_distributed_init_method(
get_ip(), get_open_port())

worker = cls(
model_config=engine_config.model_config,
parallel_config=engine_config.parallel_config,
scheduler_config=engine_config.scheduler_config,
device_config=engine_config.device_config,
cache_config=engine_config.cache_config,
load_config=engine_config.load_config,
local_rank=0,
rank=0,
distributed_init_method=distributed_init_method,
is_driver_worker=is_driver_worker,
)
worker_kwargs = {
'model_config': engine_config.model_config,
'parallel_config': engine_config.parallel_config,
'scheduler_config': engine_config.scheduler_config,
'device_config': engine_config.device_config,
'cache_config': engine_config.cache_config,
'load_config': engine_config.load_config,
'local_rank': 0,
'rank': 0,
'distributed_init_method': distributed_init_method,
'is_driver_worker': is_driver_worker,
}

if draft_ranks is not None:
assert cls is MultiStepWorker, "draft_ranks arg is for MultiStepWorker"
worker_kwargs['draft_ranks'] = draft_ranks

worker = cls(**worker_kwargs)

worker.init_device()
worker.load_model()
Expand Down
22 changes: 17 additions & 5 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -788,6 +788,7 @@ def maybe_create_spec_config(
target_parallel_config: ParallelConfig,
target_dtype: str,
speculative_model: Optional[str],
speculative_tensor_parallel_size: Optional[int],
num_speculative_tokens: Optional[int],
speculative_max_model_len: Optional[int],
enable_chunked_prefill: bool,
Expand Down Expand Up @@ -914,7 +915,7 @@ def maybe_create_spec_config(

draft_parallel_config = (
SpeculativeConfig.create_draft_parallel_config(
target_parallel_config))
target_parallel_config, speculative_tensor_parallel_size))

return SpeculativeConfig(
draft_model_config,
Expand Down Expand Up @@ -962,16 +963,27 @@ def _maybe_override_draft_max_model_len(

@staticmethod
def create_draft_parallel_config(
target_parallel_config: ParallelConfig) -> ParallelConfig:
target_parallel_config: ParallelConfig,
speculative_tensor_parallel_size: Optional[int]) -> ParallelConfig:
"""Create a parallel config for use by the draft worker.

This is mostly a copy of the target parallel config. In the future the
draft worker can have a different parallel strategy, e.g. TP=1.
This is mostly a copy of the target parallel config, except the tp_size.
"""

speculative_tensor_parallel_size = (
speculative_tensor_parallel_size
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: if the user provides --speculative-tensor-parallel-size 0, this branch causes unexpected behavior. Can we explicitly guard against this?

Copy link
Contributor Author

@wooyeonlee0 wooyeonlee0 Jun 21, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the catch!
To prevent spec_tp from being set as target_tp when given spec_tp is 0, I've changed the code as below:

        if speculative_tensor_parallel_size is None:
            speculative_tensor_parallel_size = target_parallel_config.tensor_parallel_size

In addition, to prevent tp value from being 0, I think we need to make a separate PR to handle that case by adding a check in ParallelConfig._verify_args(). Because It seems to be the same in the --tensor-parallel-size 0 case.

What do you think?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i'll look at your new changes; no need to completely fix this (just a nit)

or target_parallel_config.tensor_parallel_size)

if speculative_tensor_parallel_size > \
target_parallel_config.tensor_parallel_size:
raise ValueError(
f"{speculative_tensor_parallel_size=} cannot be "
f"larger than {target_parallel_config.tensor_parallel_size}")
wooyeonlee0 marked this conversation as resolved.
Show resolved Hide resolved

draft_parallel_config = ParallelConfig(
pipeline_parallel_size=target_parallel_config.
pipeline_parallel_size,
tensor_parallel_size=target_parallel_config.tensor_parallel_size,
tensor_parallel_size=speculative_tensor_parallel_size,
distributed_executor_backend=target_parallel_config.
distributed_executor_backend,
max_parallel_loading_workers=target_parallel_config.
Expand Down
30 changes: 29 additions & 1 deletion vllm/distributed/parallel_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -551,6 +551,10 @@ def init_distributed_environment(
global _WORLD
if _WORLD is None:
ranks = list(range(torch.distributed.get_world_size()))
if world_size != -1:
assert world_size == len(ranks), (
"given world_size does not match with world_size of torch")
wooyeonlee0 marked this conversation as resolved.
Show resolved Hide resolved

_WORLD = GroupCoordinator(
group_ranks=[ranks],
local_rank=local_rank,
Expand All @@ -559,7 +563,7 @@ def init_distributed_environment(
use_custom_allreduce=False,
)
else:
assert _WORLD.world_size == torch.distributed.get_world_size(), (
assert _WORLD.world_size == world_size, (
wooyeonlee0 marked this conversation as resolved.
Show resolved Hide resolved
"world group already initialized with a different world size")


Expand Down Expand Up @@ -674,6 +678,30 @@ def model_parallel_is_initialized():
return (_TP is not None and _PP is not None)


OVERRIDE_TP_STATE = False


@contextlib.contextmanager
def patch_tensor_parallel_group(world_group, tp_group):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will this global variable patching potentially create problem? For example, is it possible that other workers will use this context unknowingly?

Copy link
Contributor Author

@wooyeonlee0 wooyeonlee0 Jun 14, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the current design of speculative decoding, draft and target workers execute sequentially.
So there's no chance of target workers using the patched/overridden context.

But if draft and target worker execute concurrently in the future, the code should be redesigned to prevent states being mixed with each other.

"""Patch the tp group temporarily until this function ends."""
global OVERRIDE_TP_STATE
if not OVERRIDE_TP_STATE and world_group and tp_group:
OVERRIDE_TP_STATE = True
old_world_group = get_world_group()
Copy link
Contributor

@zifeitong zifeitong Jun 18, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how about only override tp_group and not world_group here.

will it work?

I saw get_world_group() is only used once in the codebase (very early in the initialization stage).

Copy link
Contributor Author

@wooyeonlee0 wooyeonlee0 Jun 19, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it is used only during initialization.
If we do not override world_group then it will fail during initialization due to an assertion check (link).
This checks the world_group size consistency between workers when spawning multiple workers in the same process (or ray worker).
In our case, it asserts the current world_group should have the same size with the world_group being initialized.

Note that this check is added by #5293 and slightly modified in this PR to support the small draft-tp case.
I'm not sure in which scenarios this check is used, but I thought it would be safer to keep it.
What do you think?

It's a slightly different story, but when I opened this PR, things were a little different.
get_world_group() was also used after initialization by broadcast_tensor_dict(), which is for driver to control workers.
But this use case has gone after refactoring by other PRs.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the explanation.

How about add a comment about world_group? Since this function is named as patch_tensor_parallel_group. Or maybe rename the function to patch_distributed_group?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good suggestion :)
I've added a comment in the docstring of patch_tensor_parallel_group().

old_tp_group = get_tp_group()
global _WORLD, _TP
_WORLD = world_group
_TP = tp_group
try:
yield
finally:
# restore the original state
if OVERRIDE_TP_STATE:
OVERRIDE_TP_STATE = False
_WORLD = old_world_group
_TP = old_tp_group


def get_tensor_model_parallel_world_size():
"""Return world size for the tensor model parallel group."""
return get_tp_group().world_size
Expand Down
10 changes: 10 additions & 0 deletions vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ class EngineArgs:
guided_decoding_backend: str = 'outlines'
# Speculative decoding configuration.
speculative_model: Optional[str] = None
speculative_tensor_parallel_size: Optional[int] = None
num_speculative_tokens: Optional[int] = None
speculative_max_model_len: Optional[int] = None
speculative_disable_by_batch_size: Optional[int] = None
Expand Down Expand Up @@ -534,6 +535,13 @@ def add_cli_args(
default=EngineArgs.num_speculative_tokens,
help='The number of speculative tokens to sample from '
'the draft model in speculative decoding.')
parser.add_argument(
'--speculative-tensor-parallel-size',
'-spec-tp',
type=int,
default=EngineArgs.speculative_tensor_parallel_size,
help='Number of tensor parallel replicas for '
'the draft model in speculative decoding.')

parser.add_argument(
'--speculative-max-model-len',
Expand Down Expand Up @@ -676,6 +684,8 @@ def create_engine_config(self, ) -> EngineConfig:
target_parallel_config=parallel_config,
target_dtype=self.dtype,
speculative_model=self.speculative_model,
speculative_tensor_parallel_size = \
self.speculative_tensor_parallel_size,
num_speculative_tokens=self.num_speculative_tokens,
speculative_disable_by_batch_size=self.
speculative_disable_by_batch_size,
Expand Down
Loading
Loading