@@ -97,12 +97,20 @@ def forward(
9797 position_ids : torch .Tensor ,
9898 hidden_states : torch .Tensor ,
9999 kv_cache : torch .Tensor ,
100- attn_metadata : AttentionMetadata ,
100+ is_prompt ,
101+ block_tables ,
102+ num_prefills ,
103+ num_prefill_tokens ,
104+ num_decode_tokens ,
105+ slot_mapping ,
106+ seq_lens ,
107+ seq_lens_tensor = None ,
108+ max_decode_seq_len = None ,
101109 ) -> torch .Tensor :
102110 qkv , _ = self .qkv_proj (hidden_states )
103111 q , k , v = qkv .chunk (chunks = 3 , dim = - 1 )
104112 q , k = self .rotary_emb (position_ids , q , k )
105- attn_output = self .attn (q , k , v , kv_cache , attn_metadata )
113+ attn_output = self .attn (q , k , v , kv_cache , is_prompt , block_tables , num_prefills , num_prefill_tokens , num_decode_tokens , slot_mapping , seq_lens , seq_lens_tensor , max_decode_seq_len )
106114 attn_output , _ = self .out_proj (attn_output )
107115 return attn_output
108116
@@ -166,15 +174,31 @@ def forward(
166174 position_ids : torch .Tensor ,
167175 hidden_states : torch .Tensor ,
168176 kv_cache : torch .Tensor ,
169- attn_metadata : AttentionMetadata ,
177+ is_prompt ,
178+ block_tables ,
179+ num_prefills ,
180+ num_prefill_tokens ,
181+ num_decode_tokens ,
182+ slot_mapping ,
183+ seq_lens ,
184+ seq_lens_tensor = None ,
185+ max_decode_seq_len = None ,
170186 ) -> torch .Tensor :
171187 residual = hidden_states
172188 hidden_states = self .ln_1 (hidden_states )
173189 attn_output = self .attn (
174190 position_ids = position_ids ,
175191 hidden_states = hidden_states ,
176192 kv_cache = kv_cache ,
177- attn_metadata = attn_metadata ,
193+ is_prompt = is_prompt ,
194+ block_tables = block_tables ,
195+ num_prefills = num_prefills ,
196+ num_prefill_tokens = num_prefill_tokens ,
197+ num_decode_tokens = num_decode_tokens ,
198+ slot_mapping = slot_mapping ,
199+ seq_lens = seq_lens ,
200+ seq_lens_tensor = seq_lens_tensor ,
201+ max_decode_seq_len = max_decode_seq_len ,
178202 )
179203 mlp_output = self .mlp (hidden_states )
180204 if self .mlp .fc_out .tp_size <= 1 and not hasattr (self , "ipex_fusion" ):
@@ -220,7 +244,15 @@ def forward(
220244 input_ids : torch .Tensor ,
221245 position_ids : torch .Tensor ,
222246 kv_caches : List [torch .Tensor ],
223- attn_metadata : AttentionMetadata ,
247+ is_prompt ,
248+ block_tables ,
249+ num_prefills ,
250+ num_prefill_tokens ,
251+ num_decode_tokens ,
252+ slot_mapping ,
253+ seq_lens ,
254+ seq_lens_tensor = None ,
255+ max_decode_seq_len = None ,
224256 ) -> torch .Tensor :
225257 hidden_states = self .wte (input_ids )
226258 for i in range (len (self .h )):
@@ -229,7 +261,15 @@ def forward(
229261 position_ids ,
230262 hidden_states ,
231263 kv_caches [i ],
232- attn_metadata ,
264+ is_prompt ,
265+ block_tables ,
266+ num_prefills ,
267+ num_prefill_tokens ,
268+ num_decode_tokens ,
269+ slot_mapping ,
270+ seq_lens ,
271+ seq_lens_tensor ,
272+ max_decode_seq_len ,
233273 )
234274 hidden_states = self .ln_f (hidden_states )
235275 return hidden_states
@@ -255,6 +295,52 @@ def __init__(
255295 )
256296 self .logits_processor = LogitsProcessor (config .vocab_size )
257297 self .sampler = Sampler ()
298+ self .trace_first = None
299+ self .trace_next = None
300+
301+ @torch .no_grad
302+ def enable_jit (
303+ self ,
304+ input_ids : torch .Tensor ,
305+ positions : torch .Tensor ,
306+ kv_caches : List [torch .Tensor ],
307+ is_prompt ,
308+ block_tables ,
309+ num_prefills ,
310+ num_prefill_tokens ,
311+ num_decode_tokens ,
312+ slot_mapping ,
313+ seq_lens ,
314+ seq_lens_tensor = None ,
315+ max_decode_seq_len = None ,
316+ ) -> torch .Tensor :
317+
318+ if is_prompt :
319+ self .transformer (input_ids , positions , kv_caches , is_prompt , block_tables ,num_prefills ,num_prefill_tokens ,num_decode_tokens ,slot_mapping ,seq_lens ,seq_lens_tensor ,max_decode_seq_len )
320+ example_input = (
321+ input_ids ,
322+ positions ,
323+ kv_caches ,
324+ is_prompt , block_tables ,num_prefills ,num_prefill_tokens ,num_decode_tokens ,slot_mapping ,seq_lens
325+ )
326+ self .trace_first = torch .jit .trace (self .transformer , example_input , check_trace = False , strict = False )
327+ self .trace_first = torch .jit .freeze (self .trace_first )
328+ self .trace_first (* example_input )
329+ self .trace_first (* example_input )
330+ else :
331+ example_input = (
332+ input_ids ,
333+ positions ,
334+ kv_caches ,
335+ is_prompt , block_tables ,num_prefills ,num_prefill_tokens ,num_decode_tokens ,slot_mapping ,seq_lens ,seq_lens_tensor ,max_decode_seq_len
336+ )
337+ self .trace_next = torch .jit .trace (
338+ self .transformer , example_input , check_trace = False , strict = False
339+ )
340+ self .trace_next = torch .jit .freeze (self .trace_next )
341+ self .trace_next (* example_input )
342+ self .trace_next (* example_input )
343+
258344
259345 def forward (
260346 self ,
@@ -263,8 +349,42 @@ def forward(
263349 kv_caches : List [torch .Tensor ],
264350 attn_metadata : AttentionMetadata ,
265351 ) -> torch .Tensor :
266- hidden_states = self .transformer (input_ids , positions , kv_caches ,
267- attn_metadata )
352+
353+ is_prompt = torch .tensor (attn_metadata .is_prompt )
354+ block_tables = attn_metadata .block_tables
355+ num_prefills = torch .tensor (attn_metadata .num_prefills )
356+ num_prefill_tokens = torch .tensor (attn_metadata .num_prefill_tokens )
357+ num_decode_tokens = torch .tensor (attn_metadata .num_decode_tokens )
358+ slot_mapping = attn_metadata .slot_mapping
359+ seq_lens = torch .tensor (attn_metadata .seq_lens )
360+ seq_lens_tensor = attn_metadata .seq_lens_tensor if attn_metadata .seq_lens_tensor is not None else None
361+ max_decode_seq_len = torch .tensor (attn_metadata .max_decode_seq_len ) if attn_metadata .max_decode_seq_len is not None else None
362+ attn_bias = attn_metadata .attn_bias
363+
364+ if kv_caches [0 ] is not None :
365+ if attn_metadata .is_prompt :
366+ if self .trace_first is None :
367+ self .enable_jit (input_ids , positions , kv_caches , is_prompt , block_tables ,num_prefills ,num_prefill_tokens ,num_decode_tokens ,slot_mapping ,seq_lens )
368+ hidden_states = self .trace_first (
369+ input_ids ,
370+ positions ,
371+ kv_caches ,
372+ is_prompt , block_tables ,num_prefills ,num_prefill_tokens ,num_decode_tokens ,slot_mapping ,seq_lens
373+ )
374+ else :
375+ if self .trace_next is None :
376+ self .enable_jit (input_ids , positions , kv_caches , is_prompt , block_tables ,num_prefills ,num_prefill_tokens ,num_decode_tokens ,slot_mapping ,seq_lens ,seq_lens_tensor ,max_decode_seq_len )
377+ hidden_states = self .trace_next (
378+ input_ids ,
379+ positions ,
380+ kv_caches ,
381+ is_prompt , block_tables ,num_prefills ,num_prefill_tokens ,num_decode_tokens ,slot_mapping ,seq_lens ,seq_lens_tensor ,max_decode_seq_len
382+ )
383+ else :
384+ # TorchSDPAMetadata(seq_lens_tensor=None, max_decode_seq_len=None, block_tables=tensor([]), num_prefills=1, num_prefill_tokens=5, num_decode_tokens=0, slot_mapping=tensor([9344, 9345, 9346, 9347, 9348]), is_prompt=True, seq_lens=[5])
385+ # TorchSDPAMetadata(seq_lens_tensor=tensor([6], dtype=torch.int32), max_decode_seq_len=6, block_tables=tensor([[584]], dtype=torch.int32), num_prefills=0, num_prefill_tokens=0, num_decode_tokens=1, slot_mapping=tensor([9349]), is_prompt=False, seq_lens=[6])
386+ hidden_states = self .transformer (input_ids , positions , kv_caches , is_prompt , block_tables ,num_prefills ,num_prefill_tokens ,num_decode_tokens ,slot_mapping ,seq_lens ,seq_lens_tensor ,max_decode_seq_len )
387+
268388 return hidden_states
269389
270390 def compute_logits (self , hidden_states : torch .Tensor ,
0 commit comments