This document explains how to build the GPT-J model using TensorRT-LLM and run on a single GPU.
The TensorRT-LLM GPT-J implementation can be found in tensorrt_llm/models/gptj/model.py
. The TensorRT-LLM GPT-J example
code is located in examples/gptj
. There are three main files in that folder:
build.py
to build the TensorRT engine(s) needed to run the GPT-J model,run.py
to run the inference on an input text,summarize.py
to summarize the articles in the cnn_dailymail dataset using the model.
- FP16
- FP8
- INT4 Weight-Only
- FP8 KV CACHE
# 1. Weights & config
git clone https://huggingface.co/EleutherAI/gpt-j-6b gptj_model
pushd gptj_model && \
rm -f pytorch_model.bin && \
wget https://huggingface.co/EleutherAI/gpt-j-6b/resolve/main/pytorch_model.bin && \
popd
# 2. Vocab and merge table
wget https://huggingface.co/EleutherAI/gpt-j-6b/resolve/main/vocab.json
wget https://huggingface.co/EleutherAI/gpt-j-6b/resolve/main/merges.txt
TensorRT-LLM builds TensorRT engine(s) using a HF checkpoint. If no checkpoint directory is specified, TensorRT-LLM will build engine(s) using dummy weights.
Examples of build invocations:
# Build a float16 engine using HF weights.
# Enable several TensorRT-LLM plugins to increase runtime performance. It also helps with build time.
python3 build.py --dtype=float16 \
--log_level=verbose \
--enable_context_fmha \
--use_gpt_attention_plugin float16 \
--use_gemm_plugin float16 \
--max_batch_size=32 \
--max_input_len=1919 \
--max_output_len=128 \
--remove_input_padding \
--output_dir=gptj_engine \
--model_dir=gptj_model 2>&1 | tee build.log
# Build a float16 engine using dummy weights, useful for performance tests.
# Enable several TensorRT-LLM plugins to increase runtime performance. It also helps with build time.
python3 build.py --dtype=float16 \
--log_level=verbose \
--enable_context_fmha \
--use_gpt_attention_plugin float16 \
--use_gemm_plugin float16 \
--max_batch_size=32 \
--max_input_len=1919 \
--max_output_len=128 \
--remove_input_padding \
--output_dir=gptj_engine_dummy_weights 2>&1 | tee build.log
# Build an int4 weight only quantization engine using awq int4 weight only quantized weights.
# Enable several TensorRT-LLM plugins to increase runtime performance. It also helps with build time.
python3 build.py --dtype=float16 \
--log_level=verbose \
--enable_context_fmha \
--use_gpt_attention_plugin float16 \
--use_gemm_plugin float16 \
--max_batch_size=32 \
--max_input_len=1919 \
--max_output_len=128 \
--remove_input_padding \
--output_dir=gptj_engine \
--use_weight_only \
--per_group \
--weight_only_precision=int4 \
--model_dir=awq_int4_weight_only_quantized_models 2>&1 | tee build.log
The examples below uses the NVIDIA AMMO (AlgorithMic Model Optimization) toolkit for the model quantization process.
First make sure AMMO toolkit is installed (see examples/quantization/README.md)
Now quantize HF GPT-J weights as follows.
After successfully running the script, the output should be in .npz format, e.g. quantized_fp8/gptj_tp1_rank0.npz
,
where FP8 scaling factors are stored.
# Quantize HF GPT-J 6B checkpoint into FP8 format
python quantize.py --model_dir gptj_model \
--dtype float16 \
--qformat fp8 \
--export_path ./quantized_fp8 \
--calib_size 512
# Build GPT-J 6B using original HF checkpoint + PTQ scaling factors
python build.py --model_dir gptj_model \
--quantized_fp8_model_path ./quantized_fp8/gptj_tp1_rank0.npz \
--dtype float16 \
--use_gpt_attention_plugin float16 \
--enable_context_fmha \
--enable_two_optimization_profiles \
--output_dir gptj_engine_fp8_quantized \
--enable_fp8 \
--fp8_kv_cache \
--strongly_typed
You can enable the FMHA kernels for GPT by adding --enable_context_fmha
to the invocation of build.py
. Note that it is disabled by default because of possible accuracy issues due to the use of Flash Attention.
If you find that the default fp16 accumulation (--enable_context_fmha
) cannot meet the requirement, you can try to enable fp32 accumulation by adding --enable_context_fmha_fp32_acc
. However, it is expected to see performance drop.
Note --enable_context_fmha
/ --enable_context_fmha_fp32_acc
has to be used together with --use_gpt_attention_plugin float16
.
One can enable FP8 for KV cache to reduce memory footprint used by KV cache and improve the accuracy over INT8 KV cache. There are 3 options need to be added to the invocation of build.py
for that:
--enable_fp8
enables FP8 GEMMs in the network.--fp8_kv_cache
to enable FP8 accurancy for KV cache.--quantized_fp8_model_path
to provide path to the quantized model calibrated for FP8. For more details see quantization docs.
One can enable AWQ INT4 weight only quantization with these 3 options when building engine with build.py
:
--use_weight_only
enables weight only GEMMs in the network.--per_group
enable groupwise weight only quantization, for GPT-J example, we support AWQ with the group size default as 128.--weight_only_precision=int4
the precision of weight only quantization. Only int4 is supported for groupwise weight only quantization.
The linear layer in the AWQ int4 weight only quantized weights should have 3 parameters:
- FP16 smoothed_weights (=weights/pre_quant_scale) with shape [n, k] ;
- FP16 amax (the max abs values of the smoothed_weights) with shape [n, k/group_size];
- FP16 pre_quant_scale (the smooth scales used to multiply by activation) with shape [k];
To run a TensorRT-LLM GPT-J model:
python3 run.py --max_output_len=50 --engine_dir=gptj_engine
The following section describes how to run a TensorRT-LLM GPT-J model to summarize the articles from the
cnn_dailymail dataset. For each summary, the script can compute the
ROUGE scores and use the ROUGE-1
score to validate the implementation.
The script can also perform the same summarization using the HF GPT-J model.
As previously explained, the first step is to build the TensorRT engine as described above using HF weights. You also have to install the requirements:
pip install -r requirements.txt
The summarization can be done using the summarize.py
script as follows:
# Run the summarization task.
python3 summarize.py --engine_dir gptj_engine \
--model_dir gptj_model \
--test_hf \
--batch_size 1 \
--test_trt_llm \
--tensorrt_llm_rouge1_threshold 14 \
--data_type fp16 \
--check_accuracy
- You must enable the LayerNorm plugin to build the engine for GPT-J when using TensorRT 8.6, this constraint is removed in TensorRT 9.0. To enable LayerNorm plugin, you should add
--use_layernorm_plugin <float16 or float32>
in the build.py, see build.py commands example above.