|
25 | 25 | import torch |
26 | 26 | import uvloop |
27 | 27 | from vllm.distributed.kv_events import ZmqEventPublisher |
28 | | -from vllm.engine.arg_utils import AsyncEngineArgs |
29 | 28 | from vllm.inputs.data import TokensPrompt |
30 | 29 | from vllm.usage.usage_lib import UsageContext |
31 | 30 | from vllm.utils import FlexibleArgumentParser |
@@ -107,14 +106,15 @@ def endpoint_overwrite(args): |
107 | 106 | def __init__( |
108 | 107 | self, |
109 | 108 | args: argparse.Namespace, |
110 | | - engine_args: AsyncEngineArgs, |
111 | 109 | component: Component, |
112 | 110 | endpoint: Endpoint, |
| 111 | + config: Config, |
113 | 112 | ): |
114 | 113 | self.enable_disagg = args.enable_disagg |
115 | 114 | self.endpoint = args.endpoint |
116 | 115 | self.downstream_endpoint = args.downstream_endpoint |
117 | | - self.engine_args = engine_args |
| 116 | + self.engine_args = config.engine_args |
| 117 | + self.config = config |
118 | 118 | self.setup_vllm_engine(component, endpoint) |
119 | 119 |
|
120 | 120 | async def async_init(self, runtime: DistributedRuntime): |
@@ -142,6 +142,7 @@ def setup_vllm_engine(self, component: Component, endpoint: Endpoint): |
142 | 142 | self.stats_logger = StatLoggerFactory( |
143 | 143 | component, |
144 | 144 | self.engine_args.data_parallel_rank or 0, |
| 145 | + metrics_labels=[("model", self.config.model)], |
145 | 146 | ) |
146 | 147 | self.engine_client = AsyncLLM.from_vllm_config( |
147 | 148 | vllm_config=vllm_config, |
@@ -444,20 +445,24 @@ async def init(runtime: DistributedRuntime, args: argparse.Namespace, config: Co |
444 | 445 |
|
445 | 446 | if args.worker_type in ["prefill", "encode_prefill"]: |
446 | 447 | handler: VllmBaseWorker = VllmPDWorker( |
447 | | - args, config.engine_args, component, generate_endpoint |
| 448 | + args, component, generate_endpoint, config |
448 | 449 | ) |
449 | 450 | elif args.worker_type == "decode": |
450 | | - handler = VllmDecodeWorker( |
451 | | - args, config.engine_args, component, generate_endpoint |
452 | | - ) |
| 451 | + handler = VllmDecodeWorker(args, component, generate_endpoint, config) |
453 | 452 | await handler.async_init(runtime) |
454 | 453 |
|
455 | 454 | logger.info(f"Starting to serve the {args.endpoint} endpoint...") |
456 | 455 |
|
| 456 | + metrics_labels = [("model", config.model)] |
| 457 | + |
457 | 458 | try: |
458 | 459 | await asyncio.gather( |
459 | | - generate_endpoint.serve_endpoint(handler.generate), |
460 | | - clear_endpoint.serve_endpoint(handler.clear_kv_blocks), |
| 460 | + generate_endpoint.serve_endpoint( |
| 461 | + handler.generate, metrics_labels=metrics_labels |
| 462 | + ), |
| 463 | + clear_endpoint.serve_endpoint( |
| 464 | + handler.clear_kv_blocks, metrics_labels=metrics_labels |
| 465 | + ), |
461 | 466 | ) |
462 | 467 | except Exception as e: |
463 | 468 | logger.error(f"Failed to serve endpoints: {e}") |
|
0 commit comments