diff --git a/examples/deepseek_v32/sparse_mla_fwd.py b/examples/deepseek_v32/sparse_mla_fwd.py index b1bce065f..ccd560346 100644 --- a/examples/deepseek_v32/sparse_mla_fwd.py +++ b/examples/deepseek_v32/sparse_mla_fwd.py @@ -2,6 +2,7 @@ import torch import tilelang from tilelang import language as T +from utils import assert_tensors_similar @tilelang.jit( @@ -253,6 +254,12 @@ def test_sparse_mla_fwd(B=1, tl_out, tl_lse = sparse_mla_fwd_interface(q, kv, indices) + if SKV <= 4096: + # otherwise may cause out of memory + ref_out = ref_sparse_mla_fwd_interface(q, kv, indices) + assert_tensors_similar(tl_out, ref_out, eps=1e-2, name="out") + print("assert_tensors_similar passed") + def fn(): return sparse_mla_fwd_interface(q, kv, indices) @@ -270,4 +277,4 @@ def fn(): if __name__ == "__main__": test_sparse_mla_fwd( - B=1, S=4096, SKV=32768, H=128, HKV=1, DQK=576, DV=512, topk=2048, dtype=torch.bfloat16) + B=1, S=4096, SKV=4096, H=128, HKV=1, DQK=576, DV=512, topk=2048, dtype=torch.bfloat16) diff --git a/examples/deepseek_v32/utils.py b/examples/deepseek_v32/utils.py index c94d382d4..2ea34b14a 100644 --- a/examples/deepseek_v32/utils.py +++ b/examples/deepseek_v32/utils.py @@ -251,25 +251,62 @@ def generate_random_cu_seqlens(per_cp_seqlen, cp_size=4, cp_rank=3, kv_stride=1, return ks, ke -def print_red_warning(message): - print(f"\033[31mWARNING: {message}\033[0m") +def calculate_tensor_similarity(x, y, name="tensor"): + """ + Calculate similarity between two tensors using a normalized dot product metric. + + Unlike torch.testing.assert_close which uses absolute/relative tolerance based on + element-wise differences, this function computes a global similarity score: + sim = 2 * / (||x||^2 + ||y||^2) + + This metric is scale-invariant and measures the cosine-like similarity normalized + by the magnitude of both tensors. It returns 1 for identical tensors and values + closer to 0 for dissimilar ones. This is particularly useful for comparing tensors + with varying magnitudes where relative errors matter more than absolute differences. + Args: + x: First tensor to compare + y: Second tensor to compare + name: Name of the tensor for logging purposes -def calc_sim(x, y, name="tensor"): + Returns: + Similarity score in range [0, 1] where 1 means identical + """ x, y = x.data.double(), y.data.double() denominator = (x * x + y * y).sum() if denominator == 0: - print_red_warning(f'{name} all zero') + print(f"\033[33mWARNING: {name} all zero\033[0m") return 1 sim = 2 * (x * y).sum() / denominator return sim -def assert_similar(x, y, eps=1e-8, name="tensor", raise_assert=True): - sim = calc_sim(x, y, name) +def assert_tensors_similar(x, y, eps=1e-8, name="tensor", raise_assert=True): + """ + Assert that two tensors are similar using a global similarity metric. + + Key differences from torch.testing.assert_close: + - torch.testing.assert_close: Uses element-wise comparison with rtol/atol, checking + that |x - y| <= atol + rtol * |y| for each element. It's sensitive to outliers + and requires all elements to satisfy the tolerance. + - assert_tensors_similar: Uses a single global similarity score (1 - sim) where sim is the + normalized dot product. It's more robust to outliers and focuses on overall + tensor similarity rather than element-wise precision. This is better suited for + comparing large tensors where a few outlier elements shouldn't fail the test. + + Args: + x: First tensor to compare + y: Second tensor to compare + eps: Maximum allowed difference (1 - similarity), default 1e-8 + name: Name of the tensor for error messages + raise_assert: Whether to raise assertion error on failure + """ + sim = calculate_tensor_similarity(x, y, name) diff = 1. - sim if not (0 <= diff <= eps): - print_red_warning(f'{name} Error: {diff}') + print( + f"\033[31mERROR: {name} similarity check failed, diff={diff:.2e} (threshold={eps:.2e})\033[0m" + ) if raise_assert: assert False # noqa: B011