Skip to content

Commit a115b67

Browse files
authored
Fix tp cb (#39838)
* fixes * one more
1 parent 2c0af41 commit a115b67

File tree

2 files changed

+7
-6
lines changed

2 files changed

+7
-6
lines changed

src/transformers/generation/continuous_batching.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1106,7 +1106,8 @@ def __init__(
11061106
max_queue_size: Maximum size of the request queue (0 = unlimited)
11071107
streaming: Whether to stream tokens as they are generated
11081108
"""
1109-
self.model = model
1109+
self.model = model.eval()
1110+
generation_config = model.generation_config if generation_config is None else generation_config
11101111
self.generation_config = generation_config
11111112
self.input_queue = queue.Queue(maxsize=max_queue_size)
11121113
self.output_queue = queue.Queue()
@@ -1118,7 +1119,6 @@ def __init__(
11181119
self._request_lock = threading.Lock()
11191120
self.model.generation_config.top_p = None
11201121
self.do_sample = getattr(generation_config, "do_sample", True)
1121-
generation_config = model.generation_config if generation_config is None else generation_config
11221122
self.logit_processor = self.model._get_logits_processor(generation_config)
11231123
self.use_cuda_graph = getattr(generation_config, "use_cuda_graph", True)
11241124
self.profile = getattr(generation_config, "profile", False)
@@ -1242,15 +1242,15 @@ def __iter__(self):
12421242

12431243
@traced
12441244
def warmup(self, batch_processor):
1245-
stream = torch.cuda.Stream()
1245+
stream = torch.cuda.Stream(device=self.model.device)
12461246
stream.wait_stream(torch.cuda.current_stream())
12471247
with torch.cuda.stream(stream):
12481248
# Warmup the model with a dummy forward pass
12491249
self._generation_step(batch_processor)
12501250
torch.cuda.current_stream().wait_stream(stream)
12511251

12521252
self.graph = torch.cuda.CUDAGraph()
1253-
with torch.cuda.graph(self.graph):
1253+
with torch.cuda.graph(self.graph, stream=stream):
12541254
self._generation_step(batch_processor)
12551255

12561256
@traced
@@ -1326,7 +1326,7 @@ def _run_generation_loop(self):
13261326
is_first = True
13271327

13281328
if self.profile:
1329-
tracing_schedule = schedule(skip_first=2, warmup=3, active=200, repeat=100, wait=1)
1329+
tracing_schedule = schedule(skip_first=2, warmup=1, active=1, repeat=3, wait=1)
13301330
trace_handler = tensorboard_trace_handler(
13311331
dir_name="/fsx/arthur/transformers", use_gzip=True, worker_name="paged_compile"
13321332
)

src/transformers/integrations/flash_paged.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,5 +63,6 @@ def paged_attention_forward(
6363
# block_table=block_tables, -> torch.Tensor
6464
# **kwargs,
6565
)
66-
66+
if isinstance(attn_output, tuple):
67+
attn_output = attn_output[0]
6768
return attn_output, None

0 commit comments

Comments
 (0)