File tree Expand file tree Collapse file tree 1 file changed +6
-3
lines changed Expand file tree Collapse file tree 1 file changed +6
-3
lines changed Original file line number Diff line number Diff line change @@ -66,7 +66,10 @@ def merge_attn_states_kernel(
6666 max_lse = tl .maximum (p_lse , s_lse )
6767 p_lse = p_lse - max_lse
6868 s_lse = s_lse - max_lse
69- out_se = (tl .exp (p_lse ) + tl .exp (s_lse ))
69+ # Will reuse precomputed Exp values for scale factor computation.
70+ p_se = tl .exp (p_lse )
71+ s_se = tl .exp (s_lse )
72+ out_se = (p_se + s_se )
7073
7174 if OUTPUT_LSE :
7275 out_lse = tl .log (out_se ) + max_lse
@@ -84,8 +87,8 @@ def merge_attn_states_kernel(
8487 # NOTE(woosuk): Be careful with the numerical stability.
8588 # We should compute the scale first, and then multiply it with the output.
8689 # Do not multiply the output with tl.exp(p_lse) or tl.exp(s_lse) directly.
87- p_scale = tl . exp ( p_lse ) / out_se
88- s_scale = tl . exp ( s_lse ) / out_se
90+ p_scale = p_se / out_se
91+ s_scale = s_se / out_se
8992 out = p_out * p_scale + s_out * s_scale
9093 tl .store (output + token_idx * num_heads * HEAD_SIZE +
9194 head_idx * HEAD_SIZE + head_arange ,
You can’t perform that action at this time.
0 commit comments