Skip to content

Commit

Permalink
Refactor: kernel --> gptq_kernel
Browse files Browse the repository at this point in the history
  • Loading branch information
Andrei-Aksionov committed Mar 1, 2024
1 parent 18c9ca7 commit 58e4c74
Showing 1 changed file with 9 additions and 9 deletions.
18 changes: 9 additions & 9 deletions generate/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def main(
"gptq.int8",
]
] = None,
kernel: Optional[Literal["cuda_old", "cuda", "exllama", "exllamav2", "triton", "marlin"]] = None,
gptq_kernel: Optional[Literal["cuda_old", "cuda", "exllama", "exllamav2", "triton", "marlin"]] = None,
precision: Optional[str] = None,
compile: bool = False,
) -> None:
Expand All @@ -135,7 +135,7 @@ def main(
- bnb.int8: 8-bit quantization from bitsandbytes
- gptq.int[bits]: inference with AutoGPTQ. Select the same `bits` value as was used during quantization.
for more details, see https://github.com/Lightning-AI/lit-gpt/blob/main/tutorials/quantize.md
kernel: Choose a kernel to apply with the quantized weights. If set to None, the same kernel used during
gptq_kernel: Choose a kernel to apply with the quantized weights. If set to None, the same kernel used during
quantization will be selected.
precision: Indicates the Fabric precision setting to use.
compile: Whether to compile the model.
Expand Down Expand Up @@ -170,8 +170,8 @@ def main(

quantized_model_dir = checkpoint_dir / f"quantized/{bits}bit"
quantize_config = QuantizeConfig.load_config(quantized_model_dir / "quantize_config.json")
kernel = kernel or quantize_config.kernel
if kernel == "marlin" and quantize_config.marlin_cached:
gptq_kernel = gptq_kernel or quantize_config.kernel
if gptq_kernel == "marlin" and quantize_config.marlin_cached:
model_file = "marlin_cache.pth"
else:
model_file = "lit_model_gptq.pth"
Expand Down Expand Up @@ -208,9 +208,9 @@ def main(
# If Marlin is selected and was cached - convert directly to Marlin (that allows to load cached weights)
# If it wasn't cached - first convert to kernel from the config, load weights (in the code down below) and
# the conversion to Marlin will be done later (in `convert_quantized_to_marlin` method)
if (_kernel := kernel) == "marlin" and not quantize_config.marlin_cached:
_kernel = quantize_config.kernel
autogptq.convert_to_quantized(_kernel, device=fabric.device)
if (_gptq_kernel := gptq_kernel) == "marlin" and not quantize_config.marlin_cached:
_gptq_kernel = quantize_config.kernel
autogptq.convert_to_quantized(_gptq_kernel, fabric.device)

# Loading weights
t0 = time.perf_counter()
Expand All @@ -227,12 +227,12 @@ def main(
# Marlin conversion and post_init
if gptq_selected:
# Marlin conversion happens after the model is quantized to one of the other kernels
if kernel == "marlin":
if gptq_kernel == "marlin":
autogptq.convert_quantized_to_marlin(quantized_model_dir)
# obligatory post init: initializes kernel's buffers
autogptq.post_init()
# the last step is to print quantize config with the selected kernel
quantize_config.kernel = kernel
quantize_config.kernel = gptq_kernel
fabric.print(f"Running GPTQ model with {quantize_config}", file=sys.stderr)

if compile:
Expand Down

0 comments on commit 58e4c74

Please sign in to comment.