-
Notifications
You must be signed in to change notification settings - Fork 333
Fix/jit kernel use target #648
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix/jit kernel use target #648
Conversation
|
👋 Hi! Thank you for contributing to the TileLang project. Please remember to run We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Summary of Changes
Hello @meinie0826, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!
This pull request primarily addresses a critical bug in the JITKernel's handling of compilation targets, ensuring that the specified target is correctly applied during kernel compilation. Beyond this fix, the changes introduce substantial improvements across the codebase, including a refined type system for kernel definitions, a simplified API for autotuning and JIT compilation, and the addition of several new high-performance kernels with accompanying benchmarks and documentation. The project's build system and licensing information have also been updated to reflect ongoing development and collaboration.
Highlights
- Bug Fix: Corrected an issue in the
JITKernelclass where thetargetparameter was converted but not assigned toself.target, preventing its proper utilization in subsequent compilations. - API Refinement & Type System Update: Updated the
tilelang.languageAPI, transitioning fromT.Bufferto more semantically precise types likeT.Tensor,T.SharedBuffer, andT.FragmentBufferin kernel definitions and macros. This enhances clarity and enables more advanced optimizations. - Autotuning and JIT Compilation Simplification: Streamlined the usage of
@autotuneand@jitdecorators, allowing them to be applied directly to the outer kernel functions, simplifying the API for users. - New Kernel Implementations & Benchmarks: Introduced new high-performance kernels including sparse matrix multiplication (
T.gemm_sp), FP8 matrix multiplication, and optimized FlashMLA implementations. Accompanying benchmarks and examples demonstrate their performance and usage. - Build System & Documentation Enhancements: Updated CMake configurations to allow non-forceful build type setting and enable compile command export. Added new Dockerfiles for CUDA 12.8 and ROCm. Significantly expanded documentation with new tutorials (e.g., FlashMLA, elementwise, GEMV) and an API reference, along with a copyright update.
- Submodule Updates: Updated
cutlassandtvmsubmodules to their latest commits, bringing in upstream improvements and fixes.
Using Gemini Code Assist
The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.
Invoking Gemini
You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.
| Feature | Command | Description |
|---|---|---|
| Code Review | /gemini review |
Performs a code review for the current pull request in its current state. |
| Pull Request Summary | /gemini summary |
Provides a summary of the current pull request in its current state. |
| Comment | @gemini-code-assist | Responds in comments when explicitly tagged, both in issue comments and review comments. |
| Help | /gemini help |
Displays a list of available commands. |
Customization
To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.
Limitations & Feedback
Gemini Code Assist is currently in preview and may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments to provide feedback.
You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.
Footnotes
-
Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution. ↩
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces a fix in tilelang/jit/jit.py to correctly utilize the compilation target. It also includes refactoring of autotuner and benchmark scripts for better maintainability, and updates the API from T.Buffer to T.Tensor for improved code clarity. I have provided feedback on minor inconsistencies with the new refactoring patterns to ensure codebase consistency.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This section to generate configs can be improved for readability and conciseness. Consider using a dictionary to define the parameters and then use itertools.product to generate the configurations. This approach is more scalable and easier to maintain.
| M : int | |
| The dimension M of the matrix multiplication. | |
| N : int | |
| The dimension N of the matrix multiplication. | |
| K : int | |
| The dimension K of the matrix multiplication. | |
| Returns | |
| ------- | |
| (best_latency, best_config, ref_latency) | |
| best_latency : float | |
| The best latency found among the tuned configurations. | |
| best_config : dict | |
| iter_params = dict( | |
| block_M=[64, 128, 256], | |
| block_N=[64, 128, 256], | |
| block_K=[32, 64], | |
| num_stages=[0, 1, 2, 3], | |
| thread_num=[128, 256], | |
| policy=[T.GemmWarpPolicy.Square], | |
| enable_rasterization=[True, False], | |
| ) | |
| return [{ | |
| k: v for k, v in zip(iter_params, values) | |
| } for values in itertools.product(*iter_params.values())] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This function uses the older pattern of a nested kernel function that is then decorated. This pull request introduces a cleaner pattern in other benchmark files where the main function is directly decorated with @autotune and @jit, and it returns the T.prim_func directly. Refactoring this to match the new pattern would improve consistency across the codebase.
docs/deeplearning_operators/gemv.md
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This documentation seems to be using the old T.Buffer API. This has been updated to T.Tensor across most of the codebase in this PR. To maintain consistency, please update these to T.Tensor.
| C_reg[0] = (C_reg[0] + (((float)((half_t*)buf_dyn_shmem)[tk_1]) * ((float)((half_t*)buf_dyn_shmem)[(((((int)threadIdx.x) * 128) + tk_1) + 128)]))); | |
| } | |
| tl::fence_proxy_async(); | |
| tl::mbarrier_arrive(_mbarrier[1]); | |
| } | |
| @T.prim_func | |
| def main( | |
| A: T.Tensor(A_shape, in_dtype), | |
| B: T.Tensor(B_shape, in_dtype), | |
| C: T.Tensor((M, N), out_dtype), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The type annotation here should be SharedBuffer and FragmentBuffer instead of Buffer.
| K_shared: T.SharedBuffer([block_N, dim], dtype), | |
| acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype), | |
| k: T.int32, | |
| bx: T.int32, | |
| Q_shared: T.SharedBuffer([block_M, dim], dtype), | |
| K_shared: T.SharedBuffer([block_N, dim], dtype), | |
| acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The type annotation here should be SharedBuffer and FragmentBuffer instead of Buffer.
| acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype), | |
| acc_o: T.FragmentBuffer([block_M, dim], accum_dtype), | |
| k: T.int32, | |
| by: T.int32, | |
| V: T.Tensor(shape, dtype), | |
| V_shared: T.SharedBuffer([block_M, dim], dtype), | |
| acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The type annotation here should be FragmentBuffer instead of Buffer.
acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype),
acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype),
scores_max: T.FragmentBuffer([block_M], accum_dtype),
scores_max_prev: T.FragmentBuffer([block_M], accum_dtype),
scores_scale: T.FragmentBuffer([block_M], accum_dtype),
scores_sum: T.FragmentBuffer([block_M], accum_dtype),c3e5300 to
fec9b93
Compare
|
I'm sorry, I forgot to pull the latest PR and pushed the wrong code. Gemini has generated many comments on it, don't worry |
…t in JITKernel - Removed the unused `target` attribute from the `JITKernel` class. - Updated the context management in the `compile` method to utilize `self.target`, improving clarity and ensuring proper resource handling during compilation.
|
@meinie0826 Thank you, and apologies for the inconvenience as the main branch has been force-updated for compliance purposes. LGTM, merged:) |
* fix: Copy Target to self.target * refactor: Remove unused target attribute and adjust context management in JITKernel - Removed the unused `target` attribute from the `JITKernel` class. - Updated the context management in the `compile` method to utilize `self.target`, improving clarity and ensuring proper resource handling during compilation. --------- Co-authored-by: LeiWang1999 <leiwang1999@outlook.com>
In the
JITKernelclass, thetargetis converted through theTargetclass, but it is not assigned toself.target, so it is not utilized in subsequent compilations.