diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index eeb6eaa2165bc..aa74672f4bf67 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -42,6 +42,7 @@ steps: - TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_basic_distributed_correctness.py - TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_chunked_prefill_distributed.py - TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_chunked_prefill_distributed.py + - pytest -v -s spec_decode/e2e/test_integration_dist.py - label: Distributed Tests (Multiple Groups) working_dir: "/vllm-workspace/tests" diff --git a/benchmarks/benchmark_latency.py b/benchmarks/benchmark_latency.py index 44da3bad8d840..8f3168c115ae6 100644 --- a/benchmarks/benchmark_latency.py +++ b/benchmarks/benchmark_latency.py @@ -18,6 +18,8 @@ def main(args: argparse.Namespace): # NOTE(woosuk): If the request cannot be processed in a single batch, # the engine will automatically process the request in multiple batches. llm = LLM(model=args.model, + speculative_model=args.speculative_model, + num_speculative_tokens=args.num_speculative_tokens, tokenizer=args.tokenizer, quantization=args.quantization, tensor_parallel_size=args.tensor_parallel_size, @@ -28,6 +30,7 @@ def main(args: argparse.Namespace): quantization_param_path=args.quantization_param_path, device=args.device, ray_workers_use_nsight=args.ray_workers_use_nsight, + use_v2_block_manager=args.use_v2_block_manager, enable_chunked_prefill=args.enable_chunked_prefill, download_dir=args.download_dir, block_size=args.block_size) @@ -99,6 +102,8 @@ def run_to_completion(profile_dir: Optional[str] = None): description='Benchmark the latency of processing a single batch of ' 'requests till completion.') parser.add_argument('--model', type=str, default='facebook/opt-125m') + parser.add_argument('--speculative-model', type=str, default=None) + parser.add_argument('--num-speculative-tokens', type=int, default=None) parser.add_argument('--tokenizer', type=str, default=None) parser.add_argument('--quantization', '-q', @@ -181,6 +186,7 @@ def run_to_completion(profile_dir: Optional[str] = None): action='store_true', help='If True, the prefill requests can be chunked based on the ' 'max_num_batched_tokens') + parser.add_argument('--use-v2-block-manager', action='store_true') parser.add_argument( "--ray-workers-use-nsight", action='store_true', diff --git a/tests/spec_decode/e2e/test_compatibility.py b/tests/spec_decode/e2e/test_compatibility.py index 60c20ed7db7a3..81f91c5e10b0d 100644 --- a/tests/spec_decode/e2e/test_compatibility.py +++ b/tests/spec_decode/e2e/test_compatibility.py @@ -5,56 +5,6 @@ from .conftest import get_output_from_llm_generator -@pytest.mark.parametrize( - "common_llm_kwargs", - [{ - "model": "JackFram/llama-68m", - "speculative_model": "JackFram/llama-68m", - "num_speculative_tokens": 5, - - # Required for spec decode. - "use_v2_block_manager": True - }]) -@pytest.mark.parametrize( - "per_test_common_llm_kwargs", - [ - { - # Expect failure as spec decode not supported by - # Ray backend. - "worker_use_ray": True, - }, - ]) -@pytest.mark.parametrize("test_llm_kwargs", [{}]) -@pytest.mark.parametrize("seed", [1]) -def test_spec_decode_xfail_ray(test_llm_generator): - """Verify that speculative decoding with Ray fails. - """ - output_len = 128 - temperature = 0.0 - - prompts = [ - "Hello, my name is", - ] - - sampling_params = SamplingParams( - max_tokens=output_len, - ignore_eos=True, - temperature=temperature, - ) - - try: - with pytest.raises( - AssertionError, - match="Speculative decoding not yet supported for "): - get_output_from_llm_generator(test_llm_generator, prompts, - sampling_params) - finally: - # we need to free up ray resource, - # so that latter test could use the gpu we allocated here - import ray - ray.shutdown() - - @pytest.mark.parametrize( "common_llm_kwargs", [{ diff --git a/tests/spec_decode/e2e/test_integration.py b/tests/spec_decode/e2e/test_integration.py new file mode 100644 index 0000000000000..4a2b62151f8cd --- /dev/null +++ b/tests/spec_decode/e2e/test_integration.py @@ -0,0 +1,44 @@ +"""Tests which cover integration of the speculative decoding framework with +other features, e.g. cuda graphs. +""" + +import pytest + +from .conftest import run_greedy_equality_correctness_test + + +@pytest.mark.parametrize( + "common_llm_kwargs", + [{ + # Required for spec decode. + "use_v2_block_manager": True, + + # Verify equality when cuda graphs allowed. + "enforce_eager": False, + "model": "JackFram/llama-68m", + }]) +@pytest.mark.parametrize( + "per_test_common_llm_kwargs", + [ + { + # Identical models. + "speculative_model": "JackFram/llama-68m", + "num_speculative_tokens": 5, + }, + ]) +@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) +@pytest.mark.parametrize("test_llm_kwargs", [{}]) +@pytest.mark.parametrize("batch_size", [8]) +@pytest.mark.parametrize("output_len", [32]) +@pytest.mark.parametrize("seed", [1]) +def test_spec_decode_cuda_graph(baseline_llm_generator, test_llm_generator, + batch_size, output_len): + """Verify spec decode equality when cuda graphs are enabled. + """ + run_greedy_equality_correctness_test( + baseline_llm_generator, + test_llm_generator, + batch_size, + max_output_len=output_len, + force_output_len=True, + ) diff --git a/tests/spec_decode/e2e/test_integration_dist.py b/tests/spec_decode/e2e/test_integration_dist.py new file mode 100644 index 0000000000000..d444ef24cbfda --- /dev/null +++ b/tests/spec_decode/e2e/test_integration_dist.py @@ -0,0 +1,65 @@ +"""Tests which cover integration of the speculative decoding framework with +tensor parallelism. +""" + +import pytest +import torch + +from vllm.utils import is_hip + +from .conftest import run_greedy_equality_correctness_test + + +@pytest.mark.skipif(torch.cuda.device_count() < 2, + reason="Need at least 2 GPUs to run the test.") +@pytest.mark.parametrize( + "common_llm_kwargs", + [{ + "model": "JackFram/llama-68m", + + # Skip cuda graph recording for fast test. + "enforce_eager": True, + + # Required for spec decode. + "use_v2_block_manager": True, + "tensor_parallel_size": 2, + + # Use AsyncLLM engine, so that the engine runs in its own process. + # Otherwise, since vLLM does not follow true SPMD, the test runner + # process will have both the engine and the rank0 worker. NCCL is not + # cleaned up properly, and its server host thread leaks, causing the + # second run of the test to fail with internal NCCL error. + "use_async": True, + }]) +@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) +@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) +@pytest.mark.parametrize("test_llm_kwargs", [ + { + "speculative_model": "JackFram/llama-68m", + "num_speculative_tokens": 3, + }, + { + "speculative_model": "[ngram]", + "num_speculative_tokens": 5, + "ngram_prompt_lookup_max": 3, + }, +]) +@pytest.mark.parametrize("batch_size", [2]) +@pytest.mark.parametrize( + "output_len", + [ + # Use smaller output len for fast test. + 32, + ]) +@pytest.mark.parametrize("seed", [1]) +def test_target_model_tp_gt_1(baseline_llm_generator, test_llm_generator, + batch_size: int, output_len: int): + """Verify greedy equality when tensor parallelism is used. + """ + if is_hip(): + pytest.skip("hip is not well-supported yet") + run_greedy_equality_correctness_test(baseline_llm_generator, + test_llm_generator, + batch_size, + max_output_len=output_len, + force_output_len=True) diff --git a/tests/spec_decode/e2e/test_multistep_correctness.py b/tests/spec_decode/e2e/test_multistep_correctness.py index d2da039e84c07..94d71fb012727 100644 --- a/tests/spec_decode/e2e/test_multistep_correctness.py +++ b/tests/spec_decode/e2e/test_multistep_correctness.py @@ -611,40 +611,3 @@ def test_many_k(baseline_llm_generator, test_llm_generator, batch_size: int, batch_size, max_output_len=output_len, force_output_len=True) - - -@pytest.mark.parametrize( - "common_llm_kwargs", - [{ - # Required for spec decode. - "use_v2_block_manager": True, - - # Verify equality when cuda graphs allowed. - "enforce_eager": False, - "model": "JackFram/llama-68m", - }]) -@pytest.mark.parametrize( - "per_test_common_llm_kwargs", - [ - { - # Identical models. - "speculative_model": "JackFram/llama-68m", - "num_speculative_tokens": 5, - }, - ]) -@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) -@pytest.mark.parametrize("test_llm_kwargs", [{}]) -@pytest.mark.parametrize("batch_size", [8]) -@pytest.mark.parametrize("output_len", [32]) -@pytest.mark.parametrize("seed", [1]) -def test_spec_decode_cuda_graph(baseline_llm_generator, test_llm_generator, - batch_size, output_len): - """Verify spec decode equality when cuda graphs are enabled. - """ - run_greedy_equality_correctness_test( - baseline_llm_generator, - test_llm_generator, - batch_size, - max_output_len=output_len, - force_output_len=True, - ) diff --git a/vllm/distributed/communication_op.py b/vllm/distributed/communication_op.py index 9cc776f8324f2..f8ee0f9796bcd 100644 --- a/vllm/distributed/communication_op.py +++ b/vllm/distributed/communication_op.py @@ -219,16 +219,16 @@ def broadcast_tensor_dict( to broadcast the metadata of the dict (e.g. dict structure, tensor sizes, dtypes). """ + # Bypass the function if we are using only 1 GPU. + if (not torch.distributed.is_initialized() + or torch.distributed.get_world_size(group=group) == 1): + return tensor_dict + group = group or torch.distributed.group.WORLD metadata_group = metadata_group or get_cpu_world_group() ranks = torch.distributed.get_process_group_ranks(group) assert src in ranks, f"Invalid src rank ({src})" - # Bypass the function if we are using only 1 GPU. - world_size = torch.distributed.get_world_size(group=group) - if world_size == 1: - return tensor_dict - rank = torch.distributed.get_rank() if rank == src: metadata_list: List[Tuple[Any, Any]] = [] diff --git a/vllm/executor/gpu_executor.py b/vllm/executor/gpu_executor.py index 2b72b31b5f070..3ad201f4757ec 100644 --- a/vllm/executor/gpu_executor.py +++ b/vllm/executor/gpu_executor.py @@ -15,14 +15,13 @@ class GPUExecutor(ExecutorBase): def _init_executor(self) -> None: """Initialize the worker and load the model. - - If speculative decoding is enabled, we instead create the speculative - worker. """ - if self.speculative_config is None: - self._init_non_spec_worker() - else: - self._init_spec_worker() + assert self.parallel_config.world_size == 1, ( + "GPUExecutor only supports single GPU.") + + self.driver_worker = self._create_worker() + self.driver_worker.init_device() + self.driver_worker.load_model() def _get_worker_kwargs( self, @@ -45,6 +44,7 @@ def _get_worker_kwargs( distributed_init_method=distributed_init_method, lora_config=self.lora_config, vision_language_config=self.vision_language_config, + speculative_config=self.speculative_config, is_driver_worker=rank == 0, ) @@ -52,59 +52,22 @@ def _create_worker(self, local_rank: int = 0, rank: int = 0, distributed_init_method: Optional[str] = None): + + if self.speculative_config is None: + worker_module_name = "vllm.worker.worker" + worker_class_name = "Worker" + else: + worker_module_name = "vllm.spec_decode.spec_decode_worker" + worker_class_name = "create_spec_worker" + wrapper = WorkerWrapperBase( - worker_module_name="vllm.worker.worker", - worker_class_name="Worker", + worker_module_name=worker_module_name, + worker_class_name=worker_class_name, ) wrapper.init_worker(**self._get_worker_kwargs(local_rank, rank, distributed_init_method)) return wrapper.worker - def _init_non_spec_worker(self): - assert self.parallel_config.world_size == 1, ( - "GPUExecutor only supports single GPU.") - - self.driver_worker = self._create_worker() - self.driver_worker.init_device() - self.driver_worker.load_model() - - def _init_spec_worker(self): - """Initialize a SpecDecodeWorker, using a draft model for proposals. - """ - assert self.speculative_config is not None - - from vllm.spec_decode.spec_decode_worker import SpecDecodeWorker - - target_worker = self._create_worker() - - draft_worker_kwargs = self._get_worker_kwargs() - # Override draft-model specific worker args. - draft_worker_kwargs.update( - model_config=self.speculative_config.draft_model_config, - parallel_config=self.speculative_config.draft_parallel_config, - ngram_prompt_lookup_max=self.speculative_config. - ngram_prompt_lookup_max, - ngram_prompt_lookup_min=self.speculative_config. - ngram_prompt_lookup_min, - # TODO allow draft-model specific load config. - #load_config=self.load_config, - ) - - spec_decode_worker = SpecDecodeWorker.create_worker( - scorer_worker=target_worker, - draft_worker_kwargs=draft_worker_kwargs, - disable_by_batch_size=self.speculative_config. - speculative_disable_by_batch_size, - ) - - assert self.parallel_config.world_size == 1, ( - "GPUExecutor only supports single GPU.") - - self.driver_worker = spec_decode_worker - - # Load model handled in spec decode worker. - self.driver_worker.init_device() - def determine_num_available_blocks(self) -> Tuple[int, int]: """Determine the number of available KV blocks by invoking the underlying worker. diff --git a/vllm/executor/ray_gpu_executor.py b/vllm/executor/ray_gpu_executor.py index 9cb03ec8c3f5a..dd3ee60682d30 100644 --- a/vllm/executor/ray_gpu_executor.py +++ b/vllm/executor/ray_gpu_executor.py @@ -28,9 +28,6 @@ class RayGPUExecutor(DistributedGPUExecutor): def _init_executor(self) -> None: - assert (not self.speculative_config - ), "Speculative decoding not yet supported for RayGPU backend." - assert self.parallel_config.distributed_executor_backend == "ray" placement_group = self.parallel_config.placement_group @@ -90,14 +87,22 @@ def _init_workers_ray(self, placement_group: "PlacementGroup", placement_group_capture_child_tasks=True, placement_group_bundle_index=bundle_id, ) + + if self.speculative_config is not None: + worker_module_name = "vllm.spec_decode.spec_decode_worker" + worker_class_name = "create_spec_worker" + else: + worker_module_name = "vllm.worker.worker" + worker_class_name = "Worker" + worker = ray.remote( num_cpus=0, num_gpus=num_gpus, scheduling_strategy=scheduling_strategy, **ray_remote_kwargs, )(RayWorkerWrapper).remote( - worker_module_name="vllm.worker.worker", - worker_class_name="Worker", + worker_module_name=worker_module_name, + worker_class_name=worker_class_name, trust_remote_code=self.model_config.trust_remote_code, ) @@ -107,8 +112,8 @@ def _init_workers_ray(self, placement_group: "PlacementGroup", # as the resource holder for the driver process. self.driver_dummy_worker = worker self.driver_worker = RayWorkerWrapper( - worker_module_name="vllm.worker.worker", - worker_class_name="Worker", + worker_module_name=worker_module_name, + worker_class_name=worker_class_name, trust_remote_code=self.model_config.trust_remote_code, ) else: diff --git a/vllm/spec_decode/spec_decode_worker.py b/vllm/spec_decode/spec_decode_worker.py index a4e759095b294..ef17b8c1e2cc0 100644 --- a/vllm/spec_decode/spec_decode_worker.py +++ b/vllm/spec_decode/spec_decode_worker.py @@ -3,6 +3,7 @@ import torch +from vllm.distributed.communication_op import broadcast_tensor_dict from vllm.logger import init_logger from vllm.model_executor.layers.rejection_sampler import RejectionSampler from vllm.sequence import (ExecuteModelRequest, SamplerOutput, @@ -17,11 +18,43 @@ get_all_num_logprobs, get_all_seq_ids, get_sampled_token_logprobs, nvtx_range, split_batch_by_proposal_len) +from vllm.worker.worker import Worker from vllm.worker.worker_base import LoraNotSupportedWorkerBase, WorkerBase logger = init_logger(__name__) +def create_spec_worker(*args, **kwargs) -> "SpecDecodeWorker": + """Helper method that is the entrypoint for Executors which use + WorkerWrapper. It constructs a SpecDecodeWorker from the speculative config. + """ + assert "speculative_config" in kwargs + speculative_config = kwargs.get("speculative_config") + assert speculative_config is not None + + target_worker = Worker(*args, **kwargs) + + draft_worker_kwargs = kwargs.copy() + # Override draft-model specific worker args. + draft_worker_kwargs.update( + model_config=speculative_config.draft_model_config, + parallel_config=speculative_config.draft_parallel_config, + ngram_prompt_lookup_max=speculative_config.ngram_prompt_lookup_max, + ngram_prompt_lookup_min=speculative_config.ngram_prompt_lookup_min, + # TODO allow draft-model specific load config. + #load_config=load_config, + ) + + spec_decode_worker = SpecDecodeWorker.create_worker( + scorer_worker=target_worker, + draft_worker_kwargs=draft_worker_kwargs, + disable_by_batch_size=speculative_config. + speculative_disable_by_batch_size, + ) + + return spec_decode_worker + + class SpecDecodeWorker(LoraNotSupportedWorkerBase): """Worker which implements speculative decoding. @@ -142,6 +175,9 @@ def init_device(self) -> None: self._configure_model_sampler_for_spec_decode() + def load_model(self, *args, **kwargs): + pass + def _configure_model_sampler_for_spec_decode(self): """Configure model sampler to emit GPU tensors. This allows spec decode to keep data on device without transferring to CPU and serializing, @@ -195,39 +231,97 @@ def initialize_cache(self, num_gpu_blocks: int, self.proposer_worker.initialize_cache(num_gpu_blocks=num_gpu_blocks, num_cpu_blocks=num_cpu_blocks) + def _broadcast_control_flow_decision( + self, + execute_model_req: Optional[ExecuteModelRequest] = None, + disable_all_speculation: bool = False) -> Tuple[int, bool]: + """Broadcast how many lookahead slots are scheduled for this step, and + whether all speculation is disabled, to all non-driver workers. + + This is required as if the number of draft model runs changes + dynamically, the non-driver workers won't know unless we perform a + communication to inform then. + + Returns the broadcasted num_lookahead_slots and disable_all_speculation. + """ + + if self.rank == self._driver_rank: + assert execute_model_req is not None + + broadcast_dict = dict( + num_lookahead_slots=execute_model_req.num_lookahead_slots, + disable_all_speculation=disable_all_speculation, + ) + broadcast_tensor_dict(broadcast_dict, src=self._driver_rank) + else: + assert execute_model_req is None + broadcast_dict = broadcast_tensor_dict(src=self._driver_rank) + + return (broadcast_dict["num_lookahead_slots"], + broadcast_dict["disable_all_speculation"]) + @torch.inference_mode() def execute_model( - self, - execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]: + self, + execute_model_req: Optional[ExecuteModelRequest] = None + ) -> List[SamplerOutput]: """Perform speculative decoding on the input batch. """ - assert execute_model_req.seq_group_metadata_list is not None, ( - "speculative decoding " - "requires non-None seq_group_metadata_list") + disable_all_speculation = False + if self.rank == self._driver_rank: + disable_all_speculation = self._should_disable_all_speculation( + execute_model_req) + + (num_lookahead_slots, + disable_all_speculation) = self._broadcast_control_flow_decision( + execute_model_req, disable_all_speculation) + + if self.rank == self._driver_rank: + assert execute_model_req is not None + assert execute_model_req.seq_group_metadata_list is not None, ( + "speculative decoding requires non-None seq_group_metadata_list" + ) + + self._maybe_disable_speculative_tokens( + disable_all_speculation, + execute_model_req.seq_group_metadata_list) + + # If no spec tokens, call the proposer and scorer workers normally. + # Used for prefill. + if num_lookahead_slots == 0 or len( + execute_model_req.seq_group_metadata_list) == 0: + return self._run_no_spec(execute_model_req, + skip_proposer=disable_all_speculation) + + return self._run_speculative_decoding_step(execute_model_req, + num_lookahead_slots) + else: + self._run_non_driver_rank(num_lookahead_slots) + return [] + def _should_disable_all_speculation( + self, execute_model_req: ExecuteModelRequest) -> bool: # When the batch size is too large, disable speculative decoding # to stop trading off throughput for latency. - disable_all = (execute_model_req.running_queue_size >= - self.disable_by_batch_size) - if disable_all: - for seq_group_metadata in execute_model_req.seq_group_metadata_list: - # Once num_speculative_tokens is set to 0, the spec decode - # of this request will be disabled forever. - # TODO(comaniac): We currently store spec decoding specific - # state in the global data structure, but we should maintain - # this state within spec decode worker. - seq_group_metadata.num_speculative_tokens = 0 - - # If no spec tokens, call the proposer and scorer workers normally. - # This happens for prefill, or when the spec decode is disabled - # for this batch. - if execute_model_req.num_lookahead_slots == 0 or len( - execute_model_req.seq_group_metadata_list) == 0: - return self._run_no_spec(execute_model_req, - skip_proposer=disable_all) - - return self._run_speculative_decoding_step(execute_model_req) + disable_all_speculation = (execute_model_req.running_queue_size >= + self.disable_by_batch_size) + + return disable_all_speculation + + def _maybe_disable_speculative_tokens( + self, disable_all_speculation: bool, + seq_group_metadata_list: List[SequenceGroupMetadata]) -> None: + if not disable_all_speculation: + return + + for seq_group_metadata in seq_group_metadata_list: + # Once num_speculative_tokens is set to 0, the spec decode + # of this request will be disabled forever. + # TODO(comaniac): We currently store spec decoding specific + # state in the global data structure, but we should maintain + # this state within spec decode worker. + seq_group_metadata.num_speculative_tokens = 0 @nvtx_range("spec_decode_worker._run_no_spec") def _run_no_spec(self, execute_model_req: ExecuteModelRequest, @@ -252,10 +346,28 @@ def _run_no_spec(self, execute_model_req: ExecuteModelRequest, sampler_output.logprobs = None return [sampler_output] + def _run_non_driver_rank(self, num_lookahead_slots: int) -> None: + """Run proposer and verifier model in non-driver workers. This is used + for both speculation cases (num_lookahead_slots>0) and non-speculation + cases (e.g. prefill). + """ + # In non-driver workers the input is None + execute_model_req = None + + # Even if num_lookahead_slots is zero, we want to run the proposer model + # as it may have KV. + # + # We run the proposer once per lookahead slot. In the future we should + # delegate how many times it runs to the proposer. + for _ in range(max(num_lookahead_slots, 1)): + self.proposer_worker.execute_model(execute_model_req) + + self.scorer_worker.execute_model(execute_model_req) + @nvtx_range("spec_decode_worker._run_speculative_decoding_step") def _run_speculative_decoding_step( - self, - execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]: + self, execute_model_req: ExecuteModelRequest, + num_lookahead_slots: int) -> List[SamplerOutput]: """Execute a single step of speculative decoding. This invokes the proposer worker to get k speculative tokens for each @@ -264,6 +376,7 @@ def _run_speculative_decoding_step( Returns a list of SamplerOutput, each containing a single token per sequence. """ + assert num_lookahead_slots == execute_model_req.num_lookahead_slots # Generate proposals using draft worker. proposals = self.proposer_worker.get_spec_proposals(execute_model_req) @@ -455,6 +568,10 @@ def rank(self): def device(self): return self.scorer_worker.device + @property + def _driver_rank(self) -> int: + return 0 + def get_cache_block_size_bytes(self): """Return the size of a cache block in bytes. diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index 82cf58101a95b..618e96b60b6c7 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -8,7 +8,7 @@ from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, ModelConfig, ParallelConfig, SchedulerConfig, - VisionLanguageConfig) + SpeculativeConfig, VisionLanguageConfig) from vllm.distributed import (broadcast_tensor_dict, ensure_model_parallel_initialized, init_distributed_environment, @@ -43,6 +43,7 @@ def __init__( distributed_init_method: str, lora_config: Optional[LoRAConfig] = None, vision_language_config: Optional[VisionLanguageConfig] = None, + speculative_config: Optional[SpeculativeConfig] = None, is_driver_worker: bool = False, ) -> None: self.model_config = model_config diff --git a/vllm/worker/worker_base.py b/vllm/worker/worker_base.py index fb32feaca0c94..1f04f821eb0f0 100644 --- a/vllm/worker/worker_base.py +++ b/vllm/worker/worker_base.py @@ -121,7 +121,7 @@ def update_environment_variables(envs: Dict[str, str]) -> None: def init_worker(self, *args, **kwargs): """ Actual initialization of the worker class, and set up - function tracing if required. + function tracing if required. Arguments are passed to the worker class constructor. """ enable_trace_function_call_for_thread()