- 
                Notifications
    You must be signed in to change notification settings 
- Fork 352
BF16 support for Quant-LLM kernel #1147
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
          
     Merged
      
        
      
    
  
     Merged
                    Changes from all commits
      Commits
    
    
            Show all changes
          
          
            30 commits
          
        
        Select commit
          Hold shift + click to select a range
      
      f50d8d7
              
                Add FP6 benchmark option to use BF16
              
              
                tobiasvanderwerff a714377
              
                Change dequant bit-shifting logic for BF16
              
              
                tobiasvanderwerff 5af3b7e
              
                Modify dequant + tensor core ops for bf16
              
              
                tobiasvanderwerff 125f17c
              
                Template progress
              
              
                tobiasvanderwerff b3c3be0
              
                Modify fpx quant logic to include bf16
              
              
                tobiasvanderwerff f828763
              
                Add tests for FP6 BF16
              
              
                tobiasvanderwerff ff2c6e8
              
                Use type punning for large exponent multiplication
              
              
                tobiasvanderwerff 4304dcc
              
                Fix some TODOs
              
              
                tobiasvanderwerff 2d00a3a
              
                Remove option to add exponent bias directly to the exponent bits
              
              
                tobiasvanderwerff ceaed34
              
                Reformat
              
              
                tobiasvanderwerff b532c51
              
                Cleanup
              
              
                tobiasvanderwerff e89274b
              
                Fix alignment
              
              
                tobiasvanderwerff ac0fbe0
              
                Remove templated input type whenever possible
              
              
                tobiasvanderwerff c1dce42
              
                Remove templated input type whenever possible 2
              
              
                tobiasvanderwerff 4546c8b
              
                Remove templated input type whenever possible 3
              
              
                tobiasvanderwerff bba42cf
              
                Less hacky way to construct a float with a large exponent
              
              
                tobiasvanderwerff e66395e
              
                rtol=1e-2 instead of 1e-3 for bfloat16 test
              
              
                tobiasvanderwerff 7e9350e
              
                Guards for SM75
              
              
                tobiasvanderwerff 401559f
              
                Remove redundant `__CUDA_ARCH` guards in host code
              
              
                tobiasvanderwerff 5d52e5b
              
                Fix consistency in checking for `CUDA_ARCH` versions
              
              
                tobiasvanderwerff 398da5b
              
                Update docs
              
              
                tobiasvanderwerff d38490f
              
                Make float bias a constexpr
              
              
                tobiasvanderwerff 11ac84b
              
                Update docs more
              
              
                tobiasvanderwerff 7bd2833
              
                Fix SM75 support
              
              
                tobiasvanderwerff 69e901d
              
                Compile guard for sm<75
              
              
                tobiasvanderwerff 8747d6d
              
                Check for CUDA synchronous errors after kernel launch
              
              
                tobiasvanderwerff 59f5eb7
              
                Updated compile guard
              
              
                tobiasvanderwerff c96cf18
              
                Fix problematic usage of `__CUDA_ARCH__`
              
              
                tobiasvanderwerff 379bd5e
              
                Fix incorrect CUDA error handling
              
              
                tobiasvanderwerff a6de35a
              
                Make the kernel fail for sm75 + bfloat16 inputs
              
              
                tobiasvanderwerff File filter
Filter by extension
Conversations
          Failed to load comments.   
        
        
          
      Loading
        
  Jump to
        
          Jump to file
        
      
      
          Failed to load files.   
        
        
          
      Loading
        
  Diff view
Diff view
There are no files selected for viewing
  
    
      This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
      Learn more about bidirectional Unicode characters
    
  
  
    
              
              
  
    
      This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
      Learn more about bidirectional Unicode characters
    
  
  
    
              
  
    
      This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
      Learn more about bidirectional Unicode characters
    
  
  
    
              
  
    
      This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
      Learn more about bidirectional Unicode characters
    
  
  
    
              | Original file line number | Diff line number | Diff line change | 
|---|---|---|
| @@ -1,7 +1,7 @@ | ||
| # FP6-LLM kernel | ||
|  | ||
| This kernel is adapted from https://github.com/usyd-fsalab/fp6_llm. It performs linear op (A @ W.T), where A is in FP16 and W is in FP6 (E3M2 without infinities and NaN). | ||
| This kernel is adapted from https://github.com/usyd-fsalab/fp6_llm. It performs linear op (A @ W.T), where A is in FP16 or BF16 and W is in FP6 (E3M2 without infinities and NaN). | ||
|  | ||
| On most hardware, this kernel is faster than FP16 linear for batch size from 1 to 128, and slower for batch size larger than or equal to 256. See https://github.com/usyd-fsalab/fp6_llm/issues/8 for a detailed discussion. | ||
|  | ||
| See https://github.com/pytorch/ao/pull/223 for some benchmark results. | ||
| See https://github.com/pytorch/ao/pull/223 and and https://github.com/pytorch/ao/pull/1147 for some benchmark results. | 
      
      Oops, something went wrong.
        
    
  
      
      Oops, something went wrong.
        
    
  
  Add this suggestion to a batch that can be applied as a single commit.
  This suggestion is invalid because no changes were made to the code.
  Suggestions cannot be applied while the pull request is closed.
  Suggestions cannot be applied while viewing a subset of changes.
  Only one suggestion per line can be applied in a batch.
  Add this suggestion to a batch that can be applied as a single commit.
  Applying suggestions on deleted lines is not supported.
  You must change the existing code in this line in order to create a valid suggestion.
  Outdated suggestions cannot be applied.
  This suggestion has been applied or marked resolved.
  Suggestions cannot be applied from pending reviews.
  Suggestions cannot be applied on multi-line comments.
  Suggestions cannot be applied while the pull request is queued to merge.
  Suggestion cannot be applied right now. Please check back later.
  
    
  
    
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just curious. I saw that generally when BF16 is used, tolerance is quite higher than FP16. From your experience working on this, you do suspect any part of the code might result in this loss of precision? e.g. perhaps some parts are computed in BF16 instead of FP32. Or maybe it's just the way it is.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
All I know is that BF16 has fewer bits for the fraction (mantissa) than FP16 (10 bits vs. 7 bits), so that leads to lower precision for BF16 compared to FP16. I can't think of any part of the FP6 kernel that would inherently lead to more loss of precision for BF16.