Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support 4bit on CPU backend #1206

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion bitsandbytes/autograd/_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -572,7 +572,8 @@ def matmul_4bit(
bias=None,
):
assert quant_state is not None
if A.numel() == A.shape[-1] and A.requires_grad == False:
if (A.numel() == A.shape[-1] or A.device.type == "cpu") and A.requires_grad == False:
# CPU backend does not require A to be a vector
if A.shape[-1] % quant_state.blocksize != 0:
warn(
f"Some matrices hidden dimension is not a multiple of {quant_state.blocksize} and efficient inference kernels are not supported for these (slow). Matrix input size found: {A.shape}",
Expand Down
16 changes: 13 additions & 3 deletions bitsandbytes/backends/cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,12 @@

from .base import Backend
from .cpu_xpu_common import (
dequantize_4bit_impl,
double_quant_impl,
gemm_4bit_impl,
igemmlt_impl,
mm_dequant_impl,
quantize_4bit_impl,
)

Tensor = torch.Tensor
Expand Down Expand Up @@ -132,7 +135,9 @@ def quantize_4bit(
quant_type: Literal["fp4", "nf4"] = "fp4",
quant_storage=torch.uint8,
) -> Tuple[torch.Tensor, QuantState]:
raise NotImplementedError("Not yet implemented for CPU backend")
assert_on_cpu([A, absmax, out])
assert quant_storage == torch.uint8, "CPU backend only supports uint8 quant_storage"
return quantize_4bit_impl(A, absmax, out, blocksize, compress_statistics, quant_type)

def dequantize_4bit(
self,
Expand All @@ -143,7 +148,8 @@ def dequantize_4bit(
blocksize: int = 64,
quant_type: Literal["fp4", "nf4"] = "fp4",
) -> torch.Tensor:
raise NotImplementedError("Not yet implemented for CPU backend")
assert_on_cpu([A, absmax, out])
return dequantize_4bit_impl(A, quant_state, absmax, out, blocksize, quant_type)

def gemv_4bit(
self,
Expand All @@ -154,7 +160,11 @@ def gemv_4bit(
transposed_B=False,
state: QuantState = None,
) -> torch.Tensor:
raise NotImplementedError("Not yet implemented for CPU backend")
assert_on_cpu([A, B, out])
if state is None:
raise ValueError("state cannot be None. gemv_4bit() requires the state from quantize_4bit()")

return gemm_4bit_impl(A, B, out, transposed_A, transposed_B, state)

def dequantize_blockwise(
self,
Expand Down
293 changes: 293 additions & 0 deletions bitsandbytes/backends/cpu_xpu_common.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,13 @@
from typing import Optional
import warnings

import torch

from bitsandbytes.functional import (
QuantState,
get_4bit_type,
)

try:
# to support Intel CPU/GPU (XPU) backend
import intel_extension_for_pytorch as ipex
Expand Down Expand Up @@ -228,3 +234,290 @@ def mm_dequant_impl(
out = out + bias.to(compute_dtype)
out = out.to(output_dtype)
return out


NF4_QUANT_TABLE = [
-1.0 - 1e-2, # 0b0000
-0.8480964004993439, # 0b0001
-0.6106329262256622, # 0b0010
-0.4599952697753906, # 0b0011
-0.33967943489551544, # 0b0100
-0.23460740596055984, # 0b0101
-0.13791173323988914, # 0b0110
-0.045525018125772476, # 0b0111
0.03979014977812767, # 0b1000
0.1202552504837513, # 0b1001
0.2035212516784668, # 0b1010
0.2920137718319893, # 0b1011
0.3893125355243683, # 0b1100
0.5016634166240692, # 0b1101
0.6427869200706482, # 0b1110
0.8614784181118011, # 0b1111
]


FP4_QUANT_TABLE = {
0 - 1e-2: 0, # 0b0000
0.00260417: 1, # 0b0001
0.0859375: 6, # 0b0110
0.20833333: 7, # 0b0111
0.29166667: 4, # 0b0100
0.4166667: 5, # 0b0101
0.583333: 2, # 0b0010
0.8333333: 3, # 0b0011
}


# It's faster not to use torch.compile
def quantize_4bit_impl(
A: Tensor,
absmax: Tensor = None,
out: Tensor = None,
blocksize=64,
compress_statistics=False,
quant_type="nf4",
) -> Tensor:
"""
Quantize tensor A in blocks of 4-bit values.

Quantizes tensor A by dividing it into blocks which are independently quantized to FP4.

Parameters
----------
A : torch.Tensor
The input tensor.
absmax : torch.Tensor
The absmax values.
out : torch.Tensor
The output tensor (8-bit).
blocksize : int
The blocksize used in quantization.
quant_type : str
The 4-bit quantization data type {fp4, nf4}, only nf4 is supported now

Returns
-------
torch.Tensor:
The 8-bit tensor with packed 4-bit values.
tuple(torch.Tensor, torch.Size, torch.dtype, int):
The quantization state to undo the quantization.
"""
if quant_type not in ["nf4", "fp4"]:
raise NotImplementedError(f"4-bit quantization data type {quant_type} is not implemented for CPU/XPU.")
if quant_type == "fp4":
warnings.warn("fp4 quantization is currently slow on CPU/XPU. Please Use nf4 instead for better performance.")
assert blocksize in [4096, 2048, 1024, 512, 256, 128, 64]
n = A.numel()
input_shape = A.shape
blocks = n // blocksize
blocks += 1 if n % blocksize > 0 else 0

if absmax is None:
absmax = torch.zeros((blocks,), device=A.device, dtype=A.dtype)

if out is None:
out = torch.zeros(((n + 1) // 2), dtype=torch.uint8, device=A.device)

rem = n % blocksize
has_rem = rem > 0

# Scale tensor to [-1, 1]
A_reshaped = A.reshape(n)
A_com = A_reshaped[: n - rem]
A_com_reshaped = A_com.reshape(n // blocksize, blocksize)
absmax[: blocks - has_rem] = torch.abs(A_com_reshaped).max(dim=-1)[0]
scaled_A = torch.clamp(A_com_reshaped * (1 / absmax[: blocks - has_rem].view(-1, 1)), -1, 1)
scaled_A = scaled_A.reshape(-1)
if has_rem:
absmax[-1] = torch.abs(A_reshaped[n - rem :]).max()
scaled_A_rem = torch.clamp(A_reshaped[n - rem :] * (1 / absmax[-1]), -1, 1)
scaled_A = torch.cat([scaled_A, scaled_A_rem], dim=0)
# map [-1, 1] to nf4/fp4
out_uint8 = torch.empty(scaled_A.shape, dtype=torch.uint8)
if quant_type == "nf4":
for i in range(len(NF4_QUANT_TABLE)):
out_uint8[scaled_A > NF4_QUANT_TABLE[i]] = i
elif quant_type == "fp4":
sign = scaled_A < 0
abs_scaled_A = torch.abs(scaled_A)
for key, val in FP4_QUANT_TABLE.items():
out_uint8[abs_scaled_A > key] = val
out_uint8 += sign.to(torch.uint8) * 8
if out_uint8.size(-1) % 2:
out_uint8 = torch.nn.functional.pad(out_uint8, (0, 1), value=0)
out[:] = out_uint8[1::2].bitwise_left_shift(4).bitwise_or_(out_uint8[::2])

code = get_4bit_type(quant_type, device=A.device)

if compress_statistics:
raise NotImplementedError("bnb_4bit_use_double_quant is not supported yet for CPU/XPU")
else:
state = QuantState(
absmax=absmax,
shape=input_shape,
dtype=A.dtype,
blocksize=blocksize,
code=code,
quant_type=quant_type,
)

if ipex_cpu and _ipex_cpu_version_prereq(2, 3) and input_shape[1] % blocksize == 0 and quant_type == "nf4":
# lowp_mode: lowest precision for computation
lowp_mode = ipex_cpu.quantization.WoqLowpMode.BF16
state.op_context = torch.ops.ipex_prepack.weight_only_qlinear_prepack(
out.reshape([input_shape[0], input_shape[1] // 2]),
ipex_cpu.quantization.WoqWeightDtype.NF4,
input_shape, # weight shape
absmax.view(input_shape[0], input_shape[1] // blocksize), # scales
None, # zero_points
None, # bias
None, # g_idx
None, # batch_size
blocksize,
int(lowp_mode),
-1, # act_quant_mode. -1 means don't quant activation
)
state.absmax = torch.Tensor()
return torch.Tensor(), state

return out, state


@_maybe_torch_compile
def dequantize_4bit_impl(
A: Tensor,
quant_state=None,
absmax: Tensor = None,
out: Tensor = None,
blocksize: int = 64,
quant_type="nf4",
) -> Tensor:
"""
Dequantizes FP4 blockwise quantized values.

Dequantizes the tensor A with maximum absolute values absmax in blocks of size blocksize.

Parameters
----------
A : torch.Tensor
The input 8-bit tensor (packed 4-bit values).
quant_state : QuantState
object with quantisation stats, incl. absmax values, original tensor shape and original dtype.
absmax : torch.Tensor
The absmax values.
out : torch.Tensor
Dequantized output tensor.
blocksize : int
The blocksize used in quantization.
quant_type : str
The 4-bit quantization data type {fp4, nf4}, only nf4 is supported now


Returns
-------
torch.Tensor:
Dequantized tensor.
"""

if quant_state is None:
assert absmax is not None and out is not None

quant_state = QuantState(
absmax=absmax,
shape=out.shape,
dtype=out.dtype,
blocksize=blocksize,
quant_type=quant_type,
)

else:
absmax = quant_state.absmax

if quant_type not in ["nf4", "fp4"]:
raise NotImplementedError(
f"4-bit quantization data type {quant_state.quant_type} is not implemented for CPU/XPU."
)

if quant_state.nested:
raise NotImplementedError("bnb_4bit_use_double_quant is not supported yet for CPU/XPU")

if ipex_cpu and _ipex_cpu_version_prereq(2, 3) and hasattr(quant_state, "op_context"):
assert quant_state.op_context is not None
A = quant_state.op_context.to_public(quant_state.op_context.get_weight())
A = A.reshape(-1)
absmax = quant_state.op_context.get_scales().reshape(-1)

if out is None:
out = torch.empty(quant_state.shape, dtype=quant_state.dtype, device=A.device)

n = out.numel()
# Map nf4 to [-1, 1]
out_uint8 = torch.empty(A.size(0) * 2, dtype=torch.uint8, device=A.device)
out_uint8[::2] = A.bitwise_and(0xF)
out_uint8[1::2] = A.bitwise_right_shift(4)
out_dq = torch.empty(out_uint8.shape).to(quant_state.dtype)
for i in range(len(quant_state.code)):
out_dq[out_uint8 == i] = quant_state.code[i]
Comment on lines +458 to +460
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using index select will be faster out_dq = quant_state.code[out_uint8.to(torch.int32)].

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like torch.compile result of this code gives wrong results. And removing torch.compile results in lower performance. Let's keep this implementation for now.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A bug in torch.compile? Can you submit a bug to PyTorch? I will try to fix it.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

However, I cannot reproduce the issue with the script below. May need more investigation.

import torch


NF4_DEQUANT_TABLE = torch.Tensor([
  -1.0,
  -0.6961928009986877,
  -0.5250730514526367,
  -0.39491748809814453,
  -0.28444138169288635,
  -0.18477343022823334,
  -0.09105003625154495,
  0.0,
  0.07958029955625534,
  0.16093020141124725,
  0.24611230194568634,
  0.33791524171829224,
  0.44070982933044434,
  0.5626170039176941,
  0.7229568362236023,
  1.0,
])


@torch.compile
def dequant_nf4_compile(t_in: torch.Tensor, out_dtype):
  return NF4_DEQUANT_TABLE[t_in.to(torch.int)].to(out_dtype)


def dequant_nf4_eager(t_in: torch.Tensor, out_dtype):
  return NF4_DEQUANT_TABLE[t_in.to(torch.int)].to(out_dtype)


x = torch.randint(0, 16, (1024, 1024), dtype=torch.uint8)

y1 = dequant_nf4_compile(x, torch.bfloat16)
y1 = dequant_nf4_compile(x, torch.bfloat16)
y2 = dequant_nf4_eager(x, torch.bfloat16)

print(torch.equal(y1, y2))
print("max diff =", torch.abs(y1 - y2).max())


# Apply scales
if out_dq.numel() != n:
assert out_dq.numel() == n + 1
out_dq = torch.narrow(out_dq, 0, 0, n)
blocks = n // blocksize
blocks += 1 if n % blocksize > 0 else 0
rem = n % blocksize
has_rem = rem > 0
out_reshaped = out.reshape(-1)
out_reshaped[: n - rem] = (out_dq[: n - rem].view(-1, blocksize) * absmax[: blocks - has_rem].view(-1, 1)).reshape(
-1
)
if has_rem:
out_reshaped[n - rem :] = out_dq[n - rem :] * absmax[-1]

# take transpose here because weight is transposed (again) for computation
return out.t()


# Do not need torch.compile here as we are calling torch/ipex kernel
def gemm_4bit_impl(
A: torch.Tensor,
B: torch.Tensor,
out: Optional[torch.Tensor] = None,
transposed_A=False,
transposed_B=False,
state: QuantState = None,
) -> torch.Tensor:
"""
Matrix-matrix multiplication with 4-bit quantization.

Parameters
----------
A : torch.Tensor
The first input tensor. Usually the activation tensor.
B : torch.Tensor
The second input tensor. Usually the weight tensor.
out : torch.Tensor
The output tensor.
transposed_A : bool
Whether A is transposed
transposed_B : bool
Whether B is transposed
state : QuantState
Contains quantization info, such as blocksize and dtype

Returns
-------
torch.Tensor:
GEMM output tensor.
"""
if ipex_cpu and _ipex_cpu_version_prereq(2, 3) and hasattr(state, "op_context"):
assert state.op_context is not None
output = torch.ops.torch_ipex.ipex_woq_linear(A, state.op_context.get_data_handle())
else:
dqB = dequantize_4bit_impl(B, state, blocksize=state.blocksize)
output = torch.matmul(A, dqB)
if out is not None:
out.copy_(output)
else:
out = output
return out
7 changes: 5 additions & 2 deletions bitsandbytes/nn/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,7 @@ def from_prequantized(
return self

def _quantize(self, device):
w = self.data.contiguous().cuda(device)
w = self.data.contiguous().to(device)
w_4bit, quant_state = bnb.functional.quantize_4bit(
w,
blocksize=self.blocksize,
Expand All @@ -303,6 +303,9 @@ def _quantize(self, device):
def cuda(self, device: Optional[Union[int, device, str]] = None, non_blocking: bool = False):
return self.to(device="cuda" if device is None else device, non_blocking=non_blocking)

def cpu(self, non_blocking: bool = False):
return self.to(device="cpu", non_blocking=non_blocking)

@overload
def to(
self: T,
Expand All @@ -320,7 +323,7 @@ def to(self: T, tensor: Tensor, non_blocking: bool = ...) -> T: ...
def to(self, *args, **kwargs):
device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs)

if device is not None and device.type == "cuda" and not self.bnb_quantized:
if device is not None and device.type in ["cuda", "cpu"] and not self.bnb_quantized:
return self._quantize(device)
else:
if self.quant_state is not None:
Expand Down
Loading
Loading