Skip to content

Commit 6c1ef96

Browse files
authored
Fix prefill_inference (#13885)
Summary: Fixes bugs in prefill_inference function Differential Revision: D81532886
1 parent 579b91e commit 6c1ef96

File tree

1 file changed

+3
-6
lines changed

1 file changed

+3
-6
lines changed

examples/qualcomm/oss_scripts/llama/decoder_utils.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,6 @@ def _model_call(self, inps):
8383
inps,
8484
self._model,
8585
self._tokenizer,
86-
self.ar_len,
8786
self.max_seq_length,
8887
use_i64_token=self.use_i64_token,
8988
collect_logits=True,
@@ -458,15 +457,13 @@ def prefill_inference(
458457
logits, new_k_caches, new_v_caches = results
459458
elif len(results) == 1:
460459
logits = results
461-
logits = torch.argmax(logits[:, pos - 1], dim=-1).item()
462-
token_list.append(logits)
460+
token = torch.argmax(logits[:, pos - 1], dim=-1).item()
461+
token_list.append(token)
463462
if collect_logits:
464-
result_logits.append(logits)
463+
result_logits = logits[:, :pos]
465464
pos += 1
466465

467466
logging.info(f"prefill inference result:\n{tokenizer.decode(token_list)}")
468-
if collect_logits:
469-
result_logits = torch.cat(result_logits, dim=1)
470467
return result_logits
471468

472469

0 commit comments

Comments
 (0)