Skip to content

Latest commit

 

History

History
 
 

phi

Folders and files

NameName
Last commit message
Last commit date

parent directory

..
 
 
 
 
 
 
 
 

Phi

This document explains how to build the Phi model using TensorRT-LLM and run on a single GPU. Note that both Phi-2 and Phi-1.5 are supported.

Overview

The TensorRT-LLM Phi implementation can be found in tensorrt_llm/models/phi/model.py. The TensorRT-LLM Phi example code is located in examples/phi. There is one main file:

In addition, there are two shared files in the parent folder examples for inference and evaluation:

Support Matrix

  • FP16
  • BF16
  • Tensor Parallel

Usage

1. Download weights from HuggingFace (HF) Transformers

# Make sure you have git-lfs installed (https://git-lfs.com)
git lfs install
git clone https://huggingface.co/microsoft/phi-2

2. Build TensorRT engine(s)

TensorRT-LLM builds TensorRT engine(s) using a HF checkpoint. If no checkpoint directory is specified, TensorRT-LLM will build engine(s) using dummy weights.

Examples of build invocations:

# Build a float16 engine using a single GPU and HF weights.
# Enable several TensorRT-LLM plugins to increase runtime performance. It also helps with build time.
python3 build.py --dtype=float16                    \
                 --log_level=verbose                \
                 --use_gpt_attention_plugin float16 \
                 --use_gemm_plugin float16          \
                 --max_batch_size=16                \
                 --max_input_len=1024               \
                 --max_output_len=1024              \
                 --output_dir=phi_engine            \
                 --model_dir=phi_model 2>&1 | tee build.log

# Build a bfloat16 engine using a single GPU and HF weights.
# Enable several TensorRT-LLM plugins to increase runtime performance. It also helps with build time.
python3 build.py --dtype=bfloat16                    \
                 --log_level=verbose                 \
                 --use_gpt_attention_plugin bfloat16 \
                 --use_gemm_plugin bfloat16          \
                 --max_batch_size=16                 \
                 --max_input_len=1024                \
                 --max_output_len=1024               \
                 --output_dir=phi_engine             \
                 --model_dir=phi_model 2>&1 | tee build.log

# Build a float16 engine using a single GPU and dummy weights.
# Using dummy weights is useful for performance tests.
# Enable several TensorRT-LLM plugins to increase runtime performance. It also helps with build time.
python3 build.py --dtype=float16                    \
                 --log_level=verbose                \
                 --use_gpt_attention_plugin float16 \
                 --use_gemm_plugin float16          \
                 --max_batch_size=16                \
                 --max_input_len=1024               \
                 --max_output_len=1024              \
                 --output_dir=phi_engine_dummy_weights 2>&1 | tee build.log

# Build a float16 engine using 2-way tensor parallelism and HF weights.
# Enable several TensorRT-LLM plugins to increase runtime performance. It also helps with build time.
python3 build.py --dtype=float16                    \
                 --log_level=verbose                \
                 --use_gpt_attention_plugin float16 \
                 --use_gemm_plugin float16          \
                 --max_batch_size=16                \
                 --max_input_len=1024               \
                 --max_output_len=1024              \
                 --world_size=2                     \
                 --output_dir=phi_engine_tp2        \
                 --model_dir=phi_model 2>&1 | tee build_tp2.log

Fused MultiHead Attention (FMHA)

You can enable the FMHA kernels by adding --enable_context_fmha to the invocation of build.py. Note that it is disabled by default because of possible accuracy issues due to the use of Flash Attention.

If you find that the default fp16 accumulation (--enable_context_fmha) cannot meet accuracy requirements, you can try to enable fp32 accumulation by adding --enable_context_fmha_fp32_acc. However, it may have an impact on performance.

Note --enable_context_fmha / --enable_context_fmha_fp32_acc have to be used together with --use_gpt_attention_plugin float16.

3. Run

To run a TensorRT-LLM Phi model using the engines generated by build.py:

# For a single GPU
python3 ../run.py --max_output_len=50 --engine_dir=phi_engine --tokenizer_dir=phi_model

# For 2-way tensor parallelism
mpirun -n 2 --allow-run-as-root python3 ../run.py --max_output_len=50 --engine_dir=phi_engine_tp2 --tokenizer_dir=phi_model

4. Summarization using the Phi model

The following section describes how to run a TensorRT-LLM Phi model to summarize the articles from the cnn_dailymail dataset. For each summary, the script can compute the ROUGE scores and use the ROUGE-1 score to validate the implementation. The script can also perform the same summarization using the HF Phi model.

As previously explained, the first step is to build the TensorRT engine as described above using HF weights. You also have to install the requirements:

pip install -r requirements.txt

The summarization can be done using the ../summarize.py script as follows:

# Run the summarization task using a TensorRT-LLM model and a single GPU.
python3 ../summarize.py --engine_dir phi_engine             \
                        --hf_model_dir phi_model            \
                        --batch_size 1                      \
                        --test_trt_llm                      \
                        --tensorrt_llm_rouge1_threshold 14  \
                        --data_type fp16                    \
                        --check_accuracy 2>&1 | tee summary_trt_llm.log

# Run the summarization task using a HF model and a single GPU.
python3 ../summarize.py --engine_dir phi_engine             \
                        --hf_model_dir phi_model            \
                        --batch_size 1                      \
                        --test_hf                           \
                        --tensorrt_llm_rouge1_threshold 14  \
                        --data_type fp16                    \
                        --check_accuracy 2>&1 | tee summary_hf.log

# Run the summarization task using a TensorRT-LLM model and 2-way tensor parallelism.
mpirun -n 2 --allow-run-as-root                             \
python3 ../summarize.py --engine_dir phi_engine_tp2         \
                        --hf_model_dir phi_model            \
                        --batch_size 1                      \
                        --test_trt_llm                      \
                        --tensorrt_llm_rouge1_threshold 14  \
                        --data_type fp16                    \
                        --check_accuracy 2>&1 | tee summary_trt_llm_tp2.log