-
Notifications
You must be signed in to change notification settings - Fork 690
[GPU] Support int4 in cuDNN GEMM fusions. #33794
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Closed
Closed
+30
−7
Conversation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
f609d3d to
e1b8dc7
Compare
dimitar-asenov
approved these changes
Nov 11, 2025
copybara-service bot
pushed a commit
that referenced
this pull request
Nov 11, 2025
Imported from GitHub PR #33794 📝 Summary of Changes Support int4 in cuDNN GEMM fusions. 🎯 Justification Accelerates some int4 GEMM fusions (under the flag xla_gpu_cudnn_gemm_fusion_level). 🚀 Kind of Contribution ⚡️ Performance Improvement 📊 Benchmark (for Performance Improvements) > Please measure and include speedups for one of the public HLOs in `compiler/xla/tools/benchmarks/hlo/`. These do not use int4. 🧪 Unit Tests: yes 🧪 Execution Tests: yes Copybara import of the project: -- e1b8dc7 by Ilia Sergachev <isergachev@nvidia.com>: [GPU] Support int4 in cuDNN GEMM fusions. Merging this change closes #33794 FUTURE_COPYBARA_INTEGRATE_REVIEW=#33794 from openxla:cudnn_gemm_int4 e1b8dc7 PiperOrigin-RevId: 830894321
copybara-service bot
pushed a commit
to tensorflow/tensorflow
that referenced
this pull request
Nov 11, 2025
Imported from GitHub PR openxla/xla#33794 📝 Summary of Changes Support int4 in cuDNN GEMM fusions. 🎯 Justification Accelerates some int4 GEMM fusions (under the flag xla_gpu_cudnn_gemm_fusion_level). 🚀 Kind of Contribution ⚡️ Performance Improvement 📊 Benchmark (for Performance Improvements) > Please measure and include speedups for one of the public HLOs in `compiler/xla/tools/benchmarks/hlo/`. These do not use int4. 🧪 Unit Tests: yes 🧪 Execution Tests: yes Copybara import of the project: -- e1b8dc7daff4963b93152d2a5c81c4d91a9f14d8 by Ilia Sergachev <isergachev@nvidia.com>: [GPU] Support int4 in cuDNN GEMM fusions. Merging this change closes #33794 FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#33794 from openxla:cudnn_gemm_int4 e1b8dc7daff4963b93152d2a5c81c4d91a9f14d8 PiperOrigin-RevId: 830894321
copybara-service bot
pushed a commit
that referenced
this pull request
Nov 12, 2025
Imported from GitHub PR #33794 📝 Summary of Changes Support int4 in cuDNN GEMM fusions. 🎯 Justification Accelerates some int4 GEMM fusions (under the flag xla_gpu_cudnn_gemm_fusion_level). 🚀 Kind of Contribution ⚡️ Performance Improvement 📊 Benchmark (for Performance Improvements) > Please measure and include speedups for one of the public HLOs in `compiler/xla/tools/benchmarks/hlo/`. These do not use int4. 🧪 Unit Tests: yes 🧪 Execution Tests: yes Copybara import of the project: -- e1b8dc7 by Ilia Sergachev <isergachev@nvidia.com>: [GPU] Support int4 in cuDNN GEMM fusions. Merging this change closes #33794 FUTURE_COPYBARA_INTEGRATE_REVIEW=#33794 from openxla:cudnn_gemm_int4 e1b8dc7 PiperOrigin-RevId: 830894321
copybara-service bot
pushed a commit
to tensorflow/tensorflow
that referenced
this pull request
Nov 12, 2025
Imported from GitHub PR openxla/xla#33794 📝 Summary of Changes Support int4 in cuDNN GEMM fusions. 🎯 Justification Accelerates some int4 GEMM fusions (under the flag xla_gpu_cudnn_gemm_fusion_level). 🚀 Kind of Contribution ⚡️ Performance Improvement 📊 Benchmark (for Performance Improvements) > Please measure and include speedups for one of the public HLOs in `compiler/xla/tools/benchmarks/hlo/`. These do not use int4. 🧪 Unit Tests: yes 🧪 Execution Tests: yes Copybara import of the project: -- e1b8dc7daff4963b93152d2a5c81c4d91a9f14d8 by Ilia Sergachev <isergachev@nvidia.com>: [GPU] Support int4 in cuDNN GEMM fusions. Merging this change closes #33794 FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#33794 from openxla:cudnn_gemm_int4 e1b8dc7daff4963b93152d2a5c81c4d91a9f14d8 PiperOrigin-RevId: 830894321
copybara-service bot
pushed a commit
to tensorflow/tensorflow
that referenced
this pull request
Nov 12, 2025
Imported from GitHub PR openxla/xla#33794 📝 Summary of Changes Support int4 in cuDNN GEMM fusions. 🎯 Justification Accelerates some int4 GEMM fusions (under the flag xla_gpu_cudnn_gemm_fusion_level). 🚀 Kind of Contribution ⚡️ Performance Improvement 📊 Benchmark (for Performance Improvements) > Please measure and include speedups for one of the public HLOs in `compiler/xla/tools/benchmarks/hlo/`. These do not use int4. 🧪 Unit Tests: yes 🧪 Execution Tests: yes Copybara import of the project: -- e1b8dc7daff4963b93152d2a5c81c4d91a9f14d8 by Ilia Sergachev <isergachev@nvidia.com>: [GPU] Support int4 in cuDNN GEMM fusions. Merging this change closes #33794 PiperOrigin-RevId: 831264661
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
📝 Summary of Changes
Support int4 in cuDNN GEMM fusions.
🎯 Justification
Accelerates some int4 GEMM fusions (under the flag xla_gpu_cudnn_gemm_fusion_level).
🚀 Kind of Contribution
⚡️ Performance Improvement
📊 Benchmark (for Performance Improvements)
These do not use int4.
🧪 Unit Tests:
yes
🧪 Execution Tests:
yes