@@ -67,6 +67,7 @@ def forward(
6767 Shape is [num_tokens, vocab_size]. Here, probabilities from
6868 different requests are flattened into a single tensor because
6969 this is the shape of the output logits.
70+ NOTE: `target_logits` can be updated in place to save memory.
7071 bonus_token_ids_tensor (torch.Tensor):
7172 A tensor containing bonus tokens. Shape is [batch_size, 1].
7273 Bonus tokens are added to the end of the sequence if all
@@ -83,6 +84,8 @@ def forward(
8384 '''
8485 assert metadata .max_spec_len <= MAX_SPEC_LEN
8586 # [num_tokens, vocab_size]
87+ # NOTE(woosuk): `target_logits` can be updated in place inside the
88+ # `compute_probs` function.
8689 target_probs = compute_probs (
8790 target_logits ,
8891 metadata .cu_num_draft_tokens ,
@@ -252,8 +255,8 @@ def compute_probs(
252255 replace_from = GREEDY_TEMPERATURE ,
253256 replace_to = 1 ,
254257 )
255- # TODO (woosuk): Consider using in- place op to reduce memory usage .
256- logits = logits / temperature .unsqueeze (- 1 )
258+ # NOTE (woosuk): Update `logits` in place to avoid allocating a new tensor .
259+ logits . div_ ( temperature .unsqueeze (- 1 ) )
257260
258261 # Get expanded top_k and top_p tensors.
259262 top_k = None
0 commit comments