@@ -509,6 +509,19 @@ def get_model(self) -> nn.Module:
509509 def get_supported_tasks (self ) -> tuple [SupportedTask , ...]:
510510 return self .model_runner .get_supported_tasks ()
511511
512+ def annotate_profile (self , scheduler_output ):
513+ # add trace annotation so that we can easily distinguish
514+ # new/cached request numbers in each iteration
515+ if not self .profiler :
516+ return nullcontext ()
517+
518+ num_new = len (scheduler_output .scheduled_new_reqs )
519+ num_cached = len (scheduler_output .scheduled_cached_reqs .req_ids )
520+
521+ return torch .profiler .record_function (
522+ f"execute_new_{ num_new } _cached_{ num_cached } "
523+ )
524+
512525 @torch .inference_mode ()
513526 def sample_tokens (
514527 self , grammar_output : "GrammarOutput | None"
@@ -536,9 +549,12 @@ def execute_model(
536549 )
537550 )
538551
539- output = self .model_runner .execute_model (scheduler_output , intermediate_tensors )
540- if isinstance (output , (ModelRunnerOutput , NoneType )):
541- return output
552+ with self .annotate_profile (scheduler_output ):
553+ output = self .model_runner .execute_model (
554+ scheduler_output , intermediate_tensors
555+ )
556+ if isinstance (output , (ModelRunnerOutput , NoneType )):
557+ return output
542558
543559 assert isinstance (output , IntermediateTensors )
544560 parallel_config = self .vllm_config .parallel_config
0 commit comments