@@ -251,25 +251,62 @@ def generate_random_cu_seqlens(per_cp_seqlen, cp_size=4, cp_rank=3, kv_stride=1,
251251 return ks , ke
252252
253253
254- def print_red_warning (message ):
255- print (f"\033 [31mWARNING: { message } \033 [0m" )
254+ def calculate_tensor_similarity (x , y , name = "tensor" ):
255+ """
256+ Calculate similarity between two tensors using a normalized dot product metric.
257+
258+ Unlike torch.testing.assert_close which uses absolute/relative tolerance based on
259+ element-wise differences, this function computes a global similarity score:
260+ sim = 2 * <x, y> / (||x||^2 + ||y||^2)
261+
262+ This metric is scale-invariant and measures the cosine-like similarity normalized
263+ by the magnitude of both tensors. It returns 1 for identical tensors and values
264+ closer to 0 for dissimilar ones. This is particularly useful for comparing tensors
265+ with varying magnitudes where relative errors matter more than absolute differences.
256266
267+ Args:
268+ x: First tensor to compare
269+ y: Second tensor to compare
270+ name: Name of the tensor for logging purposes
257271
258- def calc_sim (x , y , name = "tensor" ):
272+ Returns:
273+ Similarity score in range [0, 1] where 1 means identical
274+ """
259275 x , y = x .data .double (), y .data .double ()
260276 denominator = (x * x + y * y ).sum ()
261277 if denominator == 0 :
262- print_red_warning ( f' { name } all zero' )
278+ print ( f" \033 [33mWARNING: { name } all zero\033 [0m" )
263279 return 1
264280 sim = 2 * (x * y ).sum () / denominator
265281 return sim
266282
267283
268- def assert_similar (x , y , eps = 1e-8 , name = "tensor" , raise_assert = True ):
269- sim = calc_sim (x , y , name )
284+ def assert_tensors_similar (x , y , eps = 1e-8 , name = "tensor" , raise_assert = True ):
285+ """
286+ Assert that two tensors are similar using a global similarity metric.
287+
288+ Key differences from torch.testing.assert_close:
289+ - torch.testing.assert_close: Uses element-wise comparison with rtol/atol, checking
290+ that |x - y| <= atol + rtol * |y| for each element. It's sensitive to outliers
291+ and requires all elements to satisfy the tolerance.
292+ - assert_tensors_similar: Uses a single global similarity score (1 - sim) where sim is the
293+ normalized dot product. It's more robust to outliers and focuses on overall
294+ tensor similarity rather than element-wise precision. This is better suited for
295+ comparing large tensors where a few outlier elements shouldn't fail the test.
296+
297+ Args:
298+ x: First tensor to compare
299+ y: Second tensor to compare
300+ eps: Maximum allowed difference (1 - similarity), default 1e-8
301+ name: Name of the tensor for error messages
302+ raise_assert: Whether to raise assertion error on failure
303+ """
304+ sim = calculate_tensor_similarity (x , y , name )
270305 diff = 1. - sim
271306 if not (0 <= diff <= eps ):
272- print_red_warning (f'{ name } Error: { diff } ' )
307+ print (
308+ f"\033 [31mERROR: { name } similarity check failed, diff={ diff :.2e} (threshold={ eps :.2e} )\033 [0m"
309+ )
273310 if raise_assert :
274311 assert False # noqa: B011
275312
0 commit comments