Skip to content

Commit

Permalink
added bitsandbytes support
Browse files Browse the repository at this point in the history
  • Loading branch information
HanGuo97 committed Aug 26, 2024
1 parent 84dc870 commit f22ecd3
Show file tree
Hide file tree
Showing 6 changed files with 153 additions and 15 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/wheels.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ jobs:
matrix:
os: ['ubuntu-20.04']
python-version: ['3.8', '3.9', '3.10', '3.11']
pytorch-version: ['2.3.1']
pytorch-version: ['2.4.0']
cuda-version: ['11.8', '12.1']

steps:
Expand Down
2 changes: 1 addition & 1 deletion flute/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from . import _C
from . import ops

__version__ = "0.0.6"
__version__ = "0.0.7"

QGEMM_SIMPLE_TYPE = Callable[
[
Expand Down
67 changes: 56 additions & 11 deletions flute/integrations/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,14 @@
from accelerate.hooks import (
ModelHook,
add_hook_to_module)
from bitsandbytes.nn import (
Linear4bit as BNBLinear4bit)
from typing import Optional, Dict

import flute
import flute.utils
import flute.nf_utils
import flute.integrations.bitsandbytes


def get_accelerate_hook(name: str, module: torch.nn.Module, allow: bool) -> Optional[ModelHook]:
Expand Down Expand Up @@ -42,22 +45,41 @@ def prepare_model_flute(
group_size: int,
fake: bool,
handle_hooks: bool = False,
prepare_bnb_layers: bool = True,
default_bnb_dtype: Optional[torch.dtype] = None,
custom_scales_dict: Optional[Dict[str, torch.Tensor]] = None,
) -> None:

warnings.warn(f"Quantization always happen on 1st GPU")

# BNB layers always assume the unquantized weights are in FP32, regardless
# of the actual dtype of the weights. Hence we cannot infer the dtype.
if default_bnb_dtype is None:
default_bnb_dtype = torch.float16

def _replace_linear(_name: str, _module: torch.nn.Module) -> None:
for child_name, child in _module.named_children():

child_full_name = f"{_name}.{child_name}"

if isinstance(child, torch.nn.Linear):

if child.weight.dtype not in [torch.float16, torch.bfloat16]:
raise NotImplementedError
if isinstance(child, BNBLinear4bit):
if child.weight.dtype not in [torch.uint8]:
raise NotImplementedError
if prepare_bnb_layers is False:
raise ValueError
if num_bits != 4:
raise ValueError
if group_size != child.weight.quant_state.blocksize:
raise ValueError
else:
if child.weight.dtype not in [torch.float16, torch.bfloat16]:
raise NotImplementedError

if fake is True:
if isinstance(child, BNBLinear4bit):
raise NotImplementedError
# we primarily use the fake quantization to
# check the outputs of the quantized model
new_weight = flute.nf_utils.nf_quantize_2(
Expand Down Expand Up @@ -85,6 +107,14 @@ def _replace_linear(_name: str, _module: torch.nn.Module) -> None:
# the replacement will remove the accelerate hooks
maybe_hook = get_accelerate_hook(child_name, child, allow=True)

if not isinstance(child, BNBLinear4bit):
flute_dtype = child.weight.dtype
else:
flute_dtype = child.weight.quant_state.dtype
if flute_dtype == torch.float32:
flute_dtype = default_bnb_dtype
warnings.warn(f"BNB's `dtype` is `torch.float32`, changed to `{flute_dtype}`")

setattr(
_module,
child_name,
Expand All @@ -95,18 +125,24 @@ def _replace_linear(_name: str, _module: torch.nn.Module) -> None:
group_size=group_size,
bias=(child.bias is not None),
device=child.weight.device,
dtype=child.weight.dtype))
dtype=flute_dtype))

if custom_scales_dict is not None:
custom_scales = custom_scales_dict[child_full_name]
else:
custom_scales = None

_, _Q, scales, qmap = flute.nf_utils.nf_quantize(
W=child.weight.to(device="cuda"),
num_bits=num_bits,
group_size=group_size,
custom_scales=custom_scales)
if not isinstance(child, BNBLinear4bit):
_, _Q, scales, qmap = flute.nf_utils.nf_quantize(
W=child.weight.to(device="cuda"),
num_bits=num_bits,
group_size=group_size,
custom_scales=custom_scales)
else:
_Q, scales, qmap = flute.integrations.bitsandbytes.convert_BNBLinear4bit(
bnb_module=child,
verify=True)

Q = flute.utils.pack(
_Q.T.contiguous(),
num_bits=num_bits,
Expand All @@ -122,6 +158,8 @@ def _replace_linear(_name: str, _module: torch.nn.Module) -> None:
new_child.scales.copy_(scales)
new_child.tables.copy_(qmap)
new_child.tables2.copy_(qmap2)
if new_child.bias is not None:
new_child.bias.copy_(child.bias)

# add the accelerate hook back
if handle_hooks is True:
Expand Down Expand Up @@ -165,8 +203,6 @@ def __init__(
raise NotImplementedError
if not isinstance(device, torch.device):
raise NotImplementedError
if bias:
raise NotImplementedError

super().__init__()

Expand All @@ -190,9 +226,13 @@ def __init__(
self.register_buffer("scales", torch.ones((N, G), dtype=dtype, device=device))
self.register_buffer("tables", tables)
self.register_buffer("tables2", flute.utils.make_qmap2_from_qmap(tables))
if bias:
self.bias = torch.nn.Parameter(torch.empty(out_features, **factory_kwargs))
else:
self.register_parameter("bias", None)

def forward(self, inputs: torch.Tensor) -> torch.Tensor:
return flute.qgemm_simple(
output = flute.qgemm_simple(
inputs,
self.weight,
self.scales,
Expand All @@ -203,6 +243,11 @@ def forward(self, inputs: torch.Tensor) -> torch.Tensor:
self.group_size,
)

if self.bias is not None:
output.add_(self.bias) # In-place add

return output

def extra_repr(self) -> str:
return (f"in_features={self.in_features}, "
f"out_features={self.out_features}, "
Expand Down
91 changes: 91 additions & 0 deletions flute/integrations/bitsandbytes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
import torch
from typing import Tuple
from bitsandbytes.nn import (
Linear4bit as BNBLinear4bit)
from bitsandbytes.functional import (
dequantize_4bit,
dequantize_blockwise)


def convert_BNBLinear4bit(
bnb_module: BNBLinear4bit,
verify: bool = True,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:

if not isinstance(bnb_module, BNBLinear4bit):
raise TypeError

bnb_qweight = bnb_module.weight
bnb_quant_state = bnb_module.weight.quant_state
bnb_quant_table = bnb_module.weight.quant_state.code
bnb_quant_dtype = bnb_module.weight.quant_state.dtype

if not all([
bnb_qweight.ndim == 2,
bnb_qweight.shape[1] == 1,
bnb_qweight.bnb_quantized is True,
bnb_qweight.requires_grad is False,
bnb_qweight.dtype == torch.uint8,
bnb_qweight.quant_storage == torch.uint8,
bnb_qweight.blocksize == bnb_quant_state.blocksize,
bnb_qweight.quant_type == bnb_quant_state.quant_type,
bnb_qweight.compress_statistics == bnb_quant_state.nested,
bnb_module.quant_state is bnb_quant_state]):
raise NotImplementedError

# unpacked quantized weights
qweight = torch.cat([
(bnb_qweight.data >> 4) & 0b1111,
(bnb_qweight.data >> 0) & 0b1111], dim=1)
qweight = qweight.view(
bnb_quant_state.shape)

# get the scales
if bnb_quant_state.nested:
scales = dequantize_blockwise(
A=bnb_quant_state.absmax,
quant_state=bnb_quant_state.state2)
scales = scales + bnb_quant_state.offset
else:
scales = bnb_quant_state.absmax

# convert to the correct dtype
if scales.dtype != bnb_quant_dtype:
scales_casted = scales.to(dtype=bnb_quant_dtype)
else:
scales_casted = scales

if bnb_quant_table.dtype != bnb_quant_dtype:
bnb_quant_table_casted = bnb_quant_table.to(dtype=bnb_quant_dtype)
else:
bnb_quant_table_casted = bnb_quant_table

if not all([
scales.ndim == 1,
scales.dtype == torch.float32,
scales_casted.dtype == bnb_quant_dtype,
bnb_quant_table.dtype == torch.float32,
bnb_quant_table_casted.dtype == bnb_quant_dtype]):
raise ValueError

# double check that the conversion is lossless
if verify is True:
broadcasted_scales = (
scales
.unsqueeze(dim=-1)
.expand(scales.shape[0], bnb_quant_state.blocksize)
.reshape(qweight.shape))
weight = (
# `dequantize_4bit` function always performs dequantization in FP16
bnb_quant_table[qweight.to(dtype=torch.int)] *
broadcasted_scales).to(dtype=bnb_quant_dtype)
weight_ = dequantize_4bit(
A=bnb_qweight,
quant_state=bnb_quant_state,
# unused
blocksize=bnb_quant_state.blocksize,
quant_type=bnb_quant_state.quant_type)
if not (weight == weight_).all():
raise ValueError

return qweight, scales_casted, bnb_quant_table_casted
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@ matplotlib
transformers
accelerate
vllm >= 0.5.3.post1
bitsandbytes
5 changes: 3 additions & 2 deletions tests/kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def test_integer(

if identity is True:
if equal is not True:
raise ValueError(message)
click.secho(message, bg="red")
else:
if dtype == torch.float16:
threshold = FP16_ERROR_THRESHOLD
Expand All @@ -126,10 +126,11 @@ def test_integer(
"error": error,
"error_": error_,
}
torch.save(data_to_save, f"{message}.pth")
# torch.save(data_to_save, f"{message}.pth")
click.secho(message, fg="red")


@torch.no_grad()
def run_tests(num: int) -> None:
for index in range(num):
torch.manual_seed(index)
Expand Down

0 comments on commit f22ecd3

Please sign in to comment.