diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index ea021e874..0b18d9b06 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -607,7 +607,7 @@ def from_dict(cls, qs_dict: Dict[str, Any], device: torch.device) -> 'QuantState qs_dict: based on state_dict, with only relevant keys, striped of prefixes. - item with key `quant_state.bitsandbytes__[nf4/fp4]` may contain minor and non-tensor quant state items. + item with key `quant_state.bitsandbytes__[nf4/fp4]` may contain minor and non-tensor quant state items. """ # unpacking tensor with non-tensor components @@ -802,7 +802,7 @@ def dequantize_blockwise( if quant_state is None: quant_state = QuantState(absmax=absmax, code=code, blocksize=blocksize, dtype=torch.float32) - + absmax = quant_state.absmax if quant_state.nested: absmax = dequantize_blockwise(quant_state.absmax, quant_state.state2) @@ -884,13 +884,13 @@ def get_4bit_type(typename, device=None, blocksize=64): return data.to(device) -def quantize_fp4(A: Tensor, absmax: Tensor = None, out: Tensor = None, blocksize=64, compress_statistics=False): - return quantize_4bit(A, absmax, out, blocksize, compress_statistics, 'fp4') +def quantize_fp4(A: Tensor, absmax: Tensor = None, out: Tensor = None, blocksize=64, compress_statistics=False, quant_storage=torch.uint8): + return quantize_4bit(A, absmax, out, blocksize, compress_statistics, 'fp4', quant_storage) -def quantize_nf4(A: Tensor, absmax: Tensor = None, out: Tensor = None, blocksize=64, compress_statistics=False): - return quantize_4bit(A, absmax, out, blocksize, compress_statistics, 'nf4') +def quantize_nf4(A: Tensor, absmax: Tensor = None, out: Tensor = None, blocksize=64, compress_statistics=False, quant_storage=torch.uint8): + return quantize_4bit(A, absmax, out, blocksize, compress_statistics, 'nf4', quant_storage) -def quantize_4bit(A: Tensor, absmax: Tensor = None, out: Tensor = None, blocksize=64, compress_statistics=False, quant_type='fp4') -> Tensor: +def quantize_4bit(A: Tensor, absmax: Tensor = None, out: Tensor = None, blocksize=64, compress_statistics=False, quant_type='fp4', quant_storage=torch.uint8) -> Tensor: """ Quantize tensor A in blocks of 4-bit values. @@ -903,7 +903,7 @@ def quantize_4bit(A: Tensor, absmax: Tensor = None, out: Tensor = None, blocksiz absmax : torch.Tensor The absmax values. out : torch.Tensor - The output tensor (8-bit). + The output tensor. blocksize : int The blocksize used in quantization. quant_type : str @@ -912,7 +912,7 @@ def quantize_4bit(A: Tensor, absmax: Tensor = None, out: Tensor = None, blocksiz Returns ------- torch.Tensor: - The 8-bit tensor with packed 4-bit values. + Tensor with packed 4-bit values. tuple(torch.Tensor, torch.Size, torch.dtype, int): The quantization state to undo the quantization. """ @@ -931,7 +931,8 @@ def quantize_4bit(A: Tensor, absmax: Tensor = None, out: Tensor = None, blocksiz if out is None: - out = torch.zeros(((n+1)//2, 1), dtype=torch.uint8, device=A.device) + mod = dtype2bytes[quant_storage] * 2 + out = torch.zeros(((n+1)//mod, 1), dtype=quant_storage, device=A.device) assert blocksize in [4096, 2048, 1024, 512, 256, 128, 64] @@ -985,7 +986,7 @@ def dequantize_4bit(A: Tensor, quant_state: QuantState = None, absmax: Tensor = Parameters ---------- A : torch.Tensor - The input 8-bit tensor (packed 4-bit values). + The input tensor (packed 4-bit values). quant_state : QuantState object with quantisation stats, incl. absmax values, original tensor shape and original dtype. absmax : torch.Tensor @@ -1626,7 +1627,7 @@ def gemv_4bit( ldb = ct.c_int32(ldb) ldc = ct.c_int32(ldc) - if B.dtype == torch.uint8: + if B.dtype in [torch.uint8, torch.bfloat16, torch.float16, torch.float32]: if A.dtype == torch.float16: lib.cgemm_4bit_inference_naive_fp16(m, n, k, get_ptr(A), get_ptr(B), get_ptr(absmax), get_ptr(state.code), get_ptr(out), lda, ldb, ldc, ct.c_int32(state.blocksize)) elif A.dtype == torch.bfloat16: diff --git a/bitsandbytes/nn/modules.py b/bitsandbytes/nn/modules.py index 7cce82b91..0b1dc5c6f 100644 --- a/bitsandbytes/nn/modules.py +++ b/bitsandbytes/nn/modules.py @@ -141,8 +141,18 @@ def forward(self, input: Tensor) -> Tensor: class Params4bit(torch.nn.Parameter): - - def __new__(cls, data: Optional[torch.Tensor] = None, requires_grad=True, quant_state: QuantState = None, blocksize: int = 64, compress_statistics: bool = True, quant_type: str = 'fp4') -> "Params4bit": + def __new__( + cls, + data: Optional[torch.Tensor] = None, + requires_grad=True, + quant_state: QuantState = None, + blocksize: int = 64, + compress_statistics: bool = True, + quant_type: str = 'fp4', + quant_storage: torch.dtype = torch.uint8, + module: Optional["Linear4bit"] = None, + bnb_quantized: bool = False + ) -> "Params4bit": if data is None: data = torch.empty(0) @@ -151,7 +161,10 @@ def __new__(cls, data: Optional[torch.Tensor] = None, requires_grad=True, quant_ self.compress_statistics = compress_statistics self.quant_type = quant_type self.quant_state = quant_state + self.quant_storage = quant_storage + self.bnb_quantized = bnb_quantized self.data = data + self.module = module return self @classmethod @@ -162,16 +175,23 @@ def from_prequantized(cls, data: torch.Tensor, quantized_stats: Dict[str, Any], self.blocksize = self.quant_state.blocksize self.compress_statistics = self.quant_state.nested self.quant_type = self.quant_state.quant_type + self.bnb_quantized = True return self - def cuda(self, device): - w = self.data.contiguous().half().cuda(device) - w_4bit, quant_state = bnb.functional.quantize_4bit(w, blocksize=self.blocksize, compress_statistics=self.compress_statistics, quant_type=self.quant_type) + def _quantize(self, device): + w = self.data.contiguous().cuda(device) + w_4bit, quant_state = bnb.functional.quantize_4bit(w, blocksize=self.blocksize, compress_statistics=self.compress_statistics, + quant_type=self.quant_type, quant_storage=self.quant_storage) self.data = w_4bit self.quant_state = quant_state - + if self.module is not None: + self.module.quant_state = quant_state + self.bnb_quantized = True return self + def cuda(self, device: Optional[Union[int, device, str]] = None, non_blocking: bool = False): + return self.to(device='cuda' if device is None else device, non_blocking=non_blocking) + @overload def to(self: T, device: Optional[Union[int, device]] = ..., dtype: Optional[Union[dtype, str]] = ..., non_blocking: bool = ...,) -> T: ... @@ -187,8 +207,8 @@ def to(self: T, tensor: Tensor, non_blocking: bool = ...) -> T: def to(self, *args, **kwargs): device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs) - if (device is not None and device.type == "cuda" and self.data.device.type == "cpu"): - return self.cuda(device) + if (device is not None and device.type == "cuda" and not self.bnb_quantized): + return self._quantize(device) else: if self.quant_state is not None: self.quant_state.to(device) @@ -203,12 +223,14 @@ def to(self, *args, **kwargs): class Linear4bit(nn.Linear): - def __init__(self, input_features, output_features, bias=True, compute_dtype=None, compress_statistics=True, quant_type='fp4', device=None): + def __init__(self, input_features, output_features, bias=True, compute_dtype=None, compress_statistics=True, quant_type='fp4', quant_storage=torch.uint8, device=None): super().__init__(input_features, output_features, bias, device) - self.weight = Params4bit(self.weight.data, requires_grad=False, compress_statistics=compress_statistics, quant_type=quant_type) + self.weight = Params4bit(self.weight.data, requires_grad=False, compress_statistics=compress_statistics, quant_type=quant_type, quant_storage=quant_storage, module=self) # self.persistent_buffers = [] # TODO consider as way to save quant state self.compute_dtype = compute_dtype self.compute_type_is_set = False + self.quant_state = None + self.quant_storage = quant_storage def set_compute_type(self, x): if x.dtype in [torch.float32, torch.bfloat16]: @@ -243,7 +265,15 @@ def forward(self, x: torch.Tensor): self.bias.data = self.bias.data.to(x.dtype) if getattr(self.weight, 'quant_state', None) is None: - print('FP4 quantization state not initialized. Please call .cuda() or .to(device) on the LinearFP4 layer first.') + if getattr(self, 'quant_state', None) is not None: + # the quant state got lost when the parameter got converted. This happens for example for fsdp + # since we registered the module, we can recover the state here + assert self.weight.shape[1] == 1 + if not isinstance(self.weight, Params4bit): + self.weight = Params4bit(self.weight, quant_storage=self.quant_storage) + self.weight.quant_state = self.quant_state + else: + print('FP4 quantization state not initialized. Please call .cuda() or .to(device) on the LinearFP4 layer first.') if not self.compute_type_is_set: self.set_compute_type(x) self.compute_type_is_set = True @@ -261,8 +291,8 @@ def forward(self, x: torch.Tensor): class LinearFP4(Linear4bit): - def __init__(self, input_features, output_features, bias=True, compute_dtype=None, compress_statistics=True, device=None): - super().__init__(input_features, output_features, bias, compute_dtype, compress_statistics, 'fp4', device) + def __init__(self, input_features, output_features, bias=True, compute_dtype=None, compress_statistics=True, quant_storage=torch.uint8, device=None): + super().__init__(input_features, output_features, bias, compute_dtype, compress_statistics, 'fp4', quant_storage, device) class LinearNF4(Linear4bit): @@ -276,8 +306,8 @@ class LinearNF4(Linear4bit): Implementation of the NF4 data type in bitsandbytes can be found in the `create_normal_map` function in the `functional.py` file: https://github.com/TimDettmers/bitsandbytes/blob/main/bitsandbytes/functional.py#L236. ''' - def __init__(self, input_features, output_features, bias=True, compute_dtype=None, compress_statistics=True, device=None): - super().__init__(input_features, output_features, bias, compute_dtype, compress_statistics, 'nf4', device) + def __init__(self, input_features, output_features, bias=True, compute_dtype=None, compress_statistics=True, quant_storage=torch.uint8, device=None): + super().__init__(input_features, output_features, bias, compute_dtype, compress_statistics, 'nf4', quant_storage, device) class Int8Params(torch.nn.Parameter): diff --git a/tests/test_functional.py b/tests/test_functional.py index f39f676d5..f314dc6e2 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -2370,7 +2370,8 @@ def test_normal_map_tree(): @pytest.mark.parametrize("storage_type", ['nf4', 'fp4'], ids=['nf4', 'fp4']) @pytest.mark.parametrize("kind", ['fc1', 'fc2', 'attn', 'attn_packed'], ids=['fc1', 'fc2', 'attn', 'attn_packed']) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=['fp16', 'bf16', 'fp32']) -def test_gemv_4bit(dtype, storage_type, double_quant, kind): +@pytest.mark.parametrize("quant_storage", [torch.uint8, torch.float16, torch.bfloat16, torch.float32], ids=['uint8', 'fp16', 'bf16', 'fp32']) +def test_gemv_4bit(dtype, storage_type, quant_storage, double_quant, kind): for dim in [128, 256, 512, 1024]: #for dim in [4*1024]: #for dim in [1*16]: @@ -2399,7 +2400,7 @@ def test_gemv_4bit(dtype, storage_type, double_quant, kind): A = torch.randn(1, dim, dtype=dtype, device='cuda') B = torch.randn(dim*3, dim, dtype=dtype, device='cuda')/math.sqrt(dim) - qB, state = F.quantize_4bit(B, quant_type=storage_type, compress_statistics=double_quant) + qB, state = F.quantize_4bit(B, quant_type=storage_type, compress_statistics=double_quant, quant_storage=quant_storage) C3 = torch.matmul(A, B.t()) C2 = F.gemv_4bit(A, qB.t(), state=state) A.requires_grad = True diff --git a/tests/test_linear4bit.py b/tests/test_linear4bit.py index 67d299dea..f6be79a84 100644 --- a/tests/test_linear4bit.py +++ b/tests/test_linear4bit.py @@ -8,13 +8,19 @@ import bitsandbytes as bnb +storage = { + 'uint8': torch.uint8, + 'float16': torch.float16, + 'bfloat16': torch.bfloat16, + 'float32': torch.float32 +} @pytest.mark.skipif(not torch.cuda.is_available(), reason="this test requires a GPU") @pytest.mark.parametrize( - "quant_type, compress_statistics, bias", - list(product(["nf4", "fp4"], [False, True], [False, True])), + "quant_type, compress_statistics, bias, quant_storage", + list(product(["nf4", "fp4"], [False, True], [False, True], ['uint8', 'float16', 'bfloat16', 'float32'])), ) -def test_linear_serialization(quant_type, compress_statistics, bias): +def test_linear_serialization(quant_type, compress_statistics, bias, quant_storage): original_dtype = torch.float16 compute_dtype = None device = "cuda" @@ -32,7 +38,7 @@ def test_linear_serialization(quant_type, compress_statistics, bias): quant_type=quant_type, device="meta", ) - new_weight = bnb.nn.Params4bit(data=linear.weight, requires_grad=False) + new_weight = bnb.nn.Params4bit(data=linear.weight, quant_type=quant_type, requires_grad=False) linear_q.weight = new_weight if bias: linear_q.bias = torch.nn.Parameter(linear.bias) @@ -65,6 +71,22 @@ def test_linear_serialization(quant_type, compress_statistics, bias): # MATCHING a, b = linear_q.weight, linear_q2.weight + # Quantizing original layer with specified quant_storage type + linear_qs = bnb.nn.Linear4bit( + linear.in_features, + linear.out_features, + bias=bias, + compute_dtype=compute_dtype, + compress_statistics=compress_statistics, + quant_type=quant_type, + quant_storage=storage[quant_storage], + device="meta", + ) + linear_qs.weight = bnb.nn.Params4bit(data=linear.weight, requires_grad=False, quant_type=quant_type, quant_storage=storage[quant_storage]) + if bias: + linear_qs.bias = torch.nn.Parameter(linear.bias) + linear_qs = linear_qs.to(device) + assert a.device == b.device assert a.dtype == b.dtype assert torch.equal(a, b) @@ -96,9 +118,21 @@ def test_linear_serialization(quant_type, compress_statistics, bias): x = torch.rand(42, layer_shape[0], device=device) a = linear_q(x) b = linear_q2(x) + c = linear_qs(x) assert a.device == b.device assert a.dtype == b.dtype + assert a.device == c.device + assert a.dtype == c.dtype assert torch.equal(a, b) + assert torch.equal(a, c) + + # Test moving to CPU and back to GPU + linear_q2.to('cpu') + linear_q2.to(device) + d = linear_qs(x) + assert c.dtype == d.dtype + assert c.device == d.device + assert torch.equal(c, d) # Saved size ratio test. Target set for layer_shape == (300, 400) w/ bias with TemporaryDirectory() as tmpdir: