Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions tilelang/jit/adapter/cython/cython_wrapper.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -197,12 +197,14 @@ cdef class CythonKernelWrapper:
tensor = inputs[ins_idx]
ins_idx += 1
# TODO(chenggang): remove this check or rewrite by ourselves?
'''
if isinstance(tensor, torch.Tensor) and tensor._base is not None and not tensor.is_contiguous():
base_tensor = tensor._base.as_strided(tensor._base.shape, tensor.stride())
if torch._debug_has_internal_overlap(base_tensor):
raise ValueError(f"Cannot use an overlapping tensor"
f"(shape={tensor.shape}, strides={tensor.stride()}, "
f"overlap={torch._debug_has_internal_overlap(base_tensor)}) as the kernel input")
'''
Comment on lines 199 to +207
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Instead of commenting out this code using a multi-line string, it's better to remove it entirely, along with the now-obsolete TODO comment. This improves code clarity and avoids clutter. If the code needs to be restored later, version control history can be used.

Comment on lines +200 to +207
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Restore the overlapping-tensor guard

Commenting out this block lets overlapping views (e.g., results of expand, as_strided, certain transposes) reach the kernel. Those views alias the same storage, so once the kernel writes, threads race and results become undefined. The underlying check was flawed because of the as_strided call on _base, but we still need a guard; we should just check the tensor directly and treat failures conservatively. Please rewrite instead of removing.

-            '''
-            if isinstance(tensor, torch.Tensor) and tensor._base is not None and not tensor.is_contiguous():
-                base_tensor = tensor._base.as_strided(tensor._base.shape, tensor.stride())
-                if torch._debug_has_internal_overlap(base_tensor):
-                    raise ValueError(f"Cannot use an overlapping tensor"
-                                     f"(shape={tensor.shape}, strides={tensor.stride()}, "
-                                     f"overlap={torch._debug_has_internal_overlap(base_tensor)}) as the kernel input")
-            '''
+            if (
+                isinstance(tensor, torch.Tensor)
+                and tensor._base is not None
+                and not tensor.is_contiguous()
+            ):
+                try:
+                    has_overlap = torch._debug_has_internal_overlap(tensor)
+                except RuntimeError as err:
+                    raise ValueError(
+                        f"Cannot use an overlapping tensor "
+                        f"(shape={tensor.shape}, strides={tensor.stride()}) as the kernel input"
+                    ) from err
+                if has_overlap:
+                    raise ValueError(
+                        f"Cannot use an overlapping tensor "
+                        f"(shape={tensor.shape}, strides={tensor.stride()}, "
+                        f"overlap={has_overlap}) as the kernel input"
+                    )
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
'''
if isinstance(tensor, torch.Tensor) and tensor._base is not None and not tensor.is_contiguous():
base_tensor = tensor._base.as_strided(tensor._base.shape, tensor.stride())
if torch._debug_has_internal_overlap(base_tensor):
raise ValueError(f"Cannot use an overlapping tensor"
f"(shape={tensor.shape}, strides={tensor.stride()}, "
f"overlap={torch._debug_has_internal_overlap(base_tensor)}) as the kernel input")
'''
if (
isinstance(tensor, torch.Tensor)
and tensor._base is not None
and not tensor.is_contiguous()
):
try:
has_overlap = torch._debug_has_internal_overlap(tensor)
except RuntimeError as err:
raise ValueError(
f"Cannot use an overlapping tensor "
f"(shape={tensor.shape}, strides={tensor.stride()}) as the kernel input"
) from err
if has_overlap:
raise ValueError(
f"Cannot use an overlapping tensor "
f"(shape={tensor.shape}, strides={tensor.stride()}, "
f"overlap={has_overlap}) as the kernel input"
)

tensor_list.append(tensor)

# Convert tensor pointers to C void pointers for kernel call
Expand Down