Skip to content

Commit f2dee40

Browse files
tzulingkdillon-cullinan
authored andcommitted
feat: Metrics labels for multimodal. (#2835)
Signed-off-by: tzulingk@nvidia.com <tzulingk@nvidia.com>
1 parent 4c34b78 commit f2dee40

File tree

5 files changed

+42
-17
lines changed

5 files changed

+42
-17
lines changed

examples/multimodal/components/encode_worker.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -261,7 +261,9 @@ async def init(runtime: DistributedRuntime, args: argparse.Namespace, config: Co
261261

262262
try:
263263
await asyncio.gather(
264-
generate_endpoint.serve_endpoint(handler.generate),
264+
generate_endpoint.serve_endpoint(
265+
handler.generate, metrics_labels=[("model", config.model)]
266+
),
265267
)
266268
except Exception as e:
267269
logger.error(f"Failed to serve endpoints: {e}")

examples/multimodal/components/processor.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -332,7 +332,9 @@ async def init(runtime: DistributedRuntime, args: argparse.Namespace, config: Co
332332

333333
try:
334334
await asyncio.gather(
335-
generate_endpoint.serve_endpoint(handler.generate),
335+
generate_endpoint.serve_endpoint(
336+
handler.generate, metrics_labels=[("model", config.model)]
337+
),
336338
)
337339
except Exception as e:
338340
logger.error(f"Failed to serve endpoints: {e}")

examples/multimodal/components/publisher.py

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16-
from typing import Optional
16+
from typing import List, Optional, Tuple
1717

1818
from vllm.config import VllmConfig
1919
from vllm.v1.metrics.loggers import StatLoggerBase
@@ -48,9 +48,15 @@ def log_engine_initialized(self):
4848
class DynamoStatLoggerPublisher(StatLoggerBase):
4949
"""Stat logger publisher. Wrapper for the WorkerMetricsPublisher to match the StatLoggerBase interface."""
5050

51-
def __init__(self, component: Component, dp_rank: int) -> None:
51+
def __init__(
52+
self,
53+
component: Component,
54+
dp_rank: int,
55+
metrics_labels: Optional[List[Tuple[str, str]]] = None,
56+
) -> None:
5257
self.inner = WorkerMetricsPublisher()
53-
self.inner.create_endpoint(component)
58+
metrics_labels = metrics_labels or []
59+
self.inner.create_endpoint(component, metrics_labels)
5460
self.dp_rank = dp_rank
5561
self.num_gpu_block = 1
5662
self.request_total_slots = 1
@@ -141,15 +147,23 @@ def log_engine_initialized(self) -> None:
141147
class StatLoggerFactory:
142148
"""Factory for creating stat logger publishers. Required by vLLM."""
143149

144-
def __init__(self, component: Component, dp_rank: int = 0) -> None:
150+
def __init__(
151+
self,
152+
component: Component,
153+
dp_rank: int = 0,
154+
metrics_labels: Optional[List[Tuple[str, str]]] = None,
155+
) -> None:
145156
self.component = component
146157
self.created_logger: Optional[DynamoStatLoggerPublisher] = None
147158
self.dp_rank = dp_rank
159+
self.metrics_labels = metrics_labels or []
148160

149161
def create_stat_logger(self, dp_rank: int) -> StatLoggerBase:
150162
if self.dp_rank != dp_rank:
151163
return NullStatLogger()
152-
logger = DynamoStatLoggerPublisher(self.component, dp_rank)
164+
logger = DynamoStatLoggerPublisher(
165+
self.component, dp_rank, metrics_labels=self.metrics_labels
166+
)
153167
self.created_logger = logger
154168

155169
return logger

examples/multimodal/components/video_encode_worker.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -308,7 +308,9 @@ async def init(runtime: DistributedRuntime, args: argparse.Namespace, config: Co
308308

309309
try:
310310
await asyncio.gather(
311-
generate_endpoint.serve_endpoint(handler.generate),
311+
generate_endpoint.serve_endpoint(
312+
handler.generate, metrics_labels=[("model", config.model)]
313+
),
312314
)
313315
except Exception as e:
314316
logger.error(f"Failed to serve endpoints: {e}")

examples/multimodal/components/worker.py

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525
import torch
2626
import uvloop
2727
from vllm.distributed.kv_events import ZmqEventPublisher
28-
from vllm.engine.arg_utils import AsyncEngineArgs
2928
from vllm.inputs.data import TokensPrompt
3029
from vllm.usage.usage_lib import UsageContext
3130
from vllm.utils import FlexibleArgumentParser
@@ -107,14 +106,15 @@ def endpoint_overwrite(args):
107106
def __init__(
108107
self,
109108
args: argparse.Namespace,
110-
engine_args: AsyncEngineArgs,
111109
component: Component,
112110
endpoint: Endpoint,
111+
config: Config,
113112
):
114113
self.enable_disagg = args.enable_disagg
115114
self.endpoint = args.endpoint
116115
self.downstream_endpoint = args.downstream_endpoint
117-
self.engine_args = engine_args
116+
self.engine_args = config.engine_args
117+
self.config = config
118118
self.setup_vllm_engine(component, endpoint)
119119

120120
async def async_init(self, runtime: DistributedRuntime):
@@ -142,6 +142,7 @@ def setup_vllm_engine(self, component: Component, endpoint: Endpoint):
142142
self.stats_logger = StatLoggerFactory(
143143
component,
144144
self.engine_args.data_parallel_rank or 0,
145+
metrics_labels=[("model", self.config.model)],
145146
)
146147
self.engine_client = AsyncLLM.from_vllm_config(
147148
vllm_config=vllm_config,
@@ -444,20 +445,24 @@ async def init(runtime: DistributedRuntime, args: argparse.Namespace, config: Co
444445

445446
if args.worker_type in ["prefill", "encode_prefill"]:
446447
handler: VllmBaseWorker = VllmPDWorker(
447-
args, config.engine_args, component, generate_endpoint
448+
args, component, generate_endpoint, config
448449
)
449450
elif args.worker_type == "decode":
450-
handler = VllmDecodeWorker(
451-
args, config.engine_args, component, generate_endpoint
452-
)
451+
handler = VllmDecodeWorker(args, component, generate_endpoint, config)
453452
await handler.async_init(runtime)
454453

455454
logger.info(f"Starting to serve the {args.endpoint} endpoint...")
456455

456+
metrics_labels = [("model", config.model)]
457+
457458
try:
458459
await asyncio.gather(
459-
generate_endpoint.serve_endpoint(handler.generate),
460-
clear_endpoint.serve_endpoint(handler.clear_kv_blocks),
460+
generate_endpoint.serve_endpoint(
461+
handler.generate, metrics_labels=metrics_labels
462+
),
463+
clear_endpoint.serve_endpoint(
464+
handler.clear_kv_blocks, metrics_labels=metrics_labels
465+
),
461466
)
462467
except Exception as e:
463468
logger.error(f"Failed to serve endpoints: {e}")

0 commit comments

Comments
 (0)