Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BugFix][Refactor] Fix some bugs and refine codes for large scale simulator test #93

Open
wants to merge 55 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
55 commits
Select commit Hold shift + click to select a range
183676f
Refine logger output text
s5u13b Jan 13, 2025
43a9577
Customize prefix for actor logs
s5u13b Jan 13, 2025
d0779d8
Upgrade logger to vLLM v0.6.6.post1
s5u13b Jan 13, 2025
1f6938a
Reorganize logger
s5u13b Jan 14, 2025
6fa2cd3
Remove date format
s5u13b Jan 14, 2025
1b1598d
Add constants module
s5u13b Jan 14, 2025
00aef3d
Log ray id for logging
s5u13b Jan 14, 2025
6fa01fc
Refine logging handlers configuration
s5u13b Jan 14, 2025
8fb6dac
Fix lint
s5u13b Jan 14, 2025
1890445
Refine logger
s5u13b Jan 14, 2025
0edd964
Optimize constants
s5u13b Jan 17, 2025
9ca22eb
Fix constants
s5u13b Jan 15, 2025
fef0cd6
Refactor
s5u13b Jan 16, 2025
b101bcd
Fix benchmark_serving
s5u13b Jan 16, 2025
25016df
Minors
s5u13b Jan 16, 2025
4108477
Minors
s5u13b Jan 17, 2025
61a30d3
Add poll instance infos and migration tasks log
s5u13b Jan 17, 2025
c5ee9ef
Minors
s5u13b Jan 20, 2025
d43fac5
Minors
s5u13b Jan 20, 2025
92b7f1d
Minors
s5u13b Jan 20, 2025
72feec3
Fix
s5u13b Jan 20, 2025
b7a600b
Minors
s5u13b Jan 20, 2025
d8d17c2
Fix
s5u13b Jan 20, 2025
1b8bff9
Minors
s5u13b Jan 20, 2025
b9b445c
Fix
s5u13b Jan 20, 2025
c55b2cd
Minors
s5u13b Jan 21, 2025
1cdba8e
Reorg simulator files
s5u13b Jan 21, 2025
c0545ea
Minors
s5u13b Jan 22, 2025
37eadf5
Assert enable_scaling
s5u13b Jan 22, 2025
467b49b
Minors
s5u13b Jan 22, 2025
9e5aa22
Set max_instances for auto scale up
s5u13b Jan 22, 2025
7156b7e
Add retry bind address for zmq server
s5u13b Jan 22, 2025
cf588bb
Fix lint
s5u13b Jan 22, 2025
5888e58
Fix unit test
s5u13b Jan 22, 2025
3651cf8
Refine dispatch scheduler implementation
s5u13b Jan 23, 2025
5c545a6
Support power-of-k-choice for dispatch
s5u13b Jan 23, 2025
1d74b1b
Fix lint
s5u13b Jan 23, 2025
7d28e36
Fix lint
s5u13b Feb 7, 2025
d4028be
Fix global scheduler unit test
s5u13b Feb 7, 2025
c358075
Fix entrypoints unit test
s5u13b Feb 7, 2025
6a68c94
Squashed commit of the following:
s5u13b Feb 7, 2025
8d29dc0
Fix host, num_cpus, serve
s5u13b Feb 8, 2025
cffff07
Minors
s5u13b Feb 8, 2025
c715e02
Simulator test done
s5u13b Feb 11, 2025
26e2618
Fix manager unit test
s5u13b Feb 11, 2025
76b2fd6
Fix init_instances and simulator test
s5u13b Feb 12, 2025
24c9df8
Fix simulator test
s5u13b Feb 12, 2025
ea91e07
Minors
s5u13b Feb 12, 2025
13e68b1
Fix ip address
s5u13b Feb 12, 2025
aec7700
Refine instance ready & migration size sort
s5u13b Feb 12, 2025
ff1f317
Fix lint
s5u13b Feb 12, 2025
bf7769c
Refine timestamps
s5u13b Feb 12, 2025
c91f34f
Resort manager and launcher functions & Fix test_manager
s5u13b Feb 12, 2025
2c4cc50
Fix lint
s5u13b Feb 12, 2025
4b38e5e
Fix correctness test
s5u13b Feb 12, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions benchmark/benchmark_serving.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,7 +372,7 @@ def __init__(self):
self._decode_sum_latencies = []
self._all_decode_token_latencies = []
self._inference_latencies = []
self._per_token_latencies_breakdown_dict = []
self._per_token_latency_breakdown_list = []

def measure(self, f):
async def measured(*args, **kwargs):
Expand Down Expand Up @@ -400,9 +400,9 @@ async def measured(*args, **kwargs):
self._all_token_latencies.append(lat_arr)
self._decode_sum_latencies.append(decode_sum_latency)
self._all_decode_token_latencies.extend(lat_arr[1:,1])
if 'per_token_latency_breakdown_dict' in output:
self._inference_latencies.append(np.mean(output['per_token_latency_breakdown_dict']['step_latency_engine']))
self._per_token_latencies_breakdown_dict.append(output['per_token_latency_breakdown_dict'])
self._inference_latencies.append(0.0)
if 'per_token_latency_breakdown_list' in output:
self._per_token_latency_breakdown_list.append(output['per_token_latency_breakdown_list'])
return prompt, output
return measured

Expand Down Expand Up @@ -494,7 +494,7 @@ async def benchmark(
m._decode_sum_latencies, \
m._request_lens, \
m._all_decode_token_latencies, \
m._per_token_latencies_breakdown_dict
m._per_token_latency_breakdown_list

def gen_random_response_lens(distribution: str, len_mean, len_range, num_prompts):
if distribution == 'uniform':
Expand Down Expand Up @@ -785,7 +785,7 @@ def main():
decode_sum_latencies, \
request_lens, \
all_decode_token_latencies, \
per_token_latencies_breakdown_dict = asyncio.run(benchmark(
per_token_latency_breakdown_list = asyncio.run(benchmark(
backend,
tokenizer,
prompts,
Expand Down Expand Up @@ -823,8 +823,8 @@ def main():
"decode_sum_latencies": decode_sum_latencies,
"all_decode_token_latencies": all_decode_token_latencies,
"inference_latencies": inference_latencies,
"per_token_latencies_breakdown_dict": per_token_latencies_breakdown_dict,
"throughput": throughput,
"per_token_latency_breakdown_list": per_token_latency_breakdown_list,
"throughput": throughput,
"instance_num": avg_instance_num})
json.dump(results, f)

Expand Down
11 changes: 8 additions & 3 deletions docs/Arguments.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,11 @@ usage: -m llumnix.entrypoints.vllm.api_server [-h]
[--scaling-load-metric {remaining_steps,usage_ratio}]
[--polling-interval POLLING_INTERVAL]
[--dispatch-policy {balanced,load,queue,rr}]
[--power-of-k-choice POWER_OF_K_CHOICE]
[--enable-migration]
[--enable-defrag]
[--pair-migration-frequency PAIR_MIGRATION_FREQUENCY]
[--pair-migration-policy {balanced,defrag_constrained,defrag_relaxed}]
[--pair-migration-policy {balanced,defrag}]
[--migrate-out-threshold MIGRATE_OUT_THRESHOLD]
[--request-migration-policy {LCR,SR,LR,FCW,FCWSR}]
[--enable-scaling]
Expand Down Expand Up @@ -139,6 +140,10 @@ usage: -m llumnix.entrypoints.vllm.api_server [-h]
- Possible choices: balanced, load, queue, rr
- Default: "load"

`--power-of-k-choice`
Copy link
Collaborator

@zhypku zhypku Feb 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't recommend using the term power-of-k-choices here. It's a well-established concept in the specific context of decentralized load balancing. But here it seems to me that you are only doing some sort of randomizing inside a centralized scheduler. Calling it power-of-k-choices is pretty misleading.

- Number of candidate instances for dispatch policy
- Default: 1

`--enable-migration`
- Enable migrate requests between instances.

Expand All @@ -151,8 +156,8 @@ usage: -m llumnix.entrypoints.vllm.api_server [-h]

`--pair-migration-policy`
- Pair migration policy.
- Possible choices: balanced, defrag_constrained, defrag_relaxed
- Default: "defrag_constrained"
- Possible choices: balanced, defrag
- Default: "defrag"

`--migrate-out-threshold`
- Migrate out instance load threshold.
Expand Down
2 changes: 1 addition & 1 deletion docs/Simulator.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ Llumnix can generate latency data from logs. After run a real benchmark with `--

After running profiling with `python llumnix.backends.profiling.py`. You can get a `$PROFILING_RESULT_FILE_PATH.pkl`

Then, you can run simulator with `--profiling-result-file-path PROFILING_RESULT_FILE_PATH`.
Then, you can run simulator with `--simulator-mode` and `--profiling-result-file-path PROFILING_RESULT_FILE_PATH`.


```
Expand Down
23 changes: 18 additions & 5 deletions llumnix/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ def add_argument(self, *args, **kwargs):
kwargs['default'] = None
super().add_argument(*args, **kwargs)


@dataclass
class EntrypointsArgs:
host: str = None
Expand Down Expand Up @@ -112,13 +113,15 @@ def add_cli_args(parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
help="path to config file of arguments")
return parser


@dataclass
class ManagerArgs:
initial_instances: int = None

polling_interval: float = None
dispatch_policy: str = None
scaling_load_metric: str = None
power_of_k_choice: int = None

enable_migration: bool = None
pair_migration_frequency: int = None
Expand Down Expand Up @@ -174,6 +177,7 @@ def create_global_scheduler_config(self, is_group_kind_migration_backend) -> Tup
# Create the GlobalScheduler Configuration.
global_scheduler_config = GlobalSchedulerConfig(self.initial_instances,
self.dispatch_policy,
self.power_of_k_choice,
self.pair_migration_policy,
self.migrate_out_threshold,
self.scaling_policy,
Expand Down Expand Up @@ -205,6 +209,8 @@ def check_args(cls, args: 'ManagerArgs', parser: argparse.ArgumentParser):
assert not args.enable_port_offset_store or args.enable_port_increment, \
"Set enable_port_increment when enable_port_offset_store"

assert not args.enable_scaling, "Proactive auto-scaling is deprecated now, all auto-scaling related args will not take effects."

@staticmethod
def add_cli_args(parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
parser.add_argument('--initial-instances',
Expand All @@ -226,6 +232,13 @@ def add_cli_args(parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
'* "queue" dispatch request to the instance with minimum waiting request queue length.\n'
'* "flood" dispatch request to the instance with maximum requests dispatched.\n'
'* "rr" dispatch requests with round-robin policy.\n')
parser.add_argument('--power-of-k-choice',
type=int,
help='number of candidate instances for dispatch policy.\n\n'
'The candidate instances are first selected according to the load'
'(including factors such as load, queue size, etc.) based on the dispatch policy,'
'and then one of them is randomly chosen to receive the request for better load balancing.')

parser.add_argument('--enable-migration',
action='store_true',
help='enable migrate requests between instances')
Expand All @@ -234,13 +247,11 @@ def add_cli_args(parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
help='pair migration frequency')
parser.add_argument('--pair-migration-policy',
type=str,
choices=['balanced', 'defrag_constrained', 'defrag_relaxed'],
choices=['balanced', 'defrag'],
help='The pair migration policy.\n\n'
'* "balanced" pair migration to make the instance load of instance more balanced.\n'
'* "defrag_constrained" pair migration without balanced constraint to '
'achieve defragmentation thoroughly (with instance constraints).\n'
'* "defrag_relaxed" pair migration to without balanced constraint '
'to achieve defragmentation thoroughly (without instance constraints).\n')
'* "defrag" pair migration without balanced constraint to '
'achieve defragmentation thoroughly (with instance constraints).\n')
parser.add_argument('--migrate-out-threshold',
type=float,
help='migrate out instance load threshold')
Expand Down Expand Up @@ -289,11 +300,13 @@ def add_cli_args(parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
help='the prefill decode ratio used in gloabl launch model e.g. "1:1"')
return parser


@dataclass
class LaunchArgs:
launch_mode: LaunchMode = None
backend_type: BackendType = None


@dataclass
class InstanceArgs:
instance_type: str = None
Expand Down
4 changes: 4 additions & 0 deletions llumnix/backends/profiling.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
def _pad_to_alignment(x, multiple_of):
return x + ((-1*x) % multiple_of)


@dataclasses.dataclass
class LatencyMemData:
# The latency of each stage
Expand Down Expand Up @@ -69,6 +70,7 @@ def get_prefill_dict_kv(self):
def get_decode_dict_kv(self):
return map(list, zip(*self.decode_latency.items()))


@dataclasses.dataclass
class ProfilingResult:
"""Store the profiling result of a model."""
Expand Down Expand Up @@ -127,6 +129,7 @@ def fit_from_database(self, parallel_config: SimParallelConfig):
avg_loss += abs(sim_lat - latency_list[idx])
print(f"decode sim avg_loss={avg_loss/len(latency_list)}")


class ProfilingDatabase:
"""Store the profiling results of all the models"""
def __init__(self, database_filename: str, new_database: bool = False):
Expand Down Expand Up @@ -198,6 +201,7 @@ def get_latency_mem(backend_type: BackendType, profiling_database: ProfilingData
return latency_mem
raise ValueError(f'Unsupported simulator backend: {backend_type}')


if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
Expand Down
9 changes: 4 additions & 5 deletions llumnix/backends/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from llumnix.logging.logger import init_logger
from llumnix.utils import get_instance_name
from llumnix.internal_config import MigrationConfig
from llumnix.metrics.timestamps import set_timestamp

logger = init_logger(__name__)

Expand Down Expand Up @@ -54,16 +55,14 @@ async def put_nowait_to_servers(self,
tasks = []
for server_id, req_outputs in server_request_outputs.items():
server_info = server_info_dict[server_id]
for req_output in req_outputs:
if hasattr(req_output, 'request_timestamps'):
req_output.request_timestamps.engine_actor_put_queue_timestamp = time.time()
set_timestamp(req_outputs, 'engine_actor_put_queue_timestamp', time.time())
tasks.append(asyncio.create_task(self.request_output_queue_client.put_nowait(req_outputs, server_info)))
rets = await asyncio.gather(*tasks, return_exceptions=True)
for idx, ret in enumerate(rets):
if isinstance(ret, Exception):
server_id = list(server_request_outputs.keys())[idx]
server_info = server_info_dict[server_id]
logger.warning("Server {} is dead.".format(server_id))
logger.error("Server {} is dead, exception: {}".format(server_id, ret))
if self.request_output_queue_type == QueueType.ZMQ:
logger.warning("request output queue ip: {}, port: {}".format(server_info.request_output_queue_ip,
server_info.request_output_queue_port))
Expand Down Expand Up @@ -96,7 +95,7 @@ def init_backend_engine(instance_id: str,
engine_args)
elif backend_type == BackendType.SIM_VLLM:
# pylint: disable=import-outside-toplevel
from llumnix.backends.vllm.simulator import BackendSimVLLM
from llumnix.backends.vllm.sim_llm_engine import BackendSimVLLM
backend_engine = BackendSimVLLM(instance_id,
placement_group,
request_output_queue_type,
Expand Down
67 changes: 0 additions & 67 deletions llumnix/backends/vllm/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
# limitations under the License.

import time
import asyncio

from collections import defaultdict
from typing import Callable, Dict, List, Optional, Tuple, Type
Expand All @@ -22,7 +21,6 @@
from ray.util.placement_group import PlacementGroup

from vllm.executor.executor_base import ExecutorBase
from vllm.model_executor.layers.sampler import SamplerOutput, CompletionSequenceGroupOutput
from vllm.executor.ray_gpu_executor import RayGPUExecutor, RayGPUExecutorAsync, RayWorkerWrapper, envs, \
get_ip, get_vllm_instance_id, get_distributed_init_method, get_open_port
from vllm.worker.worker_base import WorkerBase
Expand Down Expand Up @@ -262,68 +260,3 @@ async def execute_model_async(self, *args, **kwargs):
t1 = time.time()
self.last_inference_latency = (t1 - t0) * 1000
return outputs

class SimGPUExecutor(RayGPUExecutor):
latency_mem: LatencyMemData = None
def __init__(self, *args, **kwargs) -> None:
RayGPUExecutor.__init__(self, *args, **kwargs)
self.last_inference_latency = 0
self.migration_bandwidth = self.latency_mem.migration_bandwidth
# TODO(ZeldaHuang): add swap bandwidth

self.cache_block_size = get_cache_block_size(
self.cache_config.block_size, self.model_config, self.parallel_config)
self.cache_block_size /= GiB_bytes
self.sim_cache_config = SimCacheConfig(self.cache_config.gpu_memory_utilization,
self.cache_config.block_size,
self.scheduler_config.max_num_batched_tokens)

def _init_executor(self) -> None:
pass

def determine_num_available_blocks(self) -> Tuple[int, int]:
num_gpu_blocks = self.latency_mem.cache_dict.get(self.sim_cache_config, 880)
num_cpu_blocks = 2048
return (num_gpu_blocks, num_cpu_blocks)

def initialize_cache(self, num_gpu_blocks: int,
num_cpu_blocks: int) -> None:
logger.info("# GPU blocks: {}, # CPU blocks: {}".format(num_gpu_blocks, num_cpu_blocks))

async def execute_model_async(
self,
execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
prefill_seq_len = 0
decode_seq_len = 0
decode_bs = 0
for meta_data in execute_model_req.seq_group_metadata_list:
if meta_data.is_prompt:
prefill_seq_len += meta_data.token_chunk_size
else:
decode_bs += meta_data.token_chunk_size
decode_seq_len += list(meta_data.seq_data.values())[0].get_len()
decode_bs = _pad_to_alignment(decode_bs, 8)
prefill_seq_len = _pad_to_alignment(prefill_seq_len, 8)
latency = 0
if prefill_seq_len:
latency += self.latency_mem.prefill_latency[prefill_seq_len][0] if prefill_seq_len in self.latency_mem.prefill_latency \
else model_prefill(prefill_seq_len, *self.latency_mem.prefill_model_params)
if decode_bs:
decode_meta_data = (decode_bs, decode_seq_len)
latency += self.latency_mem.decode_latency[decode_meta_data][0] if decode_meta_data in self.latency_mem.decode_latency \
else model_decode((decode_bs, decode_seq_len), *self.latency_mem.decode_model_params)
await asyncio.sleep(latency/1000)
sampler_outputs = []
for meta_data in execute_model_req.seq_group_metadata_list:
samples = []
for seq_id in meta_data.seq_data.keys():
dummy_sample_output = SequenceOutput(seq_id, 20, {20: Logprob(1.0)})
samples.append(dummy_sample_output)
if samples:
output = CompletionSequenceGroupOutput(samples, None)
sampler_outputs.append(output)
return [SamplerOutput(outputs=sampler_outputs)]

async def send_blocks(self, blocks_len) -> None:
migration_latency = (self.cache_block_size * blocks_len) / self.migration_bandwidth
await asyncio.sleep(migration_latency)
Loading