Skip to content

Add FP6-LLM doc and move FP6-LLM to prototype #358

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jun 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@ To learn more try out our APIs, you can check out API examples in
4. [Bleeding Edge Kernels](./torchao/prototype/) for experimental kernels without backwards compatibility guarantees
- [GaLore](https://github.com/pytorch/ao/tree/main/torchao/prototype/galore) for memory efficient finetuning
- [fused HQQ Gemm Kernel](https://github.com/pytorch/ao/tree/main/torchao/prototype/hqq) for compute bound workloads
- [FP6-LLM](torchao/prototype/fp6_llm) mixed matmul FP16 x FP6 kernel for io bound workloads

## Our Goals

Expand Down
2 changes: 1 addition & 1 deletion benchmarks/benchmark_fp6_llm.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import torch
from torch import nn
from torchao.quantization.fp6_llm import Fp6LlmLinear, from_tc_float6_e3m2
from torchao.prototype.fp6_llm.fp6_llm import Fp6LlmLinear, from_tc_float6_e3m2
from torch.utils.benchmark import Timer
import pandas as pd
from tqdm import tqdm
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
parametrize,
run_tests,
)
from torchao.quantization.fp6_llm import (
from torchao.prototype.fp6_llm.fp6_llm import (
to_tc_float6_e3m2,
from_tc_float6_e3m2,
_to_tc_float6_e3m2_ref,
Expand Down
12 changes: 6 additions & 6 deletions test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from torch.testing._internal.common_utils import TestCase, IS_FBCODE
from torch.testing._internal.optests import opcheck
import torchao
from torchao.quantization.fp6_llm import from_tc_float6_e3m2
from torchao.prototype.fp6_llm.fp6_llm import from_tc_float6_e3m2
import unittest
from parameterized import parameterized
import pytest
Expand All @@ -26,27 +26,27 @@ def _create_fp6_inputs(self, BS: int, OC: int, IC: int, device):
return fp6_weight.to(device), fp16_scale.to(device), fp16_activation.to(device)

@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
def test_fp16act_fp6weight_linear(self):
def test_fp6_llm_linear(self):
BS = 2
OC = 256
IC = 256
splitK = 1
fp6_weight, fp16_scale, fp16_activation = self._create_fp6_inputs(BS, OC, IC, "cuda")

# smoke test
torchao.ops.fp16act_fp6weight_linear(fp16_activation, fp6_weight, fp16_scale, splitK)
torchao.ops.fp6_llm_linear(fp16_activation, fp6_weight, fp16_scale, splitK)

# comprehensive testing
test_utils = ["test_schema", "test_autograd_registration", "test_faketensor", "test_aot_dispatch_dynamic"]
opcheck(torch.ops.torchao.fp16act_fp6weight_linear, (fp16_activation, fp6_weight, fp16_scale, splitK), test_utils=test_utils)
opcheck(torch.ops.torchao.fp6_llm_linear, (fp16_activation, fp6_weight, fp16_scale, splitK), test_utils=test_utils)

# adapted from https://github.com/usyd-fsalab/fp6_llm/blob/main/tests/python/kernel_test.py
@parameterized.expand([(1, 2048, 4096, 5), (2, 8192, 8192, 6)])
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
def test_fp6_matmul_correctness(self, BS, OC, IC, splitK):
def test_fp6_llm_linear_correctness(self, BS, OC, IC, splitK):
fp6_weight, fp16_scale, fp16_activation = self._create_fp6_inputs(BS, OC, IC, "cuda")

results_fp6 = torchao.ops.fp16act_fp6weight_linear(fp16_activation, fp6_weight, fp16_scale, splitK)
results_fp6 = torchao.ops.fp6_llm_linear(fp16_activation, fp6_weight, fp16_scale, splitK)

fp16_weight = from_tc_float6_e3m2(fp6_weight.view(torch.uint8), dtype=torch.float16) * fp16_scale[:, None]
results_fp16 = fp16_activation @ fp16_weight.T
Expand Down
2 changes: 1 addition & 1 deletion torchao/csrc/cuda/fp6_llm/fp6_linear.cu
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ torch::Tensor fp6_linear_forward_cuda(torch::Tensor _in_feats,
}

TORCH_LIBRARY_IMPL(torchao, CUDA, m) {
m.impl("torchao::fp16act_fp6weight_linear", &fp6_linear_forward_cuda);
m.impl("torchao::fp6_llm_linear", &fp6_linear_forward_cuda);
}

} // namespace torchao
2 changes: 1 addition & 1 deletion torchao/csrc/fp6_llm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,5 @@

TORCH_LIBRARY_FRAGMENT(torchao, m) {
m.impl_abstract_pystub("torchao.ops");
m.def("fp16act_fp6weight_linear(Tensor _in_feats, Tensor _weights, Tensor _scales, int splitK) -> Tensor");
m.def("fp6_llm_linear(Tensor _in_feats, Tensor _weights, Tensor _scales, int splitK) -> Tensor");
}
6 changes: 3 additions & 3 deletions torchao/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ def decorator(func):
return decorator


def fp16act_fp6weight_linear(_in_feats: Tensor, _weights: Tensor, _scales: Tensor, splitK: int = 1) -> Tensor:
def fp6_llm_linear(_in_feats: Tensor, _weights: Tensor, _scales: Tensor, splitK: int = 1) -> Tensor:
"""
FP6-LLM linear layer A @ W.T. See https://arxiv.org/abs/2401.14112 for more details.

Expand All @@ -25,10 +25,10 @@ def fp16act_fp6weight_linear(_in_feats: Tensor, _weights: Tensor, _scales: Tenso
Returns
output of linear layer
"""
return torch.ops.torchao.fp16act_fp6weight_linear.default(_in_feats, _weights, _scales, splitK)
return torch.ops.torchao.fp6_llm_linear.default(_in_feats, _weights, _scales, splitK)


@register_custom_op("torchao::fp16act_fp6weight_linear")
@register_custom_op("torchao::fp6_llm_linear")
def _(_in_feats, _weights, _scales, splitK = 1):
torch._check(_in_feats.dim() == 2, lambda: f"input should be a 2d tensor, got {_in_feats.dim()}D")
torch._check(_in_feats.dtype is torch.float16, lambda: f"weight must be FP16, got {_in_feats.dtype}")
Expand Down
1 change: 1 addition & 0 deletions torchao/prototype/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
- `galore` - fused kernels for memory-efficient pre-training / fine-tuning per the [GaLore algorithm](https://arxiv.org/abs/2403.03507)
- `galore/kernels` - `triton` kernels that fuse various steps of the `GaLore` algorithm
- `galore/docs` - implementation notes and discussion of issues faced in kernel design.
- [`fp6_llm`](fp6_llm) - FP16 x FP6 mixed matmul kernel per [FP6-LLM](https://arxiv.org/abs/2401.14112)

#### Roadmap

Expand Down
44 changes: 44 additions & 0 deletions torchao/prototype/fp6_llm/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
# FP6-LLM

This is a FP16 x FP6 mixed matmul kernel optimized for io bound workloads per [FP6-LLM](https://arxiv.org/abs/2401.14112). The actual CUDA kernel is located under [csrc/cuda/fp6_llm/](../../csrc/cuda/fp6_llm/). This module provides helper functions to quantize FP32 weights to FP6 and facility to convert existing models to FP6.

## Usage

```python
from torchao.prototype.fp6_llm import convert_fp6_llm

model = ...
convert_fp6_llm(model) # convert model in-place, replacing nn.Linear modules with Fp6LlmLinear

# fully compatible with torch.compile()
model.compile(mode="max-autotune", fullgraph=True)
```

It's also possible to pre-process the weight and call the kernel directly.

```python
import torch
from torchao.prototype.fp6_llm import to_scaled_tc_float6_e3m2
from torchao.ops import fp6_llm_linear

fp32_weight = torch.randn(1024, 512).cuda()

# pre-process the weight. this will quantize the weight to FP6 and pack it in a special
# layout for tensor cores. refer to paper for more details.
fp6_weight, scales = to_scaled_tc_float6_e3m2(fp32_weight)

fp16_act = torch.randn(1, 512).cuda().half()
outputs = fp6_llm_linear(fp16_act, fp6_weight, scales) # shape (1, 1024)
```

## TODO

- [ ] Compile CUDA kernel for Windows
- [ ] Merge FP5 from upstream

## Credits

Credits to FP6-LLM authors

- Paper: https://arxiv.org/abs/2401.14112
- Code: https://github.com/usyd-fsalab/fp6_llm
1 change: 1 addition & 0 deletions torchao/prototype/fp6_llm/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .fp6_llm import Fp6LlmLinear, convert_fp6_llm, to_scaled_tc_float6_e3m2
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from torch import nn, Tensor
from torchao.prototype.mx_formats.custom_cast import f32_to_f6_e3m2_unpacked, f6_e3m2_unpacked_to_f32
from torchao.prototype.mx_formats.constants import F6_E3M2_MAX
from torchao.ops import fp16act_fp6weight_linear
from torchao.ops import fp6_llm_linear


def _pack_2bit(x: Tensor) -> Tensor:
Expand Down Expand Up @@ -273,7 +273,7 @@ def __init__(self, weight: Tensor, scales: Tensor, bias: Optional[Tensor] = None

def forward(self, x: Tensor) -> Tensor:
splitK = self.get_split_k(math.prod(x.shape[:-1]), self.out_features)
out = fp16act_fp6weight_linear(x.view(-1, self.in_features).half(), self.weight, self.scales, splitK=splitK)
out = fp6_llm_linear(x.view(-1, self.in_features).half(), self.weight, self.scales, splitK=splitK)
if self.bias is not None:
out = out + self.bias
return out.view(*x.shape[:-1], self.out_features).to(x.dtype)
Expand Down
Loading