Skip to content

Conversation

@LeiWang1999
Copy link
Member

This pull request primarily focuses on refactoring the code to improve the handling of tensor data types and simplifying the codebase by removing redundant functions. The most important changes include importing a common utility function, removing duplicated code, and enhancing the Cython JIT adapter to manage buffer data types.

Refactoring and code simplification:

  • testing/python/kernel/test_tilelang_kernel_bf16_gemm_mma.py, testing/python/kernel/test_tilelang_kernel_fp8_gemm_mma.py, testing/python/kernel/test_tilelang_kernel_fp8_gemv_simt.py, testing/python/kernel/test_tilelang_kernel_gemm_mma_intrinsic.py, testing/python/kernel/test_tilelang_kernel_gemv_simt.py: Imported the map_torch_type function from tilelang.utils.tensor and removed local definitions of the same function. [1] [2] [3] [4] [5] [6] [7] [8] [9] [10]

Enhancements to Cython JIT adapter:

Utility function enhancement:

  • tilelang/utils/tensor.py: Added type annotation to the map_torch_type function to specify the return type as torch.dtype.

- Update debug_print_buffer_value template specialization for unsigned char
- Modify test_tilelang_debug_print.py to include additional dtype tests
- Add test case for uint8 dtype in debug print buffer function
- Improve code formatting for debug_print_buffer_value template specialization
- Adjust line breaks and indentation for better readability
- Maintain consistent code style with other template specializations
- Move map_torch_type function from multiple test files to a centralized location
- Import map_torch_type from tilelang.utils.tensor in kernel test files
- Improve code reusability by creating a shared utility function for type mapping
- Introduce buffer_dtype_map in CythonKernelAdapter to track buffer variable dtypes
- Add _process_buffer_dtype method to extract dtype information from TIR function
- Update CythonKernelWrapper to support setting and validating buffer dtypes
- Enhance type checking during kernel execution with dtype verification
- Improve logging message for Cython JIT adapter compilation
@LeiWang1999
Copy link
Member Author

Traceback (most recent call last):
  File "/root/tilelang/debug/quickstart.py", line 61, in <module>
    c = jit_kernel(a, b)
  File "/root/tilelang/tilelang/jit/kernel.py", line 106, in __call__
    return self.torch_function(*args, **kwds)
  File "/root/tilelang/tilelang/jit/adapter/cython/adapter.py", line 270, in lambda_forward
    return self.cython_wrapper.forward([*args], stream=stream)
  File "tilelang/jit/adapter/cython/cython_wrapper.pyx", line 39, in tilelang.jit.adapter.cython.cython_wrapper.CythonKernelWrapper.forward
    cpdef forward(self, list inputs, int64_t stream = -1):
  File "tilelang/jit/adapter/cython/cython_wrapper.pyx", line 88, in tilelang.jit.adapter.cython.cython_wrapper.CythonKernelWrapper.forward
    raise ValueError(f"Buffer dtype mismatch for parameter {param}: expected {torch_dtype}, got {tensor_list[buffer_idx].dtype}")
ValueError: Buffer dtype mismatch for parameter A: expected torch.float16, got torch.float32

- Introduce static_shape_map in CythonKernelAdapter to track buffer variable static shapes
- Add _process_static_shape method to extract static shape information from TIR function
- Update CythonKernelWrapper to support setting and validating static shapes
- Enhance type checking during kernel execution with static shape verification
@LeiWang1999
Copy link
Member Author

Also support shape test, but I'm curious about it's runtime overhead of type checking.

@LeiWang1999
Copy link
Member Author

Kernel latency: 0.03315455
RT latency - with runtime check: 0.03686400130391121
RT latency - without runtime check: 0.03686400130391121

- Implement comprehensive test for Multi-Head Attention backward pass
- Support both causal and non-causal attention scenarios
- Add reference implementation for comparing kernel outputs
- Test different batch sizes, head counts, sequence lengths, and head dimensions
- Verify forward and backward pass correctness using torch.testing.assert_close
- Add random seed initialization for consistent test reproducibility
- Use tilelang.testing.set_random_seed(42) to ensure deterministic test results
@LeiWang1999 LeiWang1999 merged commit c7d1966 into tile-ai:main Mar 5, 2025
2 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant