-
Notifications
You must be signed in to change notification settings - Fork 333
[Enhancement] Enable runtime tensor data type validation #146
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Conversation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
- Update debug_print_buffer_value template specialization for unsigned char - Modify test_tilelang_debug_print.py to include additional dtype tests - Add test case for uint8 dtype in debug print buffer function
- Improve code formatting for debug_print_buffer_value template specialization - Adjust line breaks and indentation for better readability - Maintain consistent code style with other template specializations
- Move map_torch_type function from multiple test files to a centralized location - Import map_torch_type from tilelang.utils.tensor in kernel test files - Improve code reusability by creating a shared utility function for type mapping
- Introduce buffer_dtype_map in CythonKernelAdapter to track buffer variable dtypes - Add _process_buffer_dtype method to extract dtype information from TIR function - Update CythonKernelWrapper to support setting and validating buffer dtypes - Enhance type checking during kernel execution with dtype verification - Improve logging message for Cython JIT adapter compilation
Member
Author
Traceback (most recent call last):
File "/root/tilelang/debug/quickstart.py", line 61, in <module>
c = jit_kernel(a, b)
File "/root/tilelang/tilelang/jit/kernel.py", line 106, in __call__
return self.torch_function(*args, **kwds)
File "/root/tilelang/tilelang/jit/adapter/cython/adapter.py", line 270, in lambda_forward
return self.cython_wrapper.forward([*args], stream=stream)
File "tilelang/jit/adapter/cython/cython_wrapper.pyx", line 39, in tilelang.jit.adapter.cython.cython_wrapper.CythonKernelWrapper.forward
cpdef forward(self, list inputs, int64_t stream = -1):
File "tilelang/jit/adapter/cython/cython_wrapper.pyx", line 88, in tilelang.jit.adapter.cython.cython_wrapper.CythonKernelWrapper.forward
raise ValueError(f"Buffer dtype mismatch for parameter {param}: expected {torch_dtype}, got {tensor_list[buffer_idx].dtype}")
ValueError: Buffer dtype mismatch for parameter A: expected torch.float16, got torch.float32 |
- Introduce static_shape_map in CythonKernelAdapter to track buffer variable static shapes - Add _process_static_shape method to extract static shape information from TIR function - Update CythonKernelWrapper to support setting and validating static shapes - Enhance type checking during kernel execution with static shape verification
Member
Author
|
Also support shape test, but I'm curious about it's runtime overhead of type checking. |
Member
Author
|
Kernel latency: 0.03315455 |
2 tasks
- Implement comprehensive test for Multi-Head Attention backward pass - Support both causal and non-causal attention scenarios - Add reference implementation for comparing kernel outputs - Test different batch sizes, head counts, sequence lengths, and head dimensions - Verify forward and backward pass correctness using torch.testing.assert_close
- Add random seed initialization for consistent test reproducibility - Use tilelang.testing.set_random_seed(42) to ensure deterministic test results
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
This pull request primarily focuses on refactoring the code to improve the handling of tensor data types and simplifying the codebase by removing redundant functions. The most important changes include importing a common utility function, removing duplicated code, and enhancing the Cython JIT adapter to manage buffer data types.
Refactoring and code simplification:
testing/python/kernel/test_tilelang_kernel_bf16_gemm_mma.py,testing/python/kernel/test_tilelang_kernel_fp8_gemm_mma.py,testing/python/kernel/test_tilelang_kernel_fp8_gemv_simt.py,testing/python/kernel/test_tilelang_kernel_gemm_mma_intrinsic.py,testing/python/kernel/test_tilelang_kernel_gemv_simt.py: Imported themap_torch_typefunction fromtilelang.utils.tensorand removed local definitions of the same function. [1] [2] [3] [4] [5] [6] [7] [8] [9] [10]Enhancements to Cython JIT adapter:
tilelang/jit/adapter/cython/adapter.py: Addedbuffer_dtype_mapto theCythonKernelAdapterclass to map buffer variables to their corresponding data types and updated the initialization process to include this mapping. [1] [2] [3] [4] [5]tilelang/jit/adapter/cython/cython_wrapper.pyx: Updated theCythonKernelWrapperclass to include a buffer dtype map and added methods to set this map. Also, added validation to check for buffer dtype mismatches during kernel execution. [1] [2]Utility function enhancement:
tilelang/utils/tensor.py: Added type annotation to themap_torch_typefunction to specify the return type astorch.dtype.