Skip to content

Commit 6793994

Browse files
committed
symint and other fixes
1 parent e14e255 commit 6793994

File tree

5 files changed

+40
-128
lines changed

5 files changed

+40
-128
lines changed

tests/kernels/test_layernorm.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,13 @@ def test_rms_norm(
5555
else:
5656
torch.testing.assert_close(out, ref_out, atol=1e-2, rtol=1e-2)
5757

58+
if residual is not None:
59+
opcheck(torch.ops._C.fused_add_rms_norm,
60+
(x, residual, layer.weight.data, layer.variance_epsilon))
61+
else:
62+
opcheck(torch.ops._C.rms_norm,
63+
(out, x, layer.weight.data, layer.variance_epsilon))
64+
5865

5966
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
6067
@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES)
@@ -119,6 +126,15 @@ def test_rms_norm_quant(
119126
if add_residual:
120127
assert torch.allclose(residual1, residual2, atol=1e-3)
121128

129+
if add_residual:
130+
opcheck(torch.ops._C.add_residual_rms_norm_quant,
131+
(out2, x_, residual2, tmp, layer.weight.data, scale2,
132+
layer.variance_epsilon))
133+
else:
134+
opcheck(
135+
torch.ops._C.rms_norm_quant,
136+
(out2, x_, tmp, layer.weight.data, scale2, layer.variance_epsilon))
137+
122138

123139
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
124140
@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES)
@@ -180,3 +196,12 @@ def test_rms_norm_quant2(
180196
assert torch.allclose(out1, out2, atol=2.0)
181197
if add_residual:
182198
assert torch.allclose(residual1, residual2, atol=1e-3)
199+
200+
if add_residual:
201+
opcheck(torch.ops._C.add_residual_rms_norm_quant,
202+
(out2, x_, residual2, tmp, layer.weight.data, scale1,
203+
layer.variance_epsilon))
204+
else:
205+
opcheck(
206+
torch.ops._C.rms_norm_quant,
207+
(out2, x_, tmp, layer.weight.data, scale1, layer.variance_epsilon))

vllm/attention/backends/flash_attn.py

Lines changed: 0 additions & 113 deletions
Original file line numberDiff line numberDiff line change
@@ -121,116 +121,6 @@ def _(
121121
return torch.empty_like(decode_query)
122122

123123

124-
125-
@torch.library.impl("vllm::flash_attn_varlen_func", "cuda")
126-
def _flash_attn_varlen_func(
127-
out_shape,
128-
q,
129-
k,
130-
v,
131-
cu_seqlens_q,
132-
cu_seqlens_k,
133-
max_seqlen_q,
134-
max_seqlen_k,
135-
softmax_scale,
136-
causal,
137-
window_size,
138-
alibi_slopes,
139-
block_table,
140-
):
141-
return flash_attn_varlen_func(
142-
q=q,
143-
k=k,
144-
v=v,
145-
cu_seqlens_q=cu_seqlens_q,
146-
cu_seqlens_k=cu_seqlens_k,
147-
max_seqlen_q=max_seqlen_q,
148-
max_seqlen_k=max_seqlen_k,
149-
softmax_scale=softmax_scale,
150-
causal=causal,
151-
window_size=window_size,
152-
alibi_slopes=alibi_slopes,
153-
block_table=block_table,
154-
)
155-
156-
157-
@torch.library.impl_abstract("vllm::flash_attn_varlen_func")
158-
def _flash_attn_varlen_func_meta(
159-
out_shape,
160-
q,
161-
k,
162-
v,
163-
cu_seqlens_q,
164-
cu_seqlens_k,
165-
max_seqlen_q,
166-
max_seqlen_k,
167-
softmax_scale,
168-
causal,
169-
window_size,
170-
alibi_slopes,
171-
block_table,
172-
):
173-
# TODO: is this always correct?
174-
return torch.empty(out_shape,
175-
dtype=q.dtype,
176-
layout=q.layout,
177-
device=q.device)
178-
179-
180-
torch.library.define("vllm::flash_attn_with_kvcache", ("(int[] out_shape, "
181-
"Tensor q, "
182-
"Tensor k, "
183-
"Tensor v, "
184-
"Tensor block_table, "
185-
"Tensor cache_seqlens, "
186-
"float softmax_scale, "
187-
"bool causal, "
188-
"float[]? alibi_slopes"
189-
") -> Tensor"))
190-
191-
192-
@torch.library.impl("vllm::flash_attn_with_kvcache", "cuda")
193-
def _flash_attn_with_kvcache(
194-
out_shape,
195-
decode_query,
196-
key_cache,
197-
value_cache,
198-
block_table,
199-
cache_seqlens,
200-
softmax_scale,
201-
causal,
202-
alibi_slopes,
203-
):
204-
return flash_attn_with_kvcache(
205-
decode_query,
206-
key_cache,
207-
value_cache,
208-
block_table=block_table,
209-
cache_seqlens=cache_seqlens,
210-
softmax_scale=softmax_scale,
211-
causal=causal,
212-
alibi_slopes=alibi_slopes,
213-
)
214-
215-
216-
@torch.library.impl_abstract("vllm::flash_attn_with_kvcache")
217-
def _flash_attn_with_kvcache_meta(
218-
out_shape,
219-
decode_query,
220-
key_cache,
221-
value_cache,
222-
block_table,
223-
cache_seqlens,
224-
softmax_scale,
225-
causal,
226-
alibi_slopes,
227-
):
228-
return torch.empty(out_shape,
229-
dtype=decode_query.dtype,
230-
layout=decode_query.layout,
231-
device=decode_query.device)
232-
233-
234124
class FlashAttentionBackend(AttentionBackend):
235125

236126
@staticmethod
@@ -779,7 +669,6 @@ def forward(
779669
# When block_tables are not filled, it means q and k are the
780670
# prompt, and they have the same length.
781671
out = torch.ops.vllm.flash_attn_varlen_func(
782-
out_shape=output[:num_prefill_tokens].size(),
783672
q=query,
784673
k=key,
785674
v=value,
@@ -817,10 +706,8 @@ def forward(
817706

818707
if decode_meta := attn_metadata.decode_metadata:
819708
# Decoding run.
820-
output_shape = output[num_prefill_tokens:].squeeze(1).size()
821709
output[
822710
num_prefill_tokens:] = torch.ops.vllm.flash_attn_with_kvcache(
823-
output_shape,
824711
decode_query.unsqueeze(1),
825712
key_cache,
826713
value_cache,

vllm/model_executor/model_optimizer/fused_op_generator_utils.py

Lines changed: 13 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -68,27 +68,26 @@ def arg_schema_type(n: torch.fx.node.Argument,
6868
"""
6969
Get the schema or C++ type for a fused op argument.
7070
"""
71-
if isinstance(n, float):
72-
return "float"
73-
elif isinstance(n, int):
74-
return "int"
71+
if n.type is not None:
72+
ty = n.type.__name__
73+
elif n.meta.get('type') and n.meta.get('type').__name__ != 'FakeTensor':
74+
ty = n.meta.get('type').__name__
75+
print(f"meta type {ty}")
76+
if ty == 'Size':
77+
return 'std::vector<int64_t>' if add_prefix else 'int[]'
7578
else:
76-
if n.type is not None:
77-
ty = n.type.__name__
78-
elif n.meta.get(
79-
'type') and n.meta.get('type').__name__ != 'FakeTensor':
80-
ty = n.meta.get('type').__name__
81-
if ty == 'Size':
82-
return 'std::vector<int64_t> const' if add_prefix else 'int[]'
83-
else:
84-
# this default is a bit sketchy
85-
ty = "Tensor"
79+
# this default is a bit sketchy
80+
ty = "Tensor"
8681

8782
builtin_types = {"int": "int64_t", "float": "double"}
8883

8984
if add_prefix and ty in builtin_types:
9085
return builtin_types[ty]
9186

87+
print(f"arg_schema_type {ty}")
88+
if ty == "SymInt" and add_prefix:
89+
return "int64_t"
90+
9291
return ty if not add_prefix else f"torch::{ty}"
9392

9493

vllm/model_executor/model_optimizer/model_optimizer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ def __init__(self, backend: Optional[str] = 'inductor'):
8888
def __call__(self, gm: torch.fx.GraphModule,
8989
example_inputs: List[torch.Tensor]) -> Callable:
9090
# Temporarily disable optimizer so we can collect dynamo issues.
91-
return gm
91+
#return gm
9292

9393
logger.info("Graph optimizer start")
9494

vllm/model_executor/model_optimizer/naive_fused_op_generator.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -320,6 +320,7 @@ def make_fused_op(
320320
f"{arg_schema_type(inp, True)}" for inp in inputs.values()
321321
]
322322
logger.debug("fused op argument types: %s", arg_types)
323+
print(f"fused op argument types: {str(arg_types)}")
323324
for i, name in enumerate(inputs.keys()):
324325
# Don't use const refs here so inputs can be deleted when no
325326
# longer needed.

0 commit comments

Comments
 (0)