Skip to content

Commit d94a226

Browse files
committed
add correctness check for dsa
1 parent 242cb45 commit d94a226

File tree

2 files changed

+52
-8
lines changed

2 files changed

+52
-8
lines changed

examples/deepseek_v32/sparse_mla_fwd.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import torch
33
import tilelang
44
from tilelang import language as T
5+
from utils import assert_tensors_similar
56

67

78
@tilelang.jit(
@@ -253,6 +254,12 @@ def test_sparse_mla_fwd(B=1,
253254

254255
tl_out, tl_lse = sparse_mla_fwd_interface(q, kv, indices)
255256

257+
if SKV <= 4096:
258+
# otherwise may cause out of memory
259+
ref_out = ref_sparse_mla_fwd_interface(q, kv, indices)
260+
assert_tensors_similar(tl_out, ref_out, eps=1e-2, name="out")
261+
print("assert_tensors_similar passed")
262+
256263
def fn():
257264
return sparse_mla_fwd_interface(q, kv, indices)
258265

@@ -270,4 +277,4 @@ def fn():
270277

271278
if __name__ == "__main__":
272279
test_sparse_mla_fwd(
273-
B=1, S=4096, SKV=32768, H=128, HKV=1, DQK=576, DV=512, topk=2048, dtype=torch.bfloat16)
280+
B=1, S=4096, SKV=4096, H=128, HKV=1, DQK=576, DV=512, topk=2048, dtype=torch.bfloat16)

examples/deepseek_v32/utils.py

Lines changed: 44 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -251,25 +251,62 @@ def generate_random_cu_seqlens(per_cp_seqlen, cp_size=4, cp_rank=3, kv_stride=1,
251251
return ks, ke
252252

253253

254-
def print_red_warning(message):
255-
print(f"\033[31mWARNING: {message}\033[0m")
254+
def calculate_tensor_similarity(x, y, name="tensor"):
255+
"""
256+
Calculate similarity between two tensors using a normalized dot product metric.
257+
258+
Unlike torch.testing.assert_close which uses absolute/relative tolerance based on
259+
element-wise differences, this function computes a global similarity score:
260+
sim = 2 * <x, y> / (||x||^2 + ||y||^2)
261+
262+
This metric is scale-invariant and measures the cosine-like similarity normalized
263+
by the magnitude of both tensors. It returns 1 for identical tensors and values
264+
closer to 0 for dissimilar ones. This is particularly useful for comparing tensors
265+
with varying magnitudes where relative errors matter more than absolute differences.
256266
267+
Args:
268+
x: First tensor to compare
269+
y: Second tensor to compare
270+
name: Name of the tensor for logging purposes
257271
258-
def calc_sim(x, y, name="tensor"):
272+
Returns:
273+
Similarity score in range [0, 1] where 1 means identical
274+
"""
259275
x, y = x.data.double(), y.data.double()
260276
denominator = (x * x + y * y).sum()
261277
if denominator == 0:
262-
print_red_warning(f'{name} all zero')
278+
print(f"\033[33mWARNING: {name} all zero\033[0m")
263279
return 1
264280
sim = 2 * (x * y).sum() / denominator
265281
return sim
266282

267283

268-
def assert_similar(x, y, eps=1e-8, name="tensor", raise_assert=True):
269-
sim = calc_sim(x, y, name)
284+
def assert_tensors_similar(x, y, eps=1e-8, name="tensor", raise_assert=True):
285+
"""
286+
Assert that two tensors are similar using a global similarity metric.
287+
288+
Key differences from torch.testing.assert_close:
289+
- torch.testing.assert_close: Uses element-wise comparison with rtol/atol, checking
290+
that |x - y| <= atol + rtol * |y| for each element. It's sensitive to outliers
291+
and requires all elements to satisfy the tolerance.
292+
- assert_tensors_similar: Uses a single global similarity score (1 - sim) where sim is the
293+
normalized dot product. It's more robust to outliers and focuses on overall
294+
tensor similarity rather than element-wise precision. This is better suited for
295+
comparing large tensors where a few outlier elements shouldn't fail the test.
296+
297+
Args:
298+
x: First tensor to compare
299+
y: Second tensor to compare
300+
eps: Maximum allowed difference (1 - similarity), default 1e-8
301+
name: Name of the tensor for error messages
302+
raise_assert: Whether to raise assertion error on failure
303+
"""
304+
sim = calculate_tensor_similarity(x, y, name)
270305
diff = 1. - sim
271306
if not (0 <= diff <= eps):
272-
print_red_warning(f'{name} Error: {diff}')
307+
print(
308+
f"\033[31mERROR: {name} similarity check failed, diff={diff:.2e} (threshold={eps:.2e})\033[0m"
309+
)
273310
if raise_assert:
274311
assert False # noqa: B011
275312

0 commit comments

Comments
 (0)