Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion examples/deepseek_v32/sparse_mla_fwd.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import torch
import tilelang
from tilelang import language as T
from utils import assert_tensors_similar


@tilelang.jit(
Expand Down Expand Up @@ -253,6 +254,12 @@ def test_sparse_mla_fwd(B=1,

tl_out, tl_lse = sparse_mla_fwd_interface(q, kv, indices)

if SKV <= 4096:
# otherwise may cause out of memory
ref_out = ref_sparse_mla_fwd_interface(q, kv, indices)
assert_tensors_similar(tl_out, ref_out, eps=1e-2, name="out")
print("assert_tensors_similar passed")

def fn():
return sparse_mla_fwd_interface(q, kv, indices)

Expand All @@ -270,4 +277,4 @@ def fn():

if __name__ == "__main__":
test_sparse_mla_fwd(
B=1, S=4096, SKV=32768, H=128, HKV=1, DQK=576, DV=512, topk=2048, dtype=torch.bfloat16)
B=1, S=4096, SKV=4096, H=128, HKV=1, DQK=576, DV=512, topk=2048, dtype=torch.bfloat16)
51 changes: 44 additions & 7 deletions examples/deepseek_v32/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,25 +251,62 @@ def generate_random_cu_seqlens(per_cp_seqlen, cp_size=4, cp_rank=3, kv_stride=1,
return ks, ke


def print_red_warning(message):
print(f"\033[31mWARNING: {message}\033[0m")
def calculate_tensor_similarity(x, y, name="tensor"):
"""
Calculate similarity between two tensors using a normalized dot product metric.

Unlike torch.testing.assert_close which uses absolute/relative tolerance based on
element-wise differences, this function computes a global similarity score:
sim = 2 * <x, y> / (||x||^2 + ||y||^2)

This metric is scale-invariant and measures the cosine-like similarity normalized
by the magnitude of both tensors. It returns 1 for identical tensors and values
closer to 0 for dissimilar ones. This is particularly useful for comparing tensors
with varying magnitudes where relative errors matter more than absolute differences.

Args:
x: First tensor to compare
y: Second tensor to compare
name: Name of the tensor for logging purposes

def calc_sim(x, y, name="tensor"):
Returns:
Similarity score in range [0, 1] where 1 means identical
"""
x, y = x.data.double(), y.data.double()
denominator = (x * x + y * y).sum()
if denominator == 0:
print_red_warning(f'{name} all zero')
print(f"\033[33mWARNING: {name} all zero\033[0m")
return 1
sim = 2 * (x * y).sum() / denominator
return sim


def assert_similar(x, y, eps=1e-8, name="tensor", raise_assert=True):
sim = calc_sim(x, y, name)
def assert_tensors_similar(x, y, eps=1e-8, name="tensor", raise_assert=True):
"""
Assert that two tensors are similar using a global similarity metric.

Key differences from torch.testing.assert_close:
- torch.testing.assert_close: Uses element-wise comparison with rtol/atol, checking
that |x - y| <= atol + rtol * |y| for each element. It's sensitive to outliers
and requires all elements to satisfy the tolerance.
- assert_tensors_similar: Uses a single global similarity score (1 - sim) where sim is the
normalized dot product. It's more robust to outliers and focuses on overall
tensor similarity rather than element-wise precision. This is better suited for
comparing large tensors where a few outlier elements shouldn't fail the test.

Args:
x: First tensor to compare
y: Second tensor to compare
eps: Maximum allowed difference (1 - similarity), default 1e-8
name: Name of the tensor for error messages
raise_assert: Whether to raise assertion error on failure
"""
sim = calculate_tensor_similarity(x, y, name)
diff = 1. - sim
if not (0 <= diff <= eps):
print_red_warning(f'{name} Error: {diff}')
print(
f"\033[31mERROR: {name} similarity check failed, diff={diff:.2e} (threshold={eps:.2e})\033[0m"
)
if raise_assert:
assert False # noqa: B011

Expand Down