@@ -375,9 +375,15 @@ def __init__(
375375 )
376376
377377 self .use_async_scheduling = self .scheduler_config .async_scheduling
378- self .async_output_copy_stream = (
379- torch .cuda .Stream () if self .use_async_scheduling else None
380- )
378+ # Separate cuda stream for overlapping transfer of sampled token ids from
379+ # GPU to CPU when async scheduling is enabled.
380+ self .async_output_copy_stream : torch .cuda .Stream | None = None
381+ # cuda event to synchronize use of reused CPU tensors between steps
382+ # when async scheduling is enabled.
383+ self .prepare_inputs_event : torch .cuda .Event | None = None
384+ if self .use_async_scheduling :
385+ self .async_output_copy_stream = torch .cuda .Stream ()
386+ self .prepare_inputs_event = torch .cuda .Event ()
381387
382388 # TODO(woosuk): Provide an option to tune the max cudagraph batch size.
383389 # The convention is different.
@@ -444,14 +450,6 @@ def __init__(
444450 (3 , self .max_num_tokens + 1 ), dtype = torch .int64
445451 )
446452
447- # CUDA event to synchronize use of reused CPU tensors between steps
448- # when async scheduling is enabled.
449- self .prepare_inputs_event : torch .cuda .Event | None = None
450- if self .use_async_scheduling :
451- self .prepare_inputs_event = torch .cuda .Event ()
452- # Start in a completed state.
453- self .prepare_inputs_event .record (torch .cuda .default_stream ())
454-
455453 # None in the first PP rank. The rest are set after load_model.
456454 self .intermediate_tensors : IntermediateTensors | None = None
457455
0 commit comments