@@ -1106,7 +1106,8 @@ def __init__(
11061106 max_queue_size: Maximum size of the request queue (0 = unlimited)
11071107 streaming: Whether to stream tokens as they are generated
11081108 """
1109- self .model = model
1109+ self .model = model .eval ()
1110+ generation_config = model .generation_config if generation_config is None else generation_config
11101111 self .generation_config = generation_config
11111112 self .input_queue = queue .Queue (maxsize = max_queue_size )
11121113 self .output_queue = queue .Queue ()
@@ -1118,7 +1119,6 @@ def __init__(
11181119 self ._request_lock = threading .Lock ()
11191120 self .model .generation_config .top_p = None
11201121 self .do_sample = getattr (generation_config , "do_sample" , True )
1121- generation_config = model .generation_config if generation_config is None else generation_config
11221122 self .logit_processor = self .model ._get_logits_processor (generation_config )
11231123 self .use_cuda_graph = getattr (generation_config , "use_cuda_graph" , True )
11241124 self .profile = getattr (generation_config , "profile" , False )
@@ -1242,15 +1242,15 @@ def __iter__(self):
12421242
12431243 @traced
12441244 def warmup (self , batch_processor ):
1245- stream = torch .cuda .Stream ()
1245+ stream = torch .cuda .Stream (device = self . model . device )
12461246 stream .wait_stream (torch .cuda .current_stream ())
12471247 with torch .cuda .stream (stream ):
12481248 # Warmup the model with a dummy forward pass
12491249 self ._generation_step (batch_processor )
12501250 torch .cuda .current_stream ().wait_stream (stream )
12511251
12521252 self .graph = torch .cuda .CUDAGraph ()
1253- with torch .cuda .graph (self .graph ):
1253+ with torch .cuda .graph (self .graph , stream = stream ):
12541254 self ._generation_step (batch_processor )
12551255
12561256 @traced
@@ -1326,7 +1326,7 @@ def _run_generation_loop(self):
13261326 is_first = True
13271327
13281328 if self .profile :
1329- tracing_schedule = schedule (skip_first = 2 , warmup = 3 , active = 200 , repeat = 100 , wait = 1 )
1329+ tracing_schedule = schedule (skip_first = 2 , warmup = 1 , active = 1 , repeat = 3 , wait = 1 )
13301330 trace_handler = tensorboard_trace_handler (
13311331 dir_name = "/fsx/arthur/transformers" , use_gzip = True , worker_name = "paged_compile"
13321332 )
0 commit comments