Skip to content

Commit

Permalink
test(qbits): increase coverage for marlin linear
Browse files Browse the repository at this point in the history
  • Loading branch information
dacorvo committed Oct 7, 2024
1 parent 38215b8 commit bae461d
Showing 1 changed file with 36 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from helpers import assert_similar, device_eq, random_qweight, random_tensor

from optimum.quanto import qint4
from optimum.quanto.library.extensions import is_extension_available
from optimum.quanto.tensor.weights import WeightQBitsTensor
from optimum.quanto.tensor.weights.marlin.int4 import MarlinInt4WeightQBitsTensor

Expand Down Expand Up @@ -91,19 +92,15 @@ def test_marlin_int4_weight_qbits_tensor_move(device):
assert torch.equal(marlinqbt.dequantize().to(device), moved_qbt.dequantize())


@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.parametrize("batch_size", [1, 2])
@pytest.mark.parametrize("tokens", [256, 512])
@pytest.mark.parametrize("embeddings", [256, 512, 1024, 4096])
@pytest.mark.parametrize("use_bias", [True, False], ids=["bias", "no-bias"])
def test_marlin_int4_weight_qbits_tensor_linear(batch_size, tokens, embeddings, use_bias):
def _test_marlin_int4_weight_qbits_tensor_linear(
dtype, weight_qtype, group_size, batch_size, tokens, in_features, out_features, use_bias
):
device = torch.device("cuda")
dtype = torch.float16
weight_qtype = qint4
group_size = 128
inputs = torch.rand((batch_size,) + (tokens, embeddings), dtype=dtype, device=device)
inputs = torch.rand((batch_size, tokens, in_features), dtype=dtype, device=device)
# Create an MarlinInt4WeightQBitsTensor from a QBitsTensor on CUDA
qbt = random_qweight((tokens, embeddings), weight_qtype, dtype, group_size=group_size, device=torch.device("cuda"))
qbt = random_qweight(
(out_features, in_features), weight_qtype, dtype, group_size=group_size, device=torch.device("cuda")
)
marlin_qweight = MarlinInt4WeightQBitsTensor(
qtype=qbt.qtype,
axis=qbt.axis,
Expand All @@ -114,10 +111,34 @@ def test_marlin_int4_weight_qbits_tensor_linear(batch_size, tokens, embeddings,
scale=qbt._scale,
shift=qbt._shift,
)
bias = random_tensor((tokens,), dtype=dtype).to(device) if use_bias else None
bias = random_tensor((out_features,), dtype=dtype, device=device) if use_bias else None
qout = torch.nn.functional.linear(inputs, marlin_qweight, bias)
out = torch.nn.functional.linear(inputs, qbt.dequantize(), bias)
# Verify global alignment
assert_similar(out, qout)
# Also look for outliers
max_val = out.abs().max()
max_err = (out - qout).abs().max()
rel_max_err = max_err / max_val
assert rel_max_err < 5e-2


@pytest.mark.skipif(
not is_extension_available("quanto_cuda") or torch.cuda.get_device_capability()[0] < 8,
reason="CUDA >= sm80 not available",
)
@pytest.mark.parametrize("batch_size", [1, 2])
@pytest.mark.parametrize("tokens", [16, 32, 48, 64])
@pytest.mark.parametrize("in_features", [1024, 4096, 16384])
@pytest.mark.parametrize("out_features", [1024, 2048, 4096])
@pytest.mark.parametrize("use_bias", [True, False], ids=["bias", "no-bias"])
def test_marlin_int4_weight_qbits_tensor_linear(batch_size, tokens, in_features, out_features, use_bias):
dtype = torch.float16
weight_qtype = qint4
group_size = 128
_test_marlin_int4_weight_qbits_tensor_linear(
dtype, weight_qtype, group_size, batch_size, tokens, in_features, out_features, use_bias
)


@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
Expand All @@ -131,7 +152,9 @@ def test_marlin_int4_weight_qbits_tensor_linear_bug(tokens):
out_features = 2048
inputs = torch.rand((tokens, in_features), dtype=dtype, device=device)
# Create a MarlinInt4WeightQBitsTensor from a QBitsTensor on CUDA
qbt = random_qweight((out_features, in_features), weight_qtype, dtype, group_size=group_size, device=torch.device("cuda"))
qbt = random_qweight(
(out_features, in_features), weight_qtype, dtype, group_size=group_size, device=torch.device("cuda")
)
marlin_qweight = MarlinInt4WeightQBitsTensor(
qtype=qbt.qtype,
axis=qbt.axis,
Expand Down

0 comments on commit bae461d

Please sign in to comment.