diff --git a/.github/workflows/wheels.yaml b/.github/workflows/wheels.yaml index da60507..cb674f2 100644 --- a/.github/workflows/wheels.yaml +++ b/.github/workflows/wheels.yaml @@ -17,7 +17,7 @@ jobs: matrix: os: ['ubuntu-20.04'] python-version: ['3.8', '3.9', '3.10', '3.11'] - pytorch-version: ['2.3.1'] + pytorch-version: ['2.4.0'] cuda-version: ['11.8', '12.1'] steps: diff --git a/flute/__init__.py b/flute/__init__.py index 778fb98..7ac099b 100644 --- a/flute/__init__.py +++ b/flute/__init__.py @@ -7,7 +7,7 @@ from . import _C from . import ops -__version__ = "0.0.6" +__version__ = "0.0.7" QGEMM_SIMPLE_TYPE = Callable[ [ diff --git a/flute/integrations/base.py b/flute/integrations/base.py index 60ecbf5..ba001d4 100644 --- a/flute/integrations/base.py +++ b/flute/integrations/base.py @@ -10,11 +10,14 @@ from accelerate.hooks import ( ModelHook, add_hook_to_module) +from bitsandbytes.nn import ( + Linear4bit as BNBLinear4bit) from typing import Optional, Dict import flute import flute.utils import flute.nf_utils +import flute.integrations.bitsandbytes def get_accelerate_hook(name: str, module: torch.nn.Module, allow: bool) -> Optional[ModelHook]: @@ -42,11 +45,18 @@ def prepare_model_flute( group_size: int, fake: bool, handle_hooks: bool = False, + prepare_bnb_layers: bool = True, + default_bnb_dtype: Optional[torch.dtype] = None, custom_scales_dict: Optional[Dict[str, torch.Tensor]] = None, ) -> None: warnings.warn(f"Quantization always happen on 1st GPU") + # BNB layers always assume the unquantized weights are in FP32, regardless + # of the actual dtype of the weights. Hence we cannot infer the dtype. + if default_bnb_dtype is None: + default_bnb_dtype = torch.float16 + def _replace_linear(_name: str, _module: torch.nn.Module) -> None: for child_name, child in _module.named_children(): @@ -54,10 +64,22 @@ def _replace_linear(_name: str, _module: torch.nn.Module) -> None: if isinstance(child, torch.nn.Linear): - if child.weight.dtype not in [torch.float16, torch.bfloat16]: - raise NotImplementedError + if isinstance(child, BNBLinear4bit): + if child.weight.dtype not in [torch.uint8]: + raise NotImplementedError + if prepare_bnb_layers is False: + raise ValueError + if num_bits != 4: + raise ValueError + if group_size != child.weight.quant_state.blocksize: + raise ValueError + else: + if child.weight.dtype not in [torch.float16, torch.bfloat16]: + raise NotImplementedError if fake is True: + if isinstance(child, BNBLinear4bit): + raise NotImplementedError # we primarily use the fake quantization to # check the outputs of the quantized model new_weight = flute.nf_utils.nf_quantize_2( @@ -85,6 +107,14 @@ def _replace_linear(_name: str, _module: torch.nn.Module) -> None: # the replacement will remove the accelerate hooks maybe_hook = get_accelerate_hook(child_name, child, allow=True) + if not isinstance(child, BNBLinear4bit): + flute_dtype = child.weight.dtype + else: + flute_dtype = child.weight.quant_state.dtype + if flute_dtype == torch.float32: + flute_dtype = default_bnb_dtype + warnings.warn(f"BNB's `dtype` is `torch.float32`, changed to `{flute_dtype}`") + setattr( _module, child_name, @@ -95,18 +125,24 @@ def _replace_linear(_name: str, _module: torch.nn.Module) -> None: group_size=group_size, bias=(child.bias is not None), device=child.weight.device, - dtype=child.weight.dtype)) + dtype=flute_dtype)) if custom_scales_dict is not None: custom_scales = custom_scales_dict[child_full_name] else: custom_scales = None - _, _Q, scales, qmap = flute.nf_utils.nf_quantize( - W=child.weight.to(device="cuda"), - num_bits=num_bits, - group_size=group_size, - custom_scales=custom_scales) + if not isinstance(child, BNBLinear4bit): + _, _Q, scales, qmap = flute.nf_utils.nf_quantize( + W=child.weight.to(device="cuda"), + num_bits=num_bits, + group_size=group_size, + custom_scales=custom_scales) + else: + _Q, scales, qmap = flute.integrations.bitsandbytes.convert_BNBLinear4bit( + bnb_module=child, + verify=True) + Q = flute.utils.pack( _Q.T.contiguous(), num_bits=num_bits, @@ -122,6 +158,8 @@ def _replace_linear(_name: str, _module: torch.nn.Module) -> None: new_child.scales.copy_(scales) new_child.tables.copy_(qmap) new_child.tables2.copy_(qmap2) + if new_child.bias is not None: + new_child.bias.copy_(child.bias) # add the accelerate hook back if handle_hooks is True: @@ -165,8 +203,6 @@ def __init__( raise NotImplementedError if not isinstance(device, torch.device): raise NotImplementedError - if bias: - raise NotImplementedError super().__init__() @@ -190,9 +226,13 @@ def __init__( self.register_buffer("scales", torch.ones((N, G), dtype=dtype, device=device)) self.register_buffer("tables", tables) self.register_buffer("tables2", flute.utils.make_qmap2_from_qmap(tables)) + if bias: + self.bias = torch.nn.Parameter(torch.empty(out_features, **factory_kwargs)) + else: + self.register_parameter("bias", None) def forward(self, inputs: torch.Tensor) -> torch.Tensor: - return flute.qgemm_simple( + output = flute.qgemm_simple( inputs, self.weight, self.scales, @@ -203,6 +243,11 @@ def forward(self, inputs: torch.Tensor) -> torch.Tensor: self.group_size, ) + if self.bias is not None: + output.add_(self.bias) # In-place add + + return output + def extra_repr(self) -> str: return (f"in_features={self.in_features}, " f"out_features={self.out_features}, " diff --git a/flute/integrations/bitsandbytes.py b/flute/integrations/bitsandbytes.py new file mode 100644 index 0000000..4541206 --- /dev/null +++ b/flute/integrations/bitsandbytes.py @@ -0,0 +1,91 @@ +import torch +from typing import Tuple +from bitsandbytes.nn import ( + Linear4bit as BNBLinear4bit) +from bitsandbytes.functional import ( + dequantize_4bit, + dequantize_blockwise) + + +def convert_BNBLinear4bit( + bnb_module: BNBLinear4bit, + verify: bool = True, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + + if not isinstance(bnb_module, BNBLinear4bit): + raise TypeError + + bnb_qweight = bnb_module.weight + bnb_quant_state = bnb_module.weight.quant_state + bnb_quant_table = bnb_module.weight.quant_state.code + bnb_quant_dtype = bnb_module.weight.quant_state.dtype + + if not all([ + bnb_qweight.ndim == 2, + bnb_qweight.shape[1] == 1, + bnb_qweight.bnb_quantized is True, + bnb_qweight.requires_grad is False, + bnb_qweight.dtype == torch.uint8, + bnb_qweight.quant_storage == torch.uint8, + bnb_qweight.blocksize == bnb_quant_state.blocksize, + bnb_qweight.quant_type == bnb_quant_state.quant_type, + bnb_qweight.compress_statistics == bnb_quant_state.nested, + bnb_module.quant_state is bnb_quant_state]): + raise NotImplementedError + + # unpacked quantized weights + qweight = torch.cat([ + (bnb_qweight.data >> 4) & 0b1111, + (bnb_qweight.data >> 0) & 0b1111], dim=1) + qweight = qweight.view( + bnb_quant_state.shape) + + # get the scales + if bnb_quant_state.nested: + scales = dequantize_blockwise( + A=bnb_quant_state.absmax, + quant_state=bnb_quant_state.state2) + scales = scales + bnb_quant_state.offset + else: + scales = bnb_quant_state.absmax + + # convert to the correct dtype + if scales.dtype != bnb_quant_dtype: + scales_casted = scales.to(dtype=bnb_quant_dtype) + else: + scales_casted = scales + + if bnb_quant_table.dtype != bnb_quant_dtype: + bnb_quant_table_casted = bnb_quant_table.to(dtype=bnb_quant_dtype) + else: + bnb_quant_table_casted = bnb_quant_table + + if not all([ + scales.ndim == 1, + scales.dtype == torch.float32, + scales_casted.dtype == bnb_quant_dtype, + bnb_quant_table.dtype == torch.float32, + bnb_quant_table_casted.dtype == bnb_quant_dtype]): + raise ValueError + + # double check that the conversion is lossless + if verify is True: + broadcasted_scales = ( + scales + .unsqueeze(dim=-1) + .expand(scales.shape[0], bnb_quant_state.blocksize) + .reshape(qweight.shape)) + weight = ( + # `dequantize_4bit` function always performs dequantization in FP16 + bnb_quant_table[qweight.to(dtype=torch.int)] * + broadcasted_scales).to(dtype=bnb_quant_dtype) + weight_ = dequantize_4bit( + A=bnb_qweight, + quant_state=bnb_quant_state, + # unused + blocksize=bnb_quant_state.blocksize, + quant_type=bnb_quant_state.quant_type) + if not (weight == weight_).all(): + raise ValueError + + return qweight, scales_casted, bnb_quant_table_casted diff --git a/requirements.txt b/requirements.txt index dfb1cca..73e47af 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,3 +5,4 @@ matplotlib transformers accelerate vllm >= 0.5.3.post1 +bitsandbytes diff --git a/tests/kernel.py b/tests/kernel.py index b1a307d..e9a5168 100644 --- a/tests/kernel.py +++ b/tests/kernel.py @@ -100,7 +100,7 @@ def test_integer( if identity is True: if equal is not True: - raise ValueError(message) + click.secho(message, bg="red") else: if dtype == torch.float16: threshold = FP16_ERROR_THRESHOLD @@ -126,10 +126,11 @@ def test_integer( "error": error, "error_": error_, } - torch.save(data_to_save, f"{message}.pth") + # torch.save(data_to_save, f"{message}.pth") click.secho(message, fg="red") +@torch.no_grad() def run_tests(num: int) -> None: for index in range(num): torch.manual_seed(index)