Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Speculative Decoding] MLPSpeculator Tensor Parallel support (1/2) #6050

Merged
merged 5 commits into from
Jul 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 24 additions & 12 deletions tests/spec_decode/e2e/test_integration_dist_tp2.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,10 +70,6 @@ def test_target_model_tp_gt_1(baseline_llm_generator, test_llm_generator,
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
# Use a small model for a fast test.
# Note this is repeated in the test body; to initialize a tokenizer.
"model": "JackFram/llama-68m",

# Skip cuda graph recording for fast test.
"enforce_eager": True,

Expand All @@ -88,15 +84,31 @@ def test_target_model_tp_gt_1(baseline_llm_generator, test_llm_generator,
# second run of the test to fail with internal NCCL error.
"use_async": True,
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [
{
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
"speculative_draft_tensor_parallel_size": 1,
},
])
@pytest.mark.parametrize(
"per_test_common_llm_kwargs, test_llm_kwargs",
[
(
{
# Use a small model for a fast test.
# Note this is repeated in the test body; to initialize a
# tokenizer.
"model": "JackFram/llama-68m",
},
{
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
"speculative_draft_tensor_parallel_size": 1,
}),
({
"model": "ibm-granite/granite-3b-code-instruct",
}, {
"speculative_model":
"ibm-granite/granite-3b-code-instruct-accelerator",
"num_speculative_tokens": 5,
"speculative_draft_tensor_parallel_size": 1,
})
])
@pytest.mark.parametrize("batch_size", [2])
@pytest.mark.parametrize("seed", [1])
def test_draft_model_tp_lt_target_model_tp2(test_llm_generator,
Expand Down
6 changes: 0 additions & 6 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -957,12 +957,6 @@ def maybe_create_spec_config(
)

draft_hf_config = draft_model_config.hf_config
if (draft_hf_config.model_type == "mlp_speculator"
and target_parallel_config.world_size != 1):
# MLPSpeculator TP support will be added very soon
raise ValueError(
"Speculative decoding with mlp_speculator models does not "
"yet support distributed inferencing (TP > 1).")

if (num_speculative_tokens is not None
and hasattr(draft_hf_config, "num_lookahead_tokens")):
Expand Down
18 changes: 11 additions & 7 deletions vllm/spec_decode/spec_decode_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,24 +113,28 @@ def create_worker(
draft_worker_kwargs.pop("ngram_prompt_lookup_min"))

disable_bonus_tokens = True

if ngram_prompt_lookup_max > 0:
disable_bonus_tokens = False
proposer_worker = NGramWorker(**draft_worker_kwargs)
proposer_worker.set_ngram_window_size(ngram_prompt_lookup_min,
ngram_prompt_lookup_max)
elif draft_worker_kwargs[
"model_config"].hf_config.model_type == "mlp_speculator":
proposer_worker = MLPSpeculatorWorker(**draft_worker_kwargs)
disable_bonus_tokens = False
else:
draft_parallel_config: ParallelConfig = draft_worker_kwargs[
'parallel_config']
draft_tp = draft_parallel_config.tensor_parallel_size
target_tp = scorer_worker.parallel_config.tensor_parallel_size

if draft_tp == 1:
draft_worker_kwargs["model_runner_cls"] = TP1DraftModelRunner
proposer_worker = MultiStepWorker(**draft_worker_kwargs)
if draft_worker_kwargs[
"model_config"].hf_config.model_type == "mlp_speculator":
disable_bonus_tokens = False
proposer_worker = MLPSpeculatorWorker(**draft_worker_kwargs)
else:
if draft_tp == 1:
draft_worker_kwargs[
"model_runner_cls"] = TP1DraftModelRunner
proposer_worker = MultiStepWorker(**draft_worker_kwargs)

proposer_worker = SmallerTpProposerWorker.maybe_wrap_worker(
proposer_worker, draft_tp, target_tp)

Expand Down
Loading