Skip to content

Conversation

@LeiWang1999
Copy link
Member

This pull request includes several changes to improve the functionality and structure of the tilelang project. The most important changes include the addition of new test cases for the ctypes kernel, the introduction of a new kernel adapter, and enhancements to the ctypes adapter functionality.

Enhancements to testing:

  • Added new test cases for the ctypes kernel, including run_ctypes_kernel_do_bench, run_ctypes_kernel_multi_stream, and run_ctypes_dynamic_shape in testing/python/jit/test_tilelang_jit_gemm_ctypes.py. These tests ensure that the ctypes kernel performs correctly under various conditions and configurations.

New kernel adapter:

  • Introduced the CythonKernelAdapter in tilelang/jit/adapter/__init__.py, expanding the project's capability to handle different types of kernel adapters.

Enhancements to ctypes adapter:

  • Enhanced the CtypesKernelAdapter in tilelang/jit/adapter/ctypes/adapter.py by adding detailed class-level documentation, handling dynamic shapes, and improving the initialization process. These changes make the adapter more robust and easier to use. [1] [2]
  • Modified the get_dynamic_symbolic_set method in tilelang/jit/adapter/ctypes/wrapper.py to use a list instead of a set for dynamic symbols, ensuring the order is preserved.
  • Improved the legalize_c method in tilelang/jit/adapter/ctypes/wrapper.py to correctly handle the CUDA kernel launch string.

Minor fixes and improvements:

  • Updated the package_data in setup.py to include *pyx files, ensuring all necessary files are packaged.
  • Removed unused imports and reorganized the import statements in testing/python/jit/test_tilelang_jit_gemm_ctypes.py for better readability and maintainability. [1] [2]

…stream execution

- Enhance CtypesKernelAdapter to handle dynamic symbolic shapes
- Add support for multi-stream kernel execution in CTypes backend
- Implement dynamic shape handling in test_tilelang_jit_gemm_ctypes.py
- Add symbolic shape utility function in tilelang.language
- Update profiler to improve flexibility in benchmark selection
- Remove unnecessary `thread_binding` line in GEMM kernel functions
- Clean up code in `examples/gemm/README.md` and `testing/python/kernel/test_tilelang_kernel_int4_gemm_mma.py`
- Enhance code readability by removing redundant thread binding annotation
- Correct indentation for function calls in `test_tilelang_kernel_int4_gemm_mma.py`
- Remove extra indentation in `mma_emitter.ldmatrix_a()` and `mma_emitter.ldmatrix_b()` calls
- Improve code formatting for better readability
…stream execution

- Implement CythonKernelAdapter to handle dynamic symbolic shapes
- Add support for multi-stream kernel execution in Cython backend
- Create comprehensive test suite for Cython GEMM kernel in test_tilelang_jit_gemm_cython.py
- Update JITKernel to include "cython" as a valid execution backend
- Add Cython-specific wrapper and library generation modules
- Update .gitignore to exclude Cython cache directory
- Modify setup.py to include Cython source files in package data
…lation

- Add new `compile()` function in tilelang/jit/__init__.py as a wrapper for JITKernel
- Update multiple test files and examples to use `tilelang.compile()` instead of `tilelang.JITKernel()`
- Modify kernel adapters to support optional kernel-only source retrieval
- Update `__init__.py` to import the new `compile()` function
- Improve kernel source retrieval for different execution backends
- Introduce new `tilelang/contrib/cc.py` module with cross-platform C/C++ compiler utilities
- Add functions to detect and retrieve system C/C++ compilers
- Implement cross-compilation and shared library creation support
- Update Cython JIT kernel to validate C++ compiler availability
- Modify Cython adapter to use detected C++ compiler for library generation
- Move float8_dtype_map inside adapt_torch2tvm function
- Simplify global scope by localizing the dtype mapping
- Maintain existing functionality for converting torch float8 tensors to TVM ndarray
- Move float8_dtype_map inside adapt_torch2tvm function
- Simplify global scope by localizing the dtype mapping
- Maintain existing functionality for converting torch float8 tensors to TVM ndarray
- Add `get_cython_compiler()` function to dynamically locate Cython executable
- Update Cython adapter to use detected Cython compiler instead of hardcoded command
- Raise an exception if no Cython compiler is found
- Update requirements.txt to specify minimum PyTorch version (>=2.2.0)
- Update stream parameter type to int64_t for better compatibility
- Directly use torch.cuda.current_stream().cuda_stream instead of casting
- Improve type safety and precision in Cython kernel wrapper
@LeiWang1999 LeiWang1999 merged commit 2e53bd0 into tile-ai:main Feb 21, 2025
3 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant