File tree Expand file tree Collapse file tree 1 file changed +3
-6
lines changed
examples/qualcomm/oss_scripts/llama Expand file tree Collapse file tree 1 file changed +3
-6
lines changed Original file line number Diff line number Diff line change @@ -83,7 +83,6 @@ def _model_call(self, inps):
83
83
inps ,
84
84
self ._model ,
85
85
self ._tokenizer ,
86
- self .ar_len ,
87
86
self .max_seq_length ,
88
87
use_i64_token = self .use_i64_token ,
89
88
collect_logits = True ,
@@ -458,15 +457,13 @@ def prefill_inference(
458
457
logits , new_k_caches , new_v_caches = results
459
458
elif len (results ) == 1 :
460
459
logits = results
461
- logits = torch .argmax (logits [:, pos - 1 ], dim = - 1 ).item ()
462
- token_list .append (logits )
460
+ token = torch .argmax (logits [:, pos - 1 ], dim = - 1 ).item ()
461
+ token_list .append (token )
463
462
if collect_logits :
464
- result_logits . append ( logits )
463
+ result_logits = logits [:, : pos ]
465
464
pos += 1
466
465
467
466
logging .info (f"prefill inference result:\n { tokenizer .decode (token_list )} " )
468
- if collect_logits :
469
- result_logits = torch .cat (result_logits , dim = 1 )
470
467
return result_logits
471
468
472
469
You can’t perform that action at this time.
0 commit comments