Skip to content

Latest commit

 

History

History
 
 

gptneox

GPT-NeoX

This document explains how to build the GPT-NeoX model using TensorRT-LLM and run on a single GPU.

Overview

The TensorRT-LLM GPT-NeoX implementation can be found in tensorrt_llm/models/gptneox/model.py. The TensorRT-LLM GPT-NeoX example code is located in examples/gptneox. There are three main files in that folder:

Support Matrix

  • FP16
  • Tensor Parallel

Usage

1. Download weights from HuggingFace (HF) Transformers

# Weights & config
sh get_weights.sh

2. Build TensorRT engine(s)

TensorRT-LLM builds TensorRT engine(s) using a HF checkpoint. If no checkpoint directory is specified, TensorRT-LLM will build engine(s) using dummy weights.

Examples of build invocations:

# Build a float16 engine using a single GPU and HF weights.
# Enable several TensorRT-LLM plugins to increase runtime performance. It also helps with build time.
python3 build.py --dtype=float16                    \
                 --log_level=verbose                \
                 --use_gpt_attention_plugin float16 \
                 --use_gemm_plugin float16          \
                 --use_layernorm_plugin float16     \
                 --max_batch_size=16                \
                 --max_input_len=1024               \
                 --max_output_len=1024              \
                 --output_dir=gptneox_engine        \
                 --model_dir=gptneox_model 2>&1 | tee build.log

# Build a float16 engine using a single GPU and dummy weights.
# Using dummy weights is useful for performance tests.
# Enable several TensorRT-LLM plugins to increase runtime performance. It also helps with build time.
python3 build.py --dtype=float16                    \
                 --log_level=verbose                \
                 --use_gpt_attention_plugin float16 \
                 --use_gemm_plugin float16          \
                 --use_layernorm_plugin float16     \
                 --max_batch_size=16                \
                 --max_input_len=1024               \
                 --max_output_len=1024              \
                 --output_dir=gptneox_engine_dummy_weights 2>&1 | tee build.log

# Build a float16 engine using 2-way tensor parallelism and HF weights.
# Enable several TensorRT-LLM plugins to increase runtime performance. It also helps with build time.
python3 build.py --dtype=float16                    \
                 --log_level=verbose                \
                 --use_gpt_attention_plugin float16 \
                 --use_gemm_plugin float16          \
                 --use_layernorm_plugin float16     \
                 --max_batch_size=16                \
                 --max_input_len=1024               \
                 --max_output_len=1024              \
                 --world_size=2                     \
                 --output_dir=gptneox_engine_tp2    \
                 --model_dir=gptneox_model 2>&1 | tee build_tp2.log

Fused MultiHead Attention (FMHA)

You can enable the FMHA kernels for GPT by adding --enable_context_fmha to the invocation of build.py. Note that it is disabled by default because of possible accuracy issues due to the use of Flash Attention.

If you find that the default fp16 accumulation (--enable_context_fmha) cannot meet the requirement, you can try to enable fp32 accumulation by adding --enable_context_fmha_fp32_acc. However, it is expected to see performance drop.

Note --enable_context_fmha / --enable_context_fmha_fp32_acc has to be used together with --use_gpt_attention_plugin float16.

3. Run

To run a TensorRT-LLM GPT-NeoX model using the engines generated by build.py:

# For a single GPU
python3 run.py --max_output_len=50 --engine_dir=gptneox_engine

# For 2-way tensor parallelism
mpirun -n 2 --allow-run-as-root python3 run.py --max_output_len=50 --engine_dir=gptneox_engine_tp2

4. Summarization using the GPT-NeoX model

The following section describes how to run a TensorRT-LLM GPT-NeoX model to summarize the articles from the cnn_dailymail dataset. For each summary, the script can compute the ROUGE scores and use the ROUGE-1 score to validate the implementation. The script can also perform the same summarization using the HF GPT-NeoX model.

As previously explained, the first step is to build the TensorRT engine as described above using HF weights. You also have to install the requirements:

pip install -r requirements.txt

The summarization can be done using the summarize.py script as follows:

# Run the summarization task using a TensorRT-LLM model and a single GPU.
python3 summarize.py --engine_dir gptneox_engine        \
                     --model_dir gptneox_model          \
                     --batch_size 1                     \
                     --test_trt_llm                     \
                     --tensorrt_llm_rouge1_threshold 14 \
                     --data_type fp16                   \
                     --check_accuracy 2>&1 | tee summary_trt_llm.log

# Run the summarization task using a HF model and a single GPU.
python3 summarize.py --engine_dir gptneox_engine        \
                     --model_dir gptneox_model          \
                     --batch_size 1                     \
                     --test_hf                          \
                     --tensorrt_llm_rouge1_threshold 14 \
                     --data_type fp16                   \
                     --check_accuracy 2>&1 | tee summary_hf.log

# Run the summarization task using a TensorRT-LLM model and 2-way tensor parallelism.
mpirun -n 2 --allow-run-as-root                         \
python3 summarize.py --engine_dir gptneox_engine_tp2    \
                     --model_dir gptneox_model          \
                     --batch_size 1                     \
                     --test_trt_llm                     \
                     --tensorrt_llm_rouge1_threshold 14 \
                     --data_type fp16                   \
                     --check_accuracy 2>&1 | tee summary_trt_llm_tp2.log

Apply groupwise quantization GPTQ

1. Download weights from HuggingFace (HF)

# Weights & config
sh get_weights.sh

2. Generating quantized weights

In this example, the weights are quantized using GPTQ-for-LLaMa. Note that the parameter --act-order referring to whether to apply the activation order GPTQ heuristic is not supported by TRT-LLM.

sh gptq_convert.sh

3. Build TensorRT engine(s)

# Build a engine applying INT4 GPTQ quantization using a single GPU and the generated quantized weights.
# Enable several TensorRT-LLM plugins to increase runtime performance. It also helps with build time.
# Set --use_weight_only_groupwise_quant_matmul_plugin to enable GPTQ
python3 build.py --dtype=float16                                                                     \
                 --log_level=verbose                                                                 \
                 --use_gpt_attention_plugin float16                                                  \
                 --use_gemm_plugin float16                                                           \
                 --use_layernorm_plugin float16                                                      \
                 --use_weight_only_groupwise_quant_matmul_plugin float16                             \
                 --groupwise_quant_safetensors_path=gptneox_model/gptneox-20b-4bit-gs128.safetensors \
                 --max_batch_size=16                                                                 \
                 --max_input_len=1024                                                                \
                 --max_output_len=1024                                                               \
                 --output_dir=gptneox_engine_gptq                                                    \
                 --model_dir=gptneox_model 2>&1 | tee build_gptq.log

# Build a engine applying INT4 GPTQ quantization using 2-way tensor parallelism and the generated quantized weights.
# Enable several TensorRT-LLM plugins to increase runtime performance. It also helps with build time.
# Set --use_weight_only_groupwise_quant_matmul_plugin to enable GPTQ
python3 build.py --dtype=float16                                                                     \
                 --log_level=verbose                                                                 \
                 --use_gpt_attention_plugin float16                                                  \
                 --use_gemm_plugin float16                                                           \
                 --use_layernorm_plugin float16                                                      \
                 --use_weight_only_groupwise_quant_matmul_plugin float16                             \
                 --groupwise_quant_safetensors_path=gptneox_model/gptneox-20b-4bit-gs128.safetensors \
                 --max_batch_size=16                                                                 \
                 --max_input_len=1024                                                                \
                 --max_output_len=1024                                                               \
                 --world_size=2                                                                      \
                 --output_dir=gptneox_engine_gptq_tp2                                                \
                 --model_dir=gptneox_model 2>&1 | tee build_gptq_tp2.log

4. Run the GPTQ quantized GPT-NeoX model

# For a single GPU
python3 run.py --max_output_len=50 --engine_dir=gptneox_engine_gptq

# For 2-way tensor parallelism
mpirun -n 2 --allow-run-as-root python3 run.py --max_output_len=50 --engine_dir=gptneox_engine_gptq_tp2

5. Summarize with the GPTQ quantized GPT-NeoX model

Install the requirements first.

pip install -r requirements.txt

Then use the summarize.py script to summarize.

# Run the summarization task using a TensorRT-LLM model and a single GPU.
python3 summarize.py --engine_dir gptneox_engine_gptq     \
                     --model_dir gptneox_model            \
                     --batch_size 1                       \
                     --test_trt_llm                       \
                     --tensorrt_llm_rouge1_threshold 14   \
                     --data_type fp16                     \
                     --check_accuracy 2>&1 | tee summary_trt_llm_gptq.log

# Run the summarization task using a TensorRT-LLM model and 2-way tensor parallelism.
mpirun -n 2 --allow-run-as-root                           \
python3 summarize.py --engine_dir gptneox_engine_gptq_tp2 \
                     --model_dir gptneox_model            \
                     --batch_size 1                       \
                     --test_trt_llm                       \
                     --tensorrt_llm_rouge1_threshold 14   \
                     --data_type fp16                     \
                     --check_accuracy 2>&1 | tee summary_trt_llm_gptq_tp2.log