Skip to content

Commit 76db283

Browse files
authored
Enable jit mlperf (vllm-project#28)
* enable jit * enable jit for mixtral * fix gptj jit acc * refine mixtral jit * fix jit in warmup with None KV cache
1 parent 4d743a8 commit 76db283

File tree

7 files changed

+442
-57
lines changed

7 files changed

+442
-57
lines changed

vllm/_custom_ops.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from typing import Optional, Tuple, Type
22

33
import torch
4-
4+
import intel_extension_for_pytorch as ipex
55
try:
66
from vllm._C import cache_ops as vllm_cache_ops
77
from vllm._C import ops as vllm_ops
@@ -98,9 +98,23 @@ def rotary_embedding(
9898
cos_sin_cache: torch.Tensor,
9999
is_neox: bool,
100100
) -> None:
101-
vllm_ops.rotary_embedding(positions, query, key, head_size, cos_sin_cache,
102-
is_neox)
103101

102+
# vllm_ops.rotary_embedding(positions, query, key, head_size, cos_sin_cache,
103+
# is_neox)
104+
105+
rotary_dim = cos_sin_cache.size(1)
106+
query = query.view(*query.shape[:-1], -1, head_size)
107+
key = key.view(*key.shape[:-1], -1, head_size)
108+
109+
query_rot = query[..., :rotary_dim]
110+
key_rot = key[..., :rotary_dim]
111+
112+
cos_sin = cos_sin_cache[positions.long()]
113+
cos, sin = cos_sin.chunk(2, dim=-1)
114+
cos = cos.repeat(1, 2)
115+
sin = sin.repeat(1, 2)
116+
117+
ipex.llm.functional.rotary_embedding(query_rot, key_rot, sin, cos, rotary_dim, is_neox)
104118

105119
def batched_rotary_embedding(positions: torch.Tensor, query: torch.Tensor,
106120
key: torch.Tensor, head_size: int,

vllm/attention/backends/torch_sdpa.py

Lines changed: 23 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -143,8 +143,17 @@ def forward(
143143
key: torch.Tensor,
144144
value: torch.Tensor,
145145
kv_cache: Optional[torch.Tensor],
146-
attn_metadata: TorchSDPAMetadata, # type: ignore
146+
is_prompt,
147+
block_tables,
148+
num_prefills,
149+
num_prefill_tokens,
150+
num_decode_tokens,
151+
slot_mapping,
152+
seq_lens,
153+
seq_lens_tensor=None,
154+
max_decode_seq_len=None,
147155
kv_scale: float = 1.0,
156+
attn_bias=None,
148157
) -> torch.Tensor:
149158
"""Forward pass with torch SDPA and PagedAttention.
150159
@@ -169,29 +178,29 @@ def forward(
169178
kv_cache, self.num_kv_heads, self.head_size)
170179
PagedAttention.write_to_paged_cache(key, value, key_cache,
171180
value_cache,
172-
attn_metadata.slot_mapping,
181+
slot_mapping,
173182
self.kv_cache_dtype, kv_scale)
174183

175-
if attn_metadata.is_prompt:
176-
assert attn_metadata.seq_lens is not None
177-
if (kv_cache is None or attn_metadata.block_tables.numel() == 0):
184+
if is_prompt:
185+
assert seq_lens is not None
186+
if (kv_cache is None or block_tables.numel() == 0):
178187
if self.num_kv_heads != self.num_heads:
179188
key = key.repeat_interleave(self.num_queries_per_kv, dim=1)
180189
value = value.repeat_interleave(self.num_queries_per_kv,
181190
dim=1)
182191

183-
if attn_metadata.attn_bias is None:
192+
if attn_bias is None:
184193
if self.alibi_slopes is not None:
185194
att_masks = _make_alibi_bias(
186195
self.alibi_slopes, query.dtype,
187196
attn_metadata.seq_lens) # type: ignore
188197
elif self.sliding_window is not None:
189198
att_masks = _make_sliding_window_bias(
190-
attn_metadata.seq_lens, self.sliding_window,
199+
seq_lens, self.sliding_window,
191200
query.dtype) # type: ignore
192201
else:
193-
att_masks = [None] * len(attn_metadata.seq_lens)
194-
attn_metadata.attn_bias = att_masks
202+
att_masks = [None] * len(seq_lens)
203+
attn_bias = att_masks
195204

196205
query = query.movedim(0, query.dim() - 2)
197206
key = key.movedim(0, key.dim() - 2)
@@ -201,8 +210,8 @@ def forward(
201210
output = torch.empty(
202211
(num_tokens, self.num_heads, self.head_size),
203212
dtype=query.dtype)
204-
for seq_len, mask in zip(attn_metadata.seq_lens,
205-
attn_metadata.attn_bias):
213+
for seq_len, mask in zip(seq_lens,
214+
attn_bias):
206215
end = start + seq_len
207216
sub_out = scaled_dot_product_attention(
208217
query[None, :, start:end, :],
@@ -226,9 +235,9 @@ def forward(
226235
query,
227236
key_cache,
228237
value_cache,
229-
attn_metadata.block_tables,
230-
attn_metadata.seq_lens_tensor,
231-
attn_metadata.max_decode_seq_len,
238+
block_tables,
239+
seq_lens_tensor,
240+
max_decode_seq_len,
232241
self.kv_cache_dtype,
233242
self.num_kv_heads,
234243
self.scale,

vllm/attention/layer.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -84,9 +84,17 @@ def forward(
8484
key: torch.Tensor,
8585
value: torch.Tensor,
8686
kv_cache: Optional[torch.Tensor],
87-
attn_metadata: AttentionMetadata,
87+
is_prompt,
88+
block_tables,
89+
num_prefills,
90+
num_prefill_tokens,
91+
num_decode_tokens,
92+
slot_mapping,
93+
seq_lens,
94+
seq_lens_tensor=None,
95+
max_decode_seq_len=None,
8896
) -> torch.Tensor:
89-
return self.impl.forward(query, key, value, kv_cache, attn_metadata,
97+
return self.impl.forward(query, key, value, kv_cache, is_prompt, block_tables,num_prefills,num_prefill_tokens,num_decode_tokens,slot_mapping,seq_lens,seq_lens_tensor,max_decode_seq_len,
9098
self._kv_scale)
9199

92100
def extra_repr(self) -> str:

vllm/model_executor/layers/layernorm.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66

77
from vllm.model_executor.custom_op import CustomOp
88

9+
from vllm import _custom_ops as ops
10+
import intel_extension_for_pytorch as ipex
911

1012
class RMSNorm(CustomOp):
1113
"""Root mean square normalization.
@@ -48,19 +50,17 @@ def forward_cuda(
4850
x: torch.Tensor,
4951
residual: Optional[torch.Tensor] = None,
5052
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
51-
from vllm import _custom_ops as ops
52-
5353
if residual is not None:
54-
ops.fused_add_rms_norm(
55-
x,
54+
x = ipex.llm.functional.add_rms_norm(
5655
residual,
56+
x,
5757
self.weight.data,
58+
None,
5859
self.variance_epsilon,
60+
True
5961
)
6062
return x, residual
61-
out = torch.empty_like(x)
62-
ops.rms_norm(
63-
out,
63+
out = ipex.llm.functional.rms_norm(
6464
x,
6565
self.weight.data,
6666
self.variance_epsilon,

vllm/model_executor/models/gpt_j.py

Lines changed: 128 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -97,12 +97,20 @@ def forward(
9797
position_ids: torch.Tensor,
9898
hidden_states: torch.Tensor,
9999
kv_cache: torch.Tensor,
100-
attn_metadata: AttentionMetadata,
100+
is_prompt,
101+
block_tables,
102+
num_prefills,
103+
num_prefill_tokens,
104+
num_decode_tokens,
105+
slot_mapping,
106+
seq_lens,
107+
seq_lens_tensor=None,
108+
max_decode_seq_len=None,
101109
) -> torch.Tensor:
102110
qkv, _ = self.qkv_proj(hidden_states)
103111
q, k, v = qkv.chunk(chunks=3, dim=-1)
104112
q, k = self.rotary_emb(position_ids, q, k)
105-
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
113+
attn_output = self.attn(q, k, v, kv_cache, is_prompt, block_tables,num_prefills,num_prefill_tokens,num_decode_tokens,slot_mapping,seq_lens,seq_lens_tensor,max_decode_seq_len)
106114
attn_output, _ = self.out_proj(attn_output)
107115
return attn_output
108116

@@ -166,15 +174,31 @@ def forward(
166174
position_ids: torch.Tensor,
167175
hidden_states: torch.Tensor,
168176
kv_cache: torch.Tensor,
169-
attn_metadata: AttentionMetadata,
177+
is_prompt,
178+
block_tables,
179+
num_prefills,
180+
num_prefill_tokens,
181+
num_decode_tokens,
182+
slot_mapping,
183+
seq_lens,
184+
seq_lens_tensor=None,
185+
max_decode_seq_len=None,
170186
) -> torch.Tensor:
171187
residual = hidden_states
172188
hidden_states = self.ln_1(hidden_states)
173189
attn_output = self.attn(
174190
position_ids=position_ids,
175191
hidden_states=hidden_states,
176192
kv_cache=kv_cache,
177-
attn_metadata=attn_metadata,
193+
is_prompt=is_prompt,
194+
block_tables=block_tables,
195+
num_prefills=num_prefills,
196+
num_prefill_tokens=num_prefill_tokens,
197+
num_decode_tokens=num_decode_tokens,
198+
slot_mapping=slot_mapping,
199+
seq_lens=seq_lens,
200+
seq_lens_tensor=seq_lens_tensor,
201+
max_decode_seq_len=max_decode_seq_len,
178202
)
179203
mlp_output = self.mlp(hidden_states)
180204
if self.mlp.fc_out.tp_size <=1 and not hasattr(self, "ipex_fusion"):
@@ -220,7 +244,15 @@ def forward(
220244
input_ids: torch.Tensor,
221245
position_ids: torch.Tensor,
222246
kv_caches: List[torch.Tensor],
223-
attn_metadata: AttentionMetadata,
247+
is_prompt,
248+
block_tables,
249+
num_prefills,
250+
num_prefill_tokens,
251+
num_decode_tokens,
252+
slot_mapping,
253+
seq_lens,
254+
seq_lens_tensor=None,
255+
max_decode_seq_len=None,
224256
) -> torch.Tensor:
225257
hidden_states = self.wte(input_ids)
226258
for i in range(len(self.h)):
@@ -229,7 +261,15 @@ def forward(
229261
position_ids,
230262
hidden_states,
231263
kv_caches[i],
232-
attn_metadata,
264+
is_prompt,
265+
block_tables,
266+
num_prefills,
267+
num_prefill_tokens,
268+
num_decode_tokens,
269+
slot_mapping,
270+
seq_lens,
271+
seq_lens_tensor,
272+
max_decode_seq_len,
233273
)
234274
hidden_states = self.ln_f(hidden_states)
235275
return hidden_states
@@ -255,6 +295,52 @@ def __init__(
255295
)
256296
self.logits_processor = LogitsProcessor(config.vocab_size)
257297
self.sampler = Sampler()
298+
self.trace_first=None
299+
self.trace_next=None
300+
301+
@torch.no_grad
302+
def enable_jit(
303+
self,
304+
input_ids: torch.Tensor,
305+
positions: torch.Tensor,
306+
kv_caches: List[torch.Tensor],
307+
is_prompt,
308+
block_tables,
309+
num_prefills,
310+
num_prefill_tokens,
311+
num_decode_tokens,
312+
slot_mapping,
313+
seq_lens,
314+
seq_lens_tensor=None,
315+
max_decode_seq_len=None,
316+
) -> torch.Tensor:
317+
318+
if is_prompt:
319+
self.transformer(input_ids, positions, kv_caches, is_prompt, block_tables,num_prefills,num_prefill_tokens,num_decode_tokens,slot_mapping,seq_lens,seq_lens_tensor,max_decode_seq_len)
320+
example_input = (
321+
input_ids,
322+
positions,
323+
kv_caches,
324+
is_prompt, block_tables,num_prefills,num_prefill_tokens,num_decode_tokens,slot_mapping,seq_lens
325+
)
326+
self.trace_first = torch.jit.trace(self.transformer, example_input, check_trace=False, strict=False)
327+
self.trace_first = torch.jit.freeze(self.trace_first)
328+
self.trace_first(*example_input)
329+
self.trace_first(*example_input)
330+
else:
331+
example_input = (
332+
input_ids,
333+
positions,
334+
kv_caches,
335+
is_prompt, block_tables,num_prefills,num_prefill_tokens,num_decode_tokens,slot_mapping,seq_lens,seq_lens_tensor,max_decode_seq_len
336+
)
337+
self.trace_next = torch.jit.trace(
338+
self.transformer, example_input, check_trace=False, strict=False
339+
)
340+
self.trace_next = torch.jit.freeze(self.trace_next)
341+
self.trace_next(*example_input)
342+
self.trace_next(*example_input)
343+
258344

259345
def forward(
260346
self,
@@ -263,8 +349,42 @@ def forward(
263349
kv_caches: List[torch.Tensor],
264350
attn_metadata: AttentionMetadata,
265351
) -> torch.Tensor:
266-
hidden_states = self.transformer(input_ids, positions, kv_caches,
267-
attn_metadata)
352+
353+
is_prompt=torch.tensor(attn_metadata.is_prompt)
354+
block_tables=attn_metadata.block_tables
355+
num_prefills=torch.tensor(attn_metadata.num_prefills)
356+
num_prefill_tokens=torch.tensor(attn_metadata.num_prefill_tokens)
357+
num_decode_tokens=torch.tensor(attn_metadata.num_decode_tokens)
358+
slot_mapping = attn_metadata.slot_mapping
359+
seq_lens=torch.tensor(attn_metadata.seq_lens)
360+
seq_lens_tensor=attn_metadata.seq_lens_tensor if attn_metadata.seq_lens_tensor is not None else None
361+
max_decode_seq_len=torch.tensor(attn_metadata.max_decode_seq_len) if attn_metadata.max_decode_seq_len is not None else None
362+
attn_bias = attn_metadata.attn_bias
363+
364+
if kv_caches[0] is not None:
365+
if attn_metadata.is_prompt:
366+
if self.trace_first is None:
367+
self.enable_jit(input_ids, positions, kv_caches, is_prompt, block_tables,num_prefills,num_prefill_tokens,num_decode_tokens,slot_mapping,seq_lens)
368+
hidden_states = self.trace_first(
369+
input_ids,
370+
positions,
371+
kv_caches,
372+
is_prompt, block_tables,num_prefills,num_prefill_tokens,num_decode_tokens,slot_mapping,seq_lens
373+
)
374+
else:
375+
if self.trace_next is None:
376+
self.enable_jit(input_ids, positions, kv_caches, is_prompt, block_tables,num_prefills,num_prefill_tokens,num_decode_tokens,slot_mapping,seq_lens,seq_lens_tensor,max_decode_seq_len)
377+
hidden_states = self.trace_next(
378+
input_ids,
379+
positions,
380+
kv_caches,
381+
is_prompt, block_tables,num_prefills,num_prefill_tokens,num_decode_tokens,slot_mapping,seq_lens,seq_lens_tensor,max_decode_seq_len
382+
)
383+
else:
384+
# TorchSDPAMetadata(seq_lens_tensor=None, max_decode_seq_len=None, block_tables=tensor([]), num_prefills=1, num_prefill_tokens=5, num_decode_tokens=0, slot_mapping=tensor([9344, 9345, 9346, 9347, 9348]), is_prompt=True, seq_lens=[5])
385+
# TorchSDPAMetadata(seq_lens_tensor=tensor([6], dtype=torch.int32), max_decode_seq_len=6, block_tables=tensor([[584]], dtype=torch.int32), num_prefills=0, num_prefill_tokens=0, num_decode_tokens=1, slot_mapping=tensor([9349]), is_prompt=False, seq_lens=[6])
386+
hidden_states = self.transformer(input_ids, positions, kv_caches, is_prompt, block_tables,num_prefills,num_prefill_tokens,num_decode_tokens,slot_mapping,seq_lens,seq_lens_tensor,max_decode_seq_len)
387+
268388
return hidden_states
269389

270390
def compute_logits(self, hidden_states: torch.Tensor,

0 commit comments

Comments
 (0)