Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fused DoRA kernels #216

Merged
merged 16 commits into from
May 7, 2024
Merged

Fused DoRA kernels #216

merged 16 commits into from
May 7, 2024

Conversation

jeromeku
Copy link
Collaborator

@jeromeku jeromeku commented May 5, 2024

Fused DoRA Kernels

Fused DoRA layer implementation that reduces number of individual kernels from ~10 -> 5.

Contents

Background

DoRA (weight-decomposed low-rank adaptation) is a variant of LoRA that decomposes the LoRA update into magnitude and vector components.

The DoRA layer is roughly as follows:

    dora_out = (x @ base_weight.T + lora_out) * magnitude_scale

where:

    lora_out = lora_B(lora_A(x))
    magnitude_scale = magnitude_vector / (base_weight + lora_B.weight @ lora_A.weight).norm(p=2, dim=1)
  • lora_A and lora_B are linear layers with weight shapes rank x in_features and out_features x rank.
  • base_weight is the weight of the frozen linear layer of shape out_features x in_features.
  • magnitude_vector is initialized as the columnwise 2-norm of the frozen weight (shape out-features).
  • x are the inputs of shape batch_size x seqlen x in_features

Optimization

After initial profiling, and as outlined above, the DoRA update layer requires multiple kernels.

In order of compute intensity:

  • 4 GEMMs:
    • x @ base_weight
    • lora_B(lora_A(x))
    • lora_B.weight @ lora_A.weight
  • 1 Reduction: 2-norm
  • 4 Elementwise: matrix-matrix additions (2) and broadcasted matrix-vector multiplications (2).

While torch.compile (and CUDA graphs) can partially mitigate the overhead of multiple small kernels and improve compute efficiency of individual kernels, there remains room for additional optimization by reordering the computations to facilitate fusions, and more importantly, exploiting the unique shapes of the GEMMs, thereby decreasing the number of kernel launches and increasing the compute intensity of each kernel.

Key Contributions

1 - Small K Fused Kernel

Note that the lora_B.weight @ lora_A.weight has a specific shape, where K << {M, N}. That is, lora_B.weight is out_features x lora_rank and lora_A.weight is lora_rank x in_features.

Since lora_rank is typically < 64 while {in,out}-features are typically > 4096 (e.g., Llama MLP / QKV projections), this GEMM is inefficient, since each CTA loads a block, only to perform a few MAC iterations given small K.

Moreover, note that the result of this GEMM is not needed -- we only need the 2-norm of this computation.

Combining these two observations, we can write a fused kernel where:

  1. Each CTA computes an entire row of the output matrix, with the key assumption that BLOCK_K = K. That is, each CTA does a single MAC iteration to compute a BLOCK_M x BLOCK_N output, then iterates across dimension N.
  2. Since each block processes an entire row, we can now additionally fuse a grid-wise reduction along axis=1 into the kernel. In this case, we can directly fold the 2-norm computation into the GEMM.
  3. As an added bonus, we can also include the base_weight elementwise addition and magnitude_vector multiplication into the GEMM epilogue.

Altogether, this allows us to fuse the following computation into a single kernel:

    magnitude_scale = magnitude_vector / (base_weight + lora_B.weight @ lora_A.weight).norm(p=2, dim=1)

2 - Fused Epilogue GEMM

Additionally, instead of computing the base layer output before the DoRA / LoRA updates, we can compute the latter (loRA layer and magnitude_scale) first, and fold these into the epilogue of the base layer GEMM:

    #DoRA / LoRA updates
    lora_out = lora_B(lora_A(x))
    magnitude_scale = magnitude_vector / (base_weight + lora_B.weight @ lora_A.weight).norm(p=2, dim=1)

    #This is now a single kernel
    final_out = (x @ base_weight.T + lora_out) * magnitude_scale

Usage

The fused kernels can be used to implement DoRA / QDoRA layers.

A reference implementation is provided in dora.dora_layer.DoRALinear, which defines a base QDoRA linear layer (with a stub dequantize method) along with corresponding BNBDoRALinear and HQQDoRALinear subclasses, which override dequantize with their respective methods.

Example

    import torch
    from bitsandbytes.nn import Linear4bit
    from torchao.prototypes.dora.dora_layer import BNBDoRALinear

    bs, seqlen = 1, 512
    dtype = torch.float16
    in_features, out_features, lora_rank = 4096, 4096, 16
    x = torch.randn(bs, seqlen, in_features, dtype=dtype, device="cuda")

    #Construct bitsnbytes QDoRA layer
    base_layer = Linear4bit(
            input_features=in_features,
            output_features=out_features,
            bias=False,
            quant_type="nf4",
            compute_dtype=dtype,
        ).cuda()
    base_layer.quant_state.dtype = base_layer.compute_dtype
    dora_layer = BNBDoRALinear(base_layer, lora_rank)

    #Run reference forward pass
    ref = dora_layer.forward(x)

    #Run fused forward pass
    fused_out = dora_layer.forward_fused(x)

See test/test_dora_layer.py and benchmarks/dora_bench.py for more detailed usage.

Also, note that these are reference implementations and are not fully optimized. See Next Steps for follow-up plans.

Tests

See test/dora/test*, for correctness checks of the fused kernels and layers.

Benchmarks

See benchmarks/dora_bench.py.

python benchmarks/dora_bench.py --help

Run with flag --kernel set to one of {dora-colnorm,dora-mm-epilogue}, to benchmark the respective fused kernels against a reference torch / torch.compile implementation, or --kernel=dora-full to bench against the entire DoRA computation.

Additionally, passing either --kernel={dora-bnb, dora-hqq} will bench a reference QDoRA layer against their fused implementations.

Profiling

The reference DoRALinear layer described above also has an instrumented forward pass with annotated regions for each of the DoRA ops.

An example script for running a profiled forward pass is provided in dora/dora_profile.py.

To run with torch.profiler:

python dora_profile.py

which outputs chrome trace to default folder dora_profiles.

To run with nsys:

nsys profile --capture_range=cudaProfilerApi ... python dora_profile.py --profiler=nsys

where ... are other desired nsys options.

Note that --capture_range=cudaProfilerApi is required.

Next Steps

  • Optimize end-to-end DoRA module: torch.compile, re-ordering computations, etc.
  • Implement backwards pass - custom torch.autograd.Function
  • Make compatible with training frameworks and usable in distributed settings (i.e., FSDP-LoRA)
  • Replace custom autotuner with updated triton autotuner
  • Refactor! Lots of repeated profiling / kernel functions across galore, hqq, and dora can now be refactored into single module. Separate PR?

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label May 5, 2024
@msaroufim msaroufim self-requested a review May 6, 2024 03:53
@msaroufim msaroufim merged commit c2657e4 into pytorch:main May 7, 2024
11 of 15 checks passed
HDCharles added a commit that referenced this pull request May 8, 2024
* Composing autoquant with compile

Summary:

this PR rewrites how torchao.autoquant works so that it works with
torch.compile. Previously you had to do:

torchao.autoquant(model, input)
mod=torch.compile(model)
mod(input)

now you can do
torchao.autoquant(torch.compile(model))
model(input)

The new method works with/without compile. Also this is BC so the old
path also works.

We use a forward_prehook to intercept the model call before
torch.compile tracing occurs at which point we do the autoquantization
and clean up all remaining hooks before passing things off to the
normal torch.compile tracing functionality.

note: in the case of multiple inputs, you can also do:

model.forward_log_only(input) to run the model forward with autoquant
shape logging and prevent the torch.compile tracing/autoquant
quantization from occuring.

Test Plan: python test/integration/test_integration.py -k "autoquant"

Reviewers:

Subscribers:

Tasks:

Tags:

* Fused DoRA kernels (#216)

* add dora kernels

* allowing error_on_unseen in autoquant func

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

* Unified AffineQuantizedTensor subclass (#214)

Summary:
Creatd a `AffineQuantizedTensor` subclass that works for both weight and input (for dynamic quantization), for all granularities (levering the recently added choose_qparams_affine, quantize_affine
and dequantize_affine ops)

only verified for 8da4w right now, we can make it work for other types of quantization (mostly the operator dispatching part) later

Test Plan:
python test/quantization/test_quant_api.py -k test_quantized_tensor_subclass_8da4w

Reviewers:

Subscribers:

Tasks:

Tags:

Co-authored-by: Mark Saroufim <marksaroufim@meta.com>

* add expecttest to requirements.txt (#225)

* add expecttest to requirements.txt

* update

* Install dev-requirements.txt in doc build (#224)

Install dev-requirements.txt

---------

Co-authored-by: Mark Saroufim <marksaroufim@meta.com>

* Fix an error in subclass impl (#226)

Summary:
Accidently changed the device check code for old subclass instead of the new one, forgot to fix before landing

Test Plan:
CI

Reviewers:

Subscribers:

Tasks:

Tags:

* update readme.md

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

* trying to fix the error in CI on cleanup hooks

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

* correct docs

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

* Some follow up fixes for quant primitives (#220)

Summary:
att

Test Plan:
python test/quantization/test_quant_primitives.py -k test_raises

Reviewers:

Subscribers:

Tasks:

Tags:

* Composing autoquant with compile

Summary:

this PR rewrites how torchao.autoquant works so that it works with
torch.compile. Previously you had to do:

torchao.autoquant(model, input)
mod=torch.compile(model)
mod(input)

now you can do
torchao.autoquant(torch.compile(model))
model(input)

The new method works with/without compile. Also this is BC so the old
path also works.

We use a forward_prehook to intercept the model call before
torch.compile tracing occurs at which point we do the autoquantization
and clean up all remaining hooks before passing things off to the
normal torch.compile tracing functionality.

note: in the case of multiple inputs, you can also do:

model.forward_log_only(input) to run the model forward with autoquant
shape logging and prevent the torch.compile tracing/autoquant
quantization from occuring.

Test Plan: python test/integration/test_integration.py -k "autoquant"

Reviewers:

Subscribers:

Tasks:

Tags:

* allowing error_on_unseen in autoquant func

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

* update readme.md

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

* trying to fix the error in CI on cleanup hooks

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

* correct docs

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

---------

Co-authored-by: jeromeku <jerome.ku@gmail.com>
Co-authored-by: Jerry Zhang <jerryzh168@gmail.com>
Co-authored-by: Mark Saroufim <marksaroufim@meta.com>
Co-authored-by: Svetlana Karslioglu <svekars@meta.com>
dbyoung18 pushed a commit to dbyoung18/ao that referenced this pull request Jul 31, 2024
dbyoung18 pushed a commit to dbyoung18/ao that referenced this pull request Jul 31, 2024
* Composing autoquant with compile

Summary:

this PR rewrites how torchao.autoquant works so that it works with
torch.compile. Previously you had to do:

torchao.autoquant(model, input)
mod=torch.compile(model)
mod(input)

now you can do
torchao.autoquant(torch.compile(model))
model(input)

The new method works with/without compile. Also this is BC so the old
path also works.

We use a forward_prehook to intercept the model call before
torch.compile tracing occurs at which point we do the autoquantization
and clean up all remaining hooks before passing things off to the
normal torch.compile tracing functionality.

note: in the case of multiple inputs, you can also do:

model.forward_log_only(input) to run the model forward with autoquant
shape logging and prevent the torch.compile tracing/autoquant
quantization from occuring.

Test Plan: python test/integration/test_integration.py -k "autoquant"

Reviewers:

Subscribers:

Tasks:

Tags:

* Fused DoRA kernels (pytorch#216)

* add dora kernels

* allowing error_on_unseen in autoquant func

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

* Unified AffineQuantizedTensor subclass (pytorch#214)

Summary:
Creatd a `AffineQuantizedTensor` subclass that works for both weight and input (for dynamic quantization), for all granularities (levering the recently added choose_qparams_affine, quantize_affine
and dequantize_affine ops)

only verified for 8da4w right now, we can make it work for other types of quantization (mostly the operator dispatching part) later

Test Plan:
python test/quantization/test_quant_api.py -k test_quantized_tensor_subclass_8da4w

Reviewers:

Subscribers:

Tasks:

Tags:

Co-authored-by: Mark Saroufim <marksaroufim@meta.com>

* add expecttest to requirements.txt (pytorch#225)

* add expecttest to requirements.txt

* update

* Install dev-requirements.txt in doc build (pytorch#224)

Install dev-requirements.txt

---------

Co-authored-by: Mark Saroufim <marksaroufim@meta.com>

* Fix an error in subclass impl (pytorch#226)

Summary:
Accidently changed the device check code for old subclass instead of the new one, forgot to fix before landing

Test Plan:
CI

Reviewers:

Subscribers:

Tasks:

Tags:

* update readme.md

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

* trying to fix the error in CI on cleanup hooks

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

* correct docs

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

* Some follow up fixes for quant primitives (pytorch#220)

Summary:
att

Test Plan:
python test/quantization/test_quant_primitives.py -k test_raises

Reviewers:

Subscribers:

Tasks:

Tags:

* Composing autoquant with compile

Summary:

this PR rewrites how torchao.autoquant works so that it works with
torch.compile. Previously you had to do:

torchao.autoquant(model, input)
mod=torch.compile(model)
mod(input)

now you can do
torchao.autoquant(torch.compile(model))
model(input)

The new method works with/without compile. Also this is BC so the old
path also works.

We use a forward_prehook to intercept the model call before
torch.compile tracing occurs at which point we do the autoquantization
and clean up all remaining hooks before passing things off to the
normal torch.compile tracing functionality.

note: in the case of multiple inputs, you can also do:

model.forward_log_only(input) to run the model forward with autoquant
shape logging and prevent the torch.compile tracing/autoquant
quantization from occuring.

Test Plan: python test/integration/test_integration.py -k "autoquant"

Reviewers:

Subscribers:

Tasks:

Tags:

* allowing error_on_unseen in autoquant func

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

* update readme.md

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

* trying to fix the error in CI on cleanup hooks

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

* correct docs

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

---------

Co-authored-by: jeromeku <jerome.ku@gmail.com>
Co-authored-by: Jerry Zhang <jerryzh168@gmail.com>
Co-authored-by: Mark Saroufim <marksaroufim@meta.com>
Co-authored-by: Svetlana Karslioglu <svekars@meta.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants