Skip to content

Commit

Permalink
Minors
Browse files Browse the repository at this point in the history
  • Loading branch information
s5u13b committed Feb 7, 2025
1 parent b3f0688 commit 814521e
Show file tree
Hide file tree
Showing 6 changed files with 23 additions and 16 deletions.
2 changes: 1 addition & 1 deletion docs/Arguments.md
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,6 @@ usage: -m llumnix.entrypoints.vllm.api_server [-h]
- Llumnix does not support pipeline parallel currently.

`--num-schedule-steps`
- Llumnix does not support multi-step scheduling.
- Llumnix does not support multi-step scheduling currently.

Besides, Llumnix does not support sampling algorithms whose number of ouput sequences is greater than one (vllm.SamplingParams.n > 1), such as beam search.
16 changes: 2 additions & 14 deletions llumnix/backends/vllm/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy

from vllm.engine.async_llm_engine import _AsyncLLMEngine
from vllm.outputs import RequestOutput, RequestOutputFactory, EmbeddingRequestOutput
from vllm.outputs import RequestOutput
from vllm.sequence import SequenceGroup, SequenceStatus
from vllm.engine.arg_utils import EngineArgs
from vllm.utils import Counter
Expand All @@ -46,18 +46,6 @@
NO_OUTPUTS_STEP_INTERVAL = 0.01


class LlumnixRequestOutputFactory(RequestOutputFactory):
@staticmethod
def create(seq_group: SequenceGroupLlumnix, use_cache: bool = False):
# Determine the type based on a condition, for example:
if hasattr(seq_group,
'embeddings') and seq_group.embeddings is not None:
return EmbeddingRequestOutput.from_seq_group(seq_group), seq_group.server_info
if RequestStatus.is_migrating(seq_group.status):
return None
# pylint: disable=too-many-function-args
return RequestOutput.from_seq_group(seq_group, use_cache), seq_group.server_info

class LLMEngineLlumnix(_AsyncLLMEngine):
def __init__(self,
instance_id: str,
Expand Down Expand Up @@ -347,7 +335,7 @@ def get_waiting_queue(self) -> Deque[SequenceGroupLlumnix]:
def get_request_incremental_blocks(self, *args, **kwargs) -> Tuple[List[int], List[int]]:
return self.engine.scheduler[0].get_request_incremental_blocks(*args, **kwargs)

def remove_running_request(self, *args, **kwargs) -> None:
def remove_running_request(self, *args, **kwargs) -> bool:
return self.engine.scheduler[0].remove_running_request(*args, **kwargs)

def remove_waiting_request(self, *args, **kwargs) -> bool:
Expand Down
16 changes: 16 additions & 0 deletions llumnix/backends/vllm/outputs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
from vllm.outputs import RequestOutput, RequestOutputFactory, EmbeddingRequestOutput

from llumnix.backends.vllm.sequence import SequenceGroupLlumnix, RequestStatus


class LlumnixRequestOutputFactory(RequestOutputFactory):
@staticmethod
def create(seq_group: SequenceGroupLlumnix, use_cache: bool = False):
# Determine the type based on a condition, for example:
if hasattr(seq_group,
'embeddings') and seq_group.embeddings is not None:
return EmbeddingRequestOutput.from_seq_group(seq_group), seq_group.server_info
if RequestStatus.is_migrating(seq_group.status):
return None
# pylint: disable=too-many-function-args
return RequestOutput.from_seq_group(seq_group, use_cache), seq_group.server_info
2 changes: 2 additions & 0 deletions llumnix/backends/vllm/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,7 @@ def _set_status(self,

def free_dst_pre_alloc_cache(self, request_id: str = None) -> None:
if request_id:
logger.info("pop request {} from pre_alloc_cache_dict".format(request_id))
block_table = self.pre_alloc_cache_dict.pop(request_id, None)
if block_table:
block_table.free()
Expand All @@ -199,6 +200,7 @@ def free_dst_pre_alloc_cache(self, request_id: str = None) -> None:
# Clear all pre-allocated cache of dst instance when src instance encounters exception.
request_ids = list(self.pre_alloc_cache_dict.keys())
for req_id in request_ids:
logger.info("pop request {} from pre_alloc_cache_dict".format(req_id))
block_table = self.pre_alloc_cache_dict.pop(req_id, None)
if block_table:
block_table.free()
Expand Down
2 changes: 1 addition & 1 deletion llumnix/backends/vllm/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def migrate_cache(self, src_worker_handle_list, src_blocks: List[int], dst_block
total_kv_cache_size = len(src_blocks) * CacheEngine.get_cache_block_size(
self.cache_config, self.model_config, self.parallel_config)
speed = total_kv_cache_size/GiB_bytes/(end_time - start_time)
logger.info("[migration_cache] blocks_num: {}, total_kv_cache_size: {}, time: {}s, speed: {}GB/s."
logger.info("Migrate kv cache done, blocks_num: {}, total_kv_cache_size: {}, time: {}s, speed: {}GB/s."
.format(len(src_blocks), convert_bytes(total_kv_cache_size), end_time-start_time, speed))

def do_recv(self, *args, **kwargs):
Expand Down
1 change: 1 addition & 0 deletions llumnix/entrypoints/vllm/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ async def generate(self,
**kwargs) -> AsyncStream:
if sampling_params.n > 1:
raise ValueError("Unsupported feature: multiple sequence decoding")
logger.info("entrypoints receive request {}".format(request_id))
# pylint: disable=unexpected-keyword-arg
results_generator = AsyncStream(request_id, cancel=partial(self.abort, verbose=False))
self.request_streams[request_id] = results_generator
Expand Down

0 comments on commit 814521e

Please sign in to comment.