Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Quantized checkpoint support in export and deploy modules #8859

Merged
merged 19 commits into from
Apr 23, 2024
Merged
Show file tree
Hide file tree
Changes from 18 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ RUN pip install flash-attn
# install numba for latest containers
RUN pip install numba>=0.57.1
# install ammo
RUN pip install nvidia-ammo~=0.7.0 --extra-index-url https://pypi.nvidia.com --no-cache-dir
RUN pip install nvidia-ammo~=0.9.0 --extra-index-url https://pypi.nvidia.com --no-cache-dir

# copy nemo source into a scratch image
FROM scratch as nemo-src
Expand Down
2 changes: 1 addition & 1 deletion Jenkinsfile
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ pipeline {

stage('AMMO installation') {
steps {
sh 'pip install nvidia-ammo~=0.7.0 --extra-index-url https://pypi.nvidia.com --no-cache-dir'
sh 'pip install nvidia-ammo~=0.9.0 --extra-index-url https://pypi.nvidia.com --no-cache-dir'
}
}

Expand Down
1 change: 1 addition & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ For more information, browse the developer docs for your area of interest in the
nlp/models
nlp/machine_translation/machine_translation
nlp/megatron_onnx_export
nlp/quantization
nlp/api


Expand Down
46 changes: 31 additions & 15 deletions docs/source/nlp/quantization.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,35 +3,35 @@
Quantization
==========================

Post Training Quantization (PTQ)
Post-Training Quantization (PTQ)
--------------------------------

PTQ enables deploying a model in a low-precision format -- FP8, INT4 or INT8 -- for efficient serving. Different quantization methods are available including FP8 quantization, INT8 SmoothQuant and INT4 AWQ.
PTQ enables deploying a model in a low-precision format -- FP8, INT4, or INT8 -- for efficient serving. Different quantization methods are available including FP8 quantization, INT8 SmoothQuant, and INT4 AWQ.

Model quantization has two primary benefits: reduced model memory requirements and increased inference throughput.

In NeMo, quantization is enabled by the Nvidia AMMO library -- a unified algorithmic model optimization & deployment toolkit.

The quantization process consists of the following steps:

1. Loading a model checkpoint using appropriate parallelism strategy for evaluation
1. Loading a model checkpoint using an appropriate parallelism strategy
2. Calibrating the model to obtain appropriate algorithm-specific scaling factors
3. Producing output directory or .qnemo tarball with model config (json), quantized weights (safetensors) and tokenizer config (yaml).
3. Producing an output directory or .qnemo tarball with model config (json), quantized weights (safetensors) and tokenizer config (yaml).

Loading models requires using AMMO spec defined in `megatron.core.deploy.gpt.model_specs module <https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/deploy/gpt/model_specs.py>`_. Typically the calibration step is lightweight and uses a small dataset to obtain appropriate statistics for scaling tensors. The output directory produced (or a .qnemo tarball) is ready to be used to build a serving engine with the Nvidia TensorRT-LLM library. The engine build step is also soon to be the part of NeMo project and ``nemo.deploy`` and ``nemo.export`` modules, see https://github.com/NVIDIA/NeMo/pull/8690.
Loading models requires using an AMMO spec defined in `megatron.core.inference.gpt.model_specs.py <https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/inference/gpt/model_specs.py>`_ module. Typically the calibration step is lightweight and uses a small dataset to obtain appropriate statistics for scaling tensors. The output directory produced (or a .qnemo tarball) is ready to be used to build a serving engine with the Nvidia TensorRT-LLM library. The engine build step is also available in NeMo project in ``nemo.deploy`` and ``nemo.export`` modules.

Quantization algorithm can also be conveniently set to ``"null"`` to perform only the weights export step using default precision for TensorRT-LLM deployment. This is useful to obtain baseline performance and accuracy results for comparison.


Example
^^^^^^^
The example below shows how to quantize the Llama2 70b model into FP8 precision, using tensor parallelism of 8 on a single DGX H100 node. The quantized model is intended for serving using 2 GPUs specified with ``export.inference_tensor_parallel`` parameter.
The example below shows how to quantize the Llama2 70b model into FP8 precision, using tensor parallelism of 8 on a single DGX H100 node. The quantized model is designed for serving using 2 GPUs specified with the ``export.inference_tensor_parallel`` parameter.

The script should be launched correctly with the number of processes equal to tensor parallelism. This is achieved with the ``mpirun`` command below.
The script must be launched correctly with the number of processes equal to tensor parallelism. This is achieved with the ``torchrun`` command below:

.. code-block:: bash

mpirun -n 8 python examples/nlp/language_modeling/megatron_llama_quantization.py \
torchrun --nproc-per-node 8 examples/nlp/language_modeling/megatron_llama_quantization.py \
model_file=llama2-70b-base-bf16.nemo \
tensor_model_parallel_size=8 \
pipeline_model_parallel_size=1 \
Expand All @@ -57,31 +57,47 @@ The output directory stores the following files:
└── tokenizer_config.yaml


The TensorRT-LLM engine can be build with ``trtllm-build`` command, see `TensorRT-LLM documentation <https://github.com/NVIDIA/TensorRT-LLM/tree/main/examples/llama#fp8-post-training-quantization>`_.
The TensorRT-LLM engine can be conveniently built and run using ``TensorRTLLM`` class available in ``nemo.export`` submodule:

.. code-block:: python

from nemo.export import TensorRTLLM


trt_llm_exporter = TensorRTLLM(model_dir="/path/to/trt_llm_engine_folder")
trt_llm_exporter.export(
nemo_checkpoint_path="llama2-70b-base-fp8-qnemo",
model_type="llama",
)
trt_llm_exporter.forward(["Hi, how are you?", "I am good, thanks, how about you?"])


Alternatively, it can also be built directly using ``trtllm-build`` command, see `TensorRT-LLM documentation <https://github.com/NVIDIA/TensorRT-LLM/tree/main/examples/llama#fp8-post-training-quantization>`_:

.. code-block:: bash

trtllm-build \
--checkpoint_dir llama2-70b-base-fp8-qnemo \
--output_dir engine_dir \
--output_dir /path/to/trt_llm_engine_folder \
--max_batch_size 8 \
--max_input_len 2048 \
--max_output_len 512

--max_output_len 512 \
--strongly_typed


Known issues
^^^^^^^^^^^^
* Currently in NeMo quantizing and building TensorRT-LLM engines is limited to single-node use cases.
* Supported and tested model family is Llama2. Quantizing other model types is experimental and may not be fully supported.
* For INT8 SmoothQuant ``quantization.algorithm=int8_sq``, the TensorRT-LLM engine cannot be build with CLI ``trtllm-build`` command -- Python API and ``tensorrt_llm.builder`` should be used instead.
* Currently in NeMo, quantizing and building TensorRT-LLM engines is limited to single-node use cases.
* The supported and tested model family is Llama2. Quantizing other model types is experimental and may not be fully supported.


Please refer to the following papers for more details on quantization techniques.

References
----------

`Integer Quantization for Deep Learning Inference: Principles and Empirical Evaluation, 2020 <https://arxiv.org/abs/2004.09602>`_

`FP8 Formats for Deep Learning, 2022 <https://arxiv.org/abs/2209.05433>`_

`SmoothQuant: Accurate and Efficient Post-Training Quantization for Large Language Models, 2022 <https://arxiv.org/abs/2211.10438>`_
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,13 @@ quantization:
algorithm: fp8 # int8_sq, fp8, int8, int4_awq, null
calib_dataset: cnn_dailymail # wikitext, cnn_dailymail, or a local dataset
num_calib_size: 512 # number of samples used for calibration
awq_block_size: 128 # block size for scaling factors in AWQ algorithm

export:
decoder_type: llama # gptnext, gpt2, llama
inference_tensor_parallel: 1 # Default using 1 TP for inference
inference_pipeline_parallel: 1 # Default using 1 PP for inference
dtype: 16 # Default precision data type
export_tensorrt_llm_config: true # export config to build TRT-LLM engine directly

model_file: llama2-7b-fp16.nemo # Nemo file path
model_save: llama2-7b-fp8.qnemo # Path where the quantized model will be saved
Expand Down
44 changes: 36 additions & 8 deletions nemo/export/quantize/quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from contextlib import nullcontext
from typing import List, Optional

import torch
import torch.distributed as dist
from megatron.core import parallel_state
from megatron.core.transformer.module import Float16Module
Expand All @@ -34,7 +35,7 @@

try:
import ammo.torch.quantization as atq
from ammo.torch.export import export_model_config
from ammo.torch.export import export_tensorrt_llm_checkpoint

HAVE_AMMO = True

Expand Down Expand Up @@ -80,7 +81,7 @@ def __init__(
trainer_config: DictConfig,
):
if not HAVE_AMMO:
raise RuntimeError("nvidia-ammo>=0.7 is needed to use Quantizer") from HAVE_AMMO_ERROR
raise RuntimeError("nvidia-ammo is needed to use Quantizer") from HAVE_AMMO_ERROR
QUANT_CFG_CHOICES = {
"int8": atq.INT8_DEFAULT_CFG,
"int8_sq": atq.INT8_SMOOTHQUANT_CFG,
Expand All @@ -97,10 +98,21 @@ def __init__(
self.trainer_config = trainer_config
if quantization_config.algorithm is not None:
atq_config = QUANT_CFG_CHOICES[quantization_config.algorithm]
if quantization_config.algorithm != "fp8":
# disable quantization for the last output layer
atq_config = copy.deepcopy(atq_config)
atq_config["quant_cfg"]["*.output_layer.*"] = {"enable": False}

if "awq" in quantization_config.algorithm:
weight_quantizer = atq_config["quant_cfg"]["*weight_quantizer"]
if isinstance(weight_quantizer, list):
weight_quantizer = weight_quantizer[0]
weight_quantizer["block_sizes"][-1] = quantization_config.awq_block_size

# Always turn on FP8 kv cache to save memory footprint.
# For int8_sq, we use int8 kv cache.
atq_config["quant_cfg"]["*output_quantizer"] = {
"num_bits": 8 if quantization_config.algorithm == "int8_sq" else (4, 3),
"axis": None,
"enable": export_config.decoder_type != "gptnext",
}

self.atq_config = atq_config
else:
self.atq_config = None
Expand Down Expand Up @@ -188,6 +200,22 @@ def forward_loop():
model.predict_step(batch, i)

model = atq.quantize(model, self.atq_config, forward_loop)

if self.export_config == "gptnext":
# We found squared_relu may have an under-calibration problem.
# Clamp the scaling_factor with a min threshold to avoid under-calibration.
maxbound = 0
if self.quantization_config.algorithm == "fp8":
maxbound = 448
elif self.quantization_config.quantization.algorithm == "int8_sq":
maxbound = 127
model = atq.postprocess_amax(
model, "*input_quantizer", lambda amax: torch.clamp(amax, min=0.01 * maxbound)
)

if dist.get_rank() == 0:
atq.print_quant_summary(model)

return model

def export(self, model, model_save: str):
Expand All @@ -206,13 +234,13 @@ def export(self, model, model_save: str):
export_handler = nullcontext(enter_result=model_save)

with export_handler as export_dir:
export_model_config(
export_tensorrt_llm_checkpoint(
model=model,
decoder_type=self.export_config.decoder_type,
dtype=torch_dtype,
export_dir=export_dir,
inference_tensor_parallel=self.export_config.inference_tensor_parallel,
export_tensorrt_llm_config=self.export_config.export_tensorrt_llm_config,
inference_pipeline_parallel=self.export_config.inference_pipeline_parallel,
)
dist.barrier() # Wait until all ranks complete export_model_config step
if dist.get_rank() == 0:
Expand Down
5 changes: 5 additions & 0 deletions nemo/export/tarutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,3 +202,8 @@

def keys(self):
return self._path.iterdir()


def unpack_tarball(archive: str, dest_dir: str):
with tarfile.open(archive, mode="r") as tar:
tar.extractall(path=dest_dir)
Dismissed Show dismissed Hide dismissed
76 changes: 48 additions & 28 deletions nemo/export/tensorrt_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,12 @@
import wrapt

from nemo.deploy import ITritonDeployable
from nemo.export.tarutils import TarPath
from nemo.export.tarutils import TarPath, unpack_tarball
from nemo.export.trt_llm.model_config_trt import model_config_to_tensorrt_llm
from nemo.export.trt_llm.nemo.nemo_ckpt_convert import build_tokenizer
from nemo.export.trt_llm.nemo_utils import get_tokenzier, nemo_llm_model_to_model_config, nemo_llm_to_model_config
from nemo.export.trt_llm.qnemo import qnemo_to_tensorrt_llm
from nemo.export.trt_llm.qnemo.tokenizer_utils import get_nmt_tokenizer
from nemo.export.trt_llm.tensorrt_llm_run import generate, generate_streaming, load, load_refit
from nemo.export.trt_llm.utils import is_nemo_file

Expand Down Expand Up @@ -188,32 +190,50 @@ def export(
tmp_dir = tempfile.TemporaryDirectory()
nemo_export_dir = Path(tmp_dir.name)

model_configs, self.tokenizer = nemo_llm_to_model_config(
in_file=nemo_checkpoint_path,
decoder_type=model_type,
dtype=dtype,
tensor_parallel_size=tensor_parallel_size,
pipeline_parallel_size=pipeline_parallel_size,
nemo_export_dir=nemo_export_dir,
save_nemo_model_config=save_nemo_model_config,
)
if nemo_checkpoint_path.endswith("qnemo"):
if os.path.isdir(nemo_checkpoint_path):
nemo_export_dir = nemo_checkpoint_path
else:
unpack_tarball(nemo_checkpoint_path, tmp_dir.name)
nemo_checkpoint_path = tmp_dir.name
self.tokenizer = get_nmt_tokenizer(nemo_checkpoint_path)

qnemo_to_tensorrt_llm(
nemo_checkpoint_path=nemo_checkpoint_path,
engine_dir=self.model_dir,
max_input_len=max_input_token,
max_output_len=max_output_token,
max_batch_size=max_batch_size,
max_prompt_embedding_table_size=max_prompt_embedding_table_size,
lora_target_modules=lora_target_modules,
)
else:
model_configs, self.tokenizer = nemo_llm_to_model_config(
in_file=nemo_checkpoint_path,
decoder_type=model_type,
dtype=dtype,
tensor_parallel_size=tensor_parallel_size,
pipeline_parallel_size=pipeline_parallel_size,
nemo_export_dir=nemo_export_dir,
save_nemo_model_config=save_nemo_model_config,
)

model_config_to_tensorrt_llm(
model_configs,
self.model_dir,
world_size=tensor_parallel_size * pipeline_parallel_size,
max_input_len=max_input_token,
max_output_len=max_output_token,
max_batch_size=max_batch_size,
max_prompt_embedding_table_size=max_prompt_embedding_table_size,
use_inflight_batching=use_inflight_batching,
paged_kv_cache=paged_kv_cache,
enable_context_fmha=enable_context_fmha,
enable_multi_block_mode=enable_multi_block_mode,
use_lora_plugin=use_lora_plugin,
lora_target_modules=lora_target_modules,
max_lora_rank=max_lora_rank,
)
model_config_to_tensorrt_llm(
model_configs,
self.model_dir,
world_size=tensor_parallel_size * pipeline_parallel_size,
max_input_len=max_input_token,
max_output_len=max_output_token,
max_batch_size=max_batch_size,
max_prompt_embedding_table_size=max_prompt_embedding_table_size,
use_inflight_batching=use_inflight_batching,
paged_kv_cache=paged_kv_cache,
enable_context_fmha=enable_context_fmha,
enable_multi_block_mode=enable_multi_block_mode,
use_lora_plugin=use_lora_plugin,
lora_target_modules=lora_target_modules,
max_lora_rank=max_lora_rank,
)

tokenizer_path = os.path.join(nemo_export_dir, "tokenizer.model")
if os.path.exists(tokenizer_path):
Expand Down Expand Up @@ -700,5 +720,5 @@ def _load(self):
raise Exception(
"Files in the TensorRT-LLM folder is corrupted and "
"model needs to be exported again. "
"Error message: " + str(error)
)
"Error message: " + repr(error)
) from error
16 changes: 16 additions & 0 deletions nemo/export/trt_llm/qnemo/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from .align_config import align_config
from .qnemo_to_tensorrt_llm import qnemo_to_tensorrt_llm
Loading
Loading