Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fused HQQ Quantization Gemm #153

Merged
merged 19 commits into from
Apr 25, 2024
Merged

Fused HQQ Quantization Gemm #153

merged 19 commits into from
Apr 25, 2024

Conversation

jeromeku
Copy link
Collaborator

@jeromeku jeromeku commented Apr 22, 2024

@msaroufim

Fused int4 / fp16 Quant Matmul

Fused kernel that combines asymmetric dequantization and gemm. Useful primarily for compute-bound (M > 16) scenarios and not for memory-bound / inference scenarios.

The kernel fuses two ops:

  • Dequantization: upcasts u4 / s4 weights to float16 / bfloat16, followed by groupwise scaling and shifting by scales / zeropoints
  • GEMM: matmul on dequantized weights and activations.

Tested and benchmarked for HQQ but could theoretically be used for any asymmetric quantization scheme.

NOTE: Benchmark below is only indicative of performance on consumer-grade Ampere GPUs (A6000 specifically). When tested on H100, the performance is on par / marginally worse than native / compiled torch.
The intended use is thus for fine-tuning / training models on non-datacenter GPUs (80 <= compute capability < 90). If interested in optimizing the kernel for other architectures, please drop a note in the CUDA-MODE Discord channel.

Implementation Details

  • Bitpacking is simple row interleave, no need for extensive preprocessing (e.g., tinygemm or fastertransformer)
  • Tested for float16 / bfloat16 activations, scales, and zeros
  • Autotuned for both compute-bound and memory-bound configs
  • Assumes operand B of the gemm is is the quantized type.
  • Requires quantization along in-features, i.e., the K dimension, or axis=1, of torch.linear.weight.
  • Implementation handles both transposed and non-tranposed quantized weights, useful for forward / backward training passes.

Performance

Initial benchmarking (on A6000) demonstrates promising results, scaling well for compute-bound workloads:

M N K group_size dtype hqq_ref triton tinygemm
0 16 4096 4096 128 torch.bfloat16 0.2675 0.0633 0.0382
1 32 4096 4096 128 torch.bfloat16 0.2669 0.0704 0.0649
2 128 4096 4096 128 torch.bfloat16 0.2689 0.0960 0.2523
3 256 4096 4096 128 torch.bfloat16 0.3268 0.1355 0.5192
4 512 4096 4096 128 torch.bfloat16 0.3628 0.2369 1.0892
5 1024 4096 4096 128 torch.bfloat16 0.5133 0.4753 2.2016
  • Times are in ms, see benchmarks/benchmark_hqq.py.
  • hqq_ref is the base HQQ_Linear module that is unfused (dequantization followed by call to torch.matmul).
  • tinygemm calls torch.ops.aten._weight_int4pack_mm. Implementation is a custom HQQLinear layer that wraps the preprocessing necessary for this kernel, adapted from a benchmark script posted by @mobicham from CUDA-mode Discord discussions.

GPU details:

_CudaDeviceProperties(name='NVIDIA RTX A6000', major=8, minor=6, total_memory=48676MB, multi_processor_count=84)

NOTE

This implementation requires triton >= 3.0.0.

  • Running tests / benchmarks requires installation of hqq:

    pip install hqq
    

TODO

  • clean up / refactor test to use pytest
  • add transposed matmul
  • torch.compile benchmarking
  • introduce additional prologue / epilogue fusions
  • create standalone dequant kernel
  • test across range of gpus
  • adapt for other quant methods (other than hqq)

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Apr 22, 2024
@jeromeku jeromeku marked this pull request as draft April 22, 2024 02:53
Copy link
Member

@msaroufim msaroufim left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

first pass

@@ -0,0 +1,134 @@
import torch
from termcolor import colored
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will need to get rid of this dependency for merge, I'm fine with adding adding colors though so something like this should work

RED = "\033[31m"
GREEN = "\033[32m"
YELLOW = "\033[33m"
BLUE = "\033[34m"
MAGENTA = "\033[35m"
CYAN = "\033[36m"
WHITE = "\033[37m"
RESET = "\033[0m"  # Resets the color to default.

name = "Alice"
print(f"{GREEN}Hello, {name}!{RESET}")

from triton import cdiv
import triton.language as tl
from .kernels import mixed_mm_kernel_compute_bound, mixed_mm_kernel_max_autotune
#credit jlebar
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you link to the original code as well? Will need to double check if the LICENSE allows us to copy paste

@@ -0,0 +1,101 @@
import itertools
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

both test and benchmark will require skips if triton is less than 3.0 (which is fine because nightlies now ship with 3.0.0) and if hqq is not installed

For hqq I'm fine if we add it as a dev dependency for now

return ref_time, tt_time, int4_time if dtype == torch.bfloat16 else None


SHAPES = [
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@cpuhrsch I guess these shapes are fine for now but are there some specific shapes we're more interested in tracking on an ongoing basis if so I wish we could just make them part of our benchmark or test utilities



df = pd.DataFrame(data, columns=HEADERS)
df.to_csv("benchmark_triton.csv", index=False)
Copy link
Member

@msaroufim msaroufim Apr 22, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we will lose this csv on CI unless its saved to some github artifact so unless this file is huge let's just print it for now

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

GPU details:

```
_CudaDeviceProperties(name='NVIDIA RTX A6000', major=8, minor=6, total_memory=48676MB, multi_processor_count=84)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

once we figure out the installation issues I'll check to see if results repro on an H100

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

apologies meant pip freeze

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Collaborator Author

@jeromeku jeromeku Apr 24, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@msaroufim

When you run on H100, can you run once with DISABLE_MMA_V3=1? It toggles Hopper specific specializations in triton. Curious to see how performance changes.

@@ -0,0 +1,43 @@
## Fused `int4 / fp16` Quant Matmul

Fused gemm for asymmetric quantized weights. Tested and benchmarked for `HQQ` but could theoretically be used for any asymmetric quantization scheme.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sounds like one of the 2 asymetric should be a symetric?

The kernel packs `u4 / s4` weights and fuses dequantization with the matmul.

- tested for `float16 / bfloat16` activations, scales, and zeros
- autotuned for both compute-bound and io-bound configs
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: could use memory bandwidth bound terminology instead


Fused gemm for asymmetric quantized weights. Tested and benchmarked for `HQQ` but could theoretically be used for any asymmetric quantization scheme.

The kernel packs `u4 / s4` weights and fuses dequantization with the matmul.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

n00b q: whu can't we generically do this with torch.compile @HDCharles

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It does work with torch.compile and there's a good speed-up (up to 4x compared to Pytorch), but a dequantize() CUDA kernel + torch.matmul is a bit faster.
I think the bitpacking should be done in such a way that torch.compile can fully optimize it.



@triton.jit
def _mixed_mm_kernel(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see a lot of similarities between this code and what you had contributed for galore, can we start modularizing?

Copy link
Member

@msaroufim msaroufim left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

first pass

@jeromeku
Copy link
Collaborator Author

@msaroufim

  • Added a transpose implementation for backwards passes
  • Cleaned up test
    • skips unless triton >= 3.0 + hqq
    • refactored to use pytest
    • added transpose test
    • removed colored outputs
  • Benchmark
    • cleaned up imports
    • removed colors
    • print benchmark results to terminal rather than saving as csv
  • Attached my pip env (see reply above). TLDR: torch 2.2.2, triton 3.0.0

@jeromeku jeromeku marked this pull request as ready for review April 24, 2024 20:19
@msaroufim msaroufim self-requested a review April 24, 2024 23:15
@msaroufim
Copy link
Member

msaroufim commented Apr 24, 2024

Results on an H100

  1. tinygemm degrades substantially relative to your triton implementation as the sizes increase
  2. Results are always worst than the hqq implementation

Although it's very curious how different the pattern is on an A6000

(x) [marksaroufim@devvm17057.vll0 ~/ao (hqq_mixed_gemm)]$ python benchmarks/benchmark_hqq.py 
_CudaDeviceProperties(name='NVIDIA H100', major=9, minor=0, total_memory=97320MB, multi_processor_count=132)
shape=[16, 4096, 4096] group_size=128 dtype=torch.bfloat16:
Ref: 0.1194 Triton: 0.1678 Torch int4mm: 0.0248

shape=[32, 4096, 4096] group_size=128 dtype=torch.bfloat16:
Ref: 0.1203 Triton: 0.1567 Torch int4mm: 0.0442

shape=[128, 4096, 4096] group_size=128 dtype=torch.bfloat16:
Ref: 0.1253 Triton: 0.1607 Torch int4mm: 0.1564

shape=[256, 4096, 4096] group_size=128 dtype=torch.bfloat16:
Ref: 0.1269 Triton: 0.1515 Torch int4mm: 0.3025

shape=[512, 4096, 4096] group_size=128 dtype=torch.bfloat16:
Ref: 0.1344 Triton: 0.1477 Torch int4mm: 0.5926

shape=[1024, 4096, 4096] group_size=128 dtype=torch.bfloat16:
Ref: 0.1499 Triton: 0.1585 Torch int4mm: 1.1743

M,N,K,group_size,dtype,ref,triton,tinygemm
16,4096,4096,128,torch.bfloat16,0.11943938583135605,0.1677887886762619,0.024766016751527786
32,4096,4096,128,torch.bfloat16,0.12034595012664795,0.15668845176696777,0.044162582606077194
128,4096,4096,128,torch.bfloat16,0.1253334879875183,0.16072532534599304,0.15640245378017426
256,4096,4096,128,torch.bfloat16,0.1268536001443863,0.1515069156885147,0.3024669885635376
512,4096,4096,128,torch.bfloat16,0.13444344699382782,0.14771266281604767,0.5926073789596558
1024,4096,4096,128,torch.bfloat16,0.14986468851566315,0.15848150849342346,1.1743202209472656

(x) [marksaroufim@devvm17057.vll0 ~/ao (hqq_mixed_gemm)]$ DISABLE_MMA_V3=1 python benchmarks/benchmark_hqq.py 
_CudaDeviceProperties(name='NVIDIA H100', major=9, minor=0, total_memory=97320MB, multi_processor_count=132)
shape=[16, 4096, 4096] group_size=128 dtype=torch.bfloat16:
Ref: 0.1192 Triton: 0.1538 Torch int4mm: 0.0248

shape=[32, 4096, 4096] group_size=128 dtype=torch.bfloat16:
Ref: 0.1202 Triton: 0.1587 Torch int4mm: 0.0441

shape=[128, 4096, 4096] group_size=128 dtype=torch.bfloat16:
Ref: 0.1253 Triton: 0.1532 Torch int4mm: 0.1567

shape=[256, 4096, 4096] group_size=128 dtype=torch.bfloat16:
Ref: 0.1264 Triton: 0.1610 Torch int4mm: 0.3028

shape=[512, 4096, 4096] group_size=128 dtype=torch.bfloat16:
Ref: 0.1334 Triton: 0.1618 Torch int4mm: 0.5928

shape=[1024, 4096, 4096] group_size=128 dtype=torch.bfloat16:
Ref: 0.1496 Triton: 0.1632 Torch int4mm: 1.1754

M,N,K,group_size,dtype,ref,triton,tinygemm
16,4096,4096,128,torch.bfloat16,0.11921614408493042,0.15375947952270508,0.02484913170337677
32,4096,4096,128,torch.bfloat16,0.12021129578351974,0.15867114067077637,0.04411611706018448
128,4096,4096,128,torch.bfloat16,0.12525483965873718,0.15317335724830627,0.1566852480173111
256,4096,4096,128,torch.bfloat16,0.12637671828269958,0.16101239621639252,0.30275848507881165
512,4096,4096,128,torch.bfloat16,0.1334434598684311,0.16184988617897034,0.5928144454956055
1024,4096,4096,128,torch.bfloat16,0.14962367713451385,0.16317297518253326,1.1754025220870972

Copy link
Member

@msaroufim msaroufim left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you!

@msaroufim msaroufim merged commit e148244 into pytorch:main Apr 25, 2024
13 checks passed
dbyoung18 pushed a commit to dbyoung18/ao that referenced this pull request Jul 31, 2024
* add test / benchmark

* add kernels

* update readme

* more readme edits

* edit readme

* add transpose test

* transpose test pass

* refactor test

* add checks for CI

* add more comments for transpose kernel

* remove import in test

* clean up benchmark

* fix test import order

* minor README edits

* additional readme edits

* update readme

* update readme

* add note about cudamode

---------

Co-authored-by: Mark Saroufim <marksaroufim@meta.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants