Skip to content

Commit

Permalink
Tests update
Browse files Browse the repository at this point in the history
  • Loading branch information
Andrei-Aksionov committed Feb 18, 2024
1 parent e24dd7a commit 0a8a0b1
Show file tree
Hide file tree
Showing 3 changed files with 102 additions and 39 deletions.
6 changes: 3 additions & 3 deletions generate/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,10 +230,10 @@ def main(

# --------------------- Convert to Marlin ---------------------

if kernel == "marlin":
autogptq.convert_to_marlin(quantized_model_dir)

if use_gptq:
if kernel == "marlin":
autogptq.convert_to_marlin(quantized_model_dir)

# post_init is executed only on a CUDA device
autogptq.post_init()

Expand Down
8 changes: 5 additions & 3 deletions quantize/autogptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,9 @@ def validate_config(self, kernel=None):

kernel = kernel or self.kernel
if (kernel == "triton" and self.bits == 3) or (kernel in ("exllama", "exllamav2", "marlin") and self.bits != 4):
raise ValueError(f"Kernel '{kernel}' doesn't support {self.bits}bit precision.")
raise NotImplementedError(f"Kernel '{kernel}' doesn't support {self.bits}bit precision.")
if kernel == "marlin" and self.group_size not in (-1, 128):
raise NotImplementedError(f"Kernel Marlin doesn't support group_size of {self.group_size}, only -1 or 128.")

@classmethod
def load_config(cls, path: Path):
Expand Down Expand Up @@ -184,7 +186,7 @@ def __init__(
def convert_model_to_quantized(
self,
kernel: Literal["cuda_old", "cuda", "exllama", "exllamav2", "triton", "marlin"],
device,
device=None,
) -> None:
# TODO: add docstring

Expand Down Expand Up @@ -237,7 +239,7 @@ def convert_to_marlin(self, quantized_model_dir):
).QuantLinear

self.model.config.quantization_config = {} # required by AutoGPTQ
convert_to_marlin(self.model, QuantLinear, self.quantize_config, repack=True, strict=True)
convert_to_marlin(self.model, QuantLinear, self.quantize_config, repack=True)

marlin_cache_path = quantized_model_dir / "marlin_cache.pth"
torch.save(self.model.state_dict(), marlin_cache_path)
Expand Down
127 changes: 94 additions & 33 deletions tests/test_autogptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,7 @@
@pytest.mark.parametrize("group_size", [32, 128], ids=[f"{gs}group_size" for gs in (32, 128)])
@pytest.mark.parametrize("use_triton", (True, False), ids=["use_triton", "dont_use_triton"])
@pytest.mark.parametrize("mlp_class", ("GptNeoxMLP", "LLaMAMLP", "LLaMAMoE"))
def test_autogptq_quantization_mlp_layers(
tmp_path, fake_checkpoint_dir, monkeypatch, bits, group_size, use_triton, mlp_class
):
def test_quantization(tmp_path, fake_checkpoint_dir, monkeypatch, bits, group_size, use_triton, mlp_class):
if use_triton and bits == 3:
pytest.skip("Triton doesn't support 3bit precision.")

Expand Down Expand Up @@ -70,15 +68,16 @@ def test_autogptq_quantization_mlp_layers(
assert "Quantization time" in stdout.getvalue()

# Assert that the quantized model weights are saved
files = [p.name for p in fake_checkpoint_dir.glob("*")]
assert f"lit_model_gptq.{bits}bit.pth" in files
quantized_model_dir = fake_checkpoint_dir / f"autogptq/{bits}bit"
files = [p.name for p in quantized_model_dir.glob("*")]
assert "lit_model_gptq.pth" in files
# Assert that the quantize config is saved
assert "autogptq_config.json" in files
assert "quantize_config.json" in files
# Assert that the kernel type was saved
assert "kernel" in json.loads(Path(fake_checkpoint_dir / "autogptq_config.json").read_text())
assert "kernel" in json.loads(Path(quantized_model_dir / "quantize_config.json").read_text())

# --- Validate the saved quantized weights ---
quantized_state_dict = torch.load(fake_checkpoint_dir / f"lit_model_gptq.{bits}bit.pth")
quantized_state_dict = torch.load(quantized_model_dir / "lit_model_gptq.pth")

# Create a reference model to check that the saved quantized weights have a proper shape
reference_model = GPT(config)
Expand Down Expand Up @@ -121,22 +120,33 @@ def test_autogptq_quantization_mlp_layers(


@RunIf(min_cuda_gpus=1)
@pytest.mark.parametrize("kernel", ("cuda", "exllama", "exllamav2", "triton", "marlin"))
@pytest.mark.parametrize(
"kernel",
(
"cuda_old",
"cuda",
"exllama",
"exllamav2",
# due to randomly initialized "quantized" values the triton kernel might throw an error
pytest.param("triton", marks=pytest.mark.xfail(raises=ValueError, match="math domain error")),
"marlin",
),
)
@pytest.mark.parametrize("bits", [2, 3, 4, 8], ids=[f"{bit}bit" for bit in (2, 3, 4, 8)])
@pytest.mark.parametrize("group_size", [32, 128], ids=[f"{gs}group_size" for gs in (32, 128)])
@pytest.mark.parametrize("mlp_class", ("GptNeoxMLP", "LLaMAMLP", "LLaMAMoE"))
def test_autogptq_convert_layers(kernel, bits, group_size, mlp_class):
def test_layer_conversion(kernel, bits, group_size, mlp_class):

import importlib
from functools import reduce

from auto_gptq.modeling._base import BaseQuantizeConfig

from lit_gpt import GPT
from lit_gpt.config import Config
from quantize.autogptq import AutoGPTQ
from quantize.autogptq import AutoGPTQ, QuantizeConfig

# Prepare model's config
# NOTE: carefully select `n_query_groups` so the dimension of a layer fits
# Marlin requirements: in_features divisible by 128 and out_features - by 256
config = Config(
padded_vocab_size=10_000,
n_layer=2,
Expand All @@ -152,35 +162,29 @@ def test_autogptq_convert_layers(kernel, bits, group_size, mlp_class):
# Create a model: it has to be on a GPU and with float16 precision
device = "cuda:0"
model = GPT(config).to(device=device, dtype=torch.float16)
model.config.model_type = None # used in .from_pretrained and .from_quantized
model.config.pad_token_id = None # ._prepare_examples_for_quantization
model.config.eos_token_id = 0 # _prepare_examples_for_quantization
model.config.use_cache = False # for quantization it's disabled anyway

# Wrap the model in AutoGPTQ as it allows to convert "nn.Linear" layers to "QuantLinear"
quantize_config = BaseQuantizeConfig(bits=bits, group_size=group_size)
autogptq_model = AutoGPTQ(model, quantized=True, quantize_config=quantize_config)

# Some kernels support only specific set of precision. The code has to tell about it.
# Some kernels support only specific set of precisions. The code has to tell about it.
# We should check it.
# TODO: refactor this
skip_test = False
if (kernel == "triton" and bits == 3) or (kernel in ("exllama", "exllamav2", "marlin") and bits != 4):
skip_test = True
elif (kernel == "marlin") and group_size not in (-1, 128):
if (
(kernel == "triton" and bits == 3)
or (kernel in ("exllama", "exllamav2", "marlin") and bits != 4)
or (kernel == "marlin" and group_size not in (-1, 128))
):
skip_test = True
with pytest.raises(ValueError, match="[doesn't support|only are supported]") if skip_test else nullcontext():
autogptq_model.convert_model_to_quantized(kernel)
# TODO: update skip message to include marlin message
with pytest.raises(NotImplementedError, match="doesn't support") if skip_test else nullcontext() as e_info:
quantize_config = QuantizeConfig(bits=bits, group_size=group_size, kernel=kernel)
if skip_test:
pytest.skip(f"Kernel `{kernel}` doesn't support {bits}bit precision.")
pytest.skip(str(e_info))

# Wrap the model in AutoGPTQ as it allows to convert "nn.Linear" layers to "QuantLinear"
autogptq_model = AutoGPTQ(model, quantized=True, quantize_config=quantize_config)
autogptq_model.convert_model_to_quantized(kernel, device)
# Convert layers and run obligatory "post_init" method: initializes kernel's buffers
autogptq_model.post_init()

# Check that all the target layer were converted
inside_layer_modules = autogptq_model.inside_layer_modules
inside_layer_modules = sum(inside_layer_modules, [])
# Check that all the target layers were successfully converted
inside_layer_modules = sum(autogptq_model.inside_layer_modules, [])

QuantLinear = importlib.import_module(f"auto_gptq.nn_modules.qlinear.qlinear_{kernel}").QuantLinear

Expand All @@ -199,3 +203,60 @@ def test_autogptq_convert_layers(kernel, bits, group_size, mlp_class):
# Run a forward pass, it should not fail
x = torch.tensor([[9856, 23, 491, 1536, 304], [23, 345, 65, 123, 321]], dtype=torch.int32, device=device)
autogptq_model(x)


@pytest.mark.parametrize("kernel", ("cuda_old", "cuda", "exllama", "exllamav2", "triton"))
def test_marlin_conversion(kernel, tmp_path):
from functools import reduce

from auto_gptq.nn_modules.qlinear.qlinear_marlin import QuantLinear

from lit_gpt import GPT
from lit_gpt.config import Config
from quantize.autogptq import AutoGPTQ, QuantizeConfig

# Prepare model's config
# NOTE: carefully select `n_query_groups` so the dimension of a layer fits
# Marlin requirements: in_features divisible by 128 and out_features - by 256
config = Config(
padded_vocab_size=10_000,
n_layer=2,
n_embd=128,
n_head=8,
n_query_groups=4,
intermediate_size=256,
)

# Create a model: it has to be on a GPU and with float16 precision
device = "cuda:0"
model = GPT(config).to(device=device, dtype=torch.float16)

quantize_config = QuantizeConfig(bits=4, group_size=128, desc_act=False, kernel=kernel)
quantize_config.save_config(tmp_path / "quantize_config.json")

# Wrap the model in AutoGPTQ as it allows to convert "nn.Linear" layers to "QuantLinear"
autogptq_model = AutoGPTQ(model, quantized=True, quantize_config=quantize_config)
autogptq_model.convert_model_to_quantized(kernel, device)
# Convert layers and run obligatory "post_init" method: initializes kernel's buffers
autogptq_model.post_init()

# Convert to Marlin layers
autogptq_model.convert_to_marlin(tmp_path)

# Assert that all layers were converted
inside_layer_modules = sum(autogptq_model.inside_layer_modules, [])

for layer_name in autogptq_model.model.state_dict():
module = reduce(getattr, layer_name.split(".")[:-1], autogptq_model.model)
if any(ilm in layer_name for ilm in inside_layer_modules):
assert layer_name.endswith((".B", ".s", ".workspace", ".bias")), layer_name
assert isinstance(module, QuantLinear), layer_name
else:
assert layer_name.endswith((".weight", ".bias")), layer_name
assert not isinstance(module, QuantLinear), layer_name

# Assert that the Marlin version of the model is cached
assert "marlin_cache.pth" in [p.name for p in tmp_path.glob("*")]

# Assert that the quantize config now knows that the Marlin was cached
assert quantize_config.marlin_cached is True

0 comments on commit 0a8a0b1

Please sign in to comment.