Skip to content

Latest commit

 

History

History
301 lines (219 loc) · 11.1 KB

README.rst

File metadata and controls

301 lines (219 loc) · 11.1 KB

License

Transformer Engine

Quickstart | Installation | User Guide | Examples | Model Support | Integrations | Release notes

Latest News

What is Transformer Engine?

Transformer Engine (TE) is a library for accelerating Transformer models on NVIDIA GPUs, including using 8-bit floating point (FP8) precision on Hopper GPUs, to provide better performance with lower memory utilization in both training and inference. TE provides a collection of highly optimized building blocks for popular Transformer architectures and an automatic mixed precision-like API that can be used seamlessly with your framework-specific code. TE also includes a framework agnostic C++ API that can be integrated with other deep learning libraries to enable FP8 support for Transformers.

As the number of parameters in Transformer models continues to grow, training and inference for architectures such as BERT, GPT and T5 become very memory and compute intensive. Most deep learning frameworks train with FP32 by default. This is not essential, however, to achieve full accuracy for many deep learning models. Using mixed-precision training, which combines single-precision (FP32) with lower precision (e.g. FP16) format when training a model, results in significant speedups with minimal differences in accuracy as compared to FP32 training. With Hopper GPU architecture FP8 precision was introduced, which offers improved performance over FP16 with no degradation in accuracy. Although all major deep learning frameworks support FP16, FP8 support is not available natively in frameworks today.

TE addresses the problem of FP8 support by providing APIs that integrate with popular Large Language Model (LLM) libraries. It provides a Python API consisting of modules to easily build a Transformer layer as well as a framework agnostic library in C++ including structs and kernels needed for FP8 support. Modules provided by TE internally maintain scaling factors and other values needed for FP8 training, greatly simplifying mixed precision training for users.

Highlights

  • Easy-to-use modules for building Transformer layers with FP8 support
  • Optimizations (e.g. fused kernels) for Transformer models
  • Support for FP8 on NVIDIA Hopper and NVIDIA Ada GPUs
  • Support for optimizations across all precisions (FP16, BF16) on NVIDIA Ampere GPU architecture generations and later

Examples

PyTorch

import torch
import transformer_engine.pytorch as te
from transformer_engine.common import recipe

# Set dimensions.
in_features = 768
out_features = 3072
hidden_size = 2048

# Initialize model and inputs.
model = te.Linear(in_features, out_features, bias=True)
inp = torch.randn(hidden_size, in_features, device="cuda")

# Create an FP8 recipe. Note: All input args are optional.
fp8_recipe = recipe.DelayedScaling(margin=0, interval=1, fp8_format=recipe.Format.E4M3)

# Enable autocasting for the forward pass
with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):
    out = model(inp)

loss = out.sum()
loss.backward()

JAX

Flax
import jax
import jax.numpy as jnp
import transformer_engine.jax as te
import transformer_engine.jax.flax as te_flax
from transformer_engine.common import recipe

BATCH = 32
SEQLEN = 128
HIDDEN = 1024

# Initialize RNG and inputs.
rng = jax.random.PRNGKey(0)
init_rng, data_rng = jax.random.split(rng)
inp = jax.random.normal(data_rng, [BATCH, SEQLEN, HIDDEN], jnp.float32)

# Create an FP8 recipe. Note: All input args are optional.
fp8_recipe = recipe.DelayedScaling(margin=0, interval=1, fp8_format=recipe.Format.HYBRID)

# Enable autocasting for the forward pass
with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):
    model = te_flax.DenseGeneral(features=HIDDEN)

    def loss_fn(params, other_vars, inp):
      out = model.apply({'params':params, **other_vars}, inp)
      return jnp.mean(out)

    # Initialize models.
    variables = model.init(init_rng, inp)
    other_variables, params = variables.pop('params')

    # Construct the forward and backward function
    fwd_bwd_fn = jax.value_and_grad(loss_fn, argnums=(0, 1))

    for _ in range(10):
      loss, (param_grads, other_grads) = fwd_bwd_fn(params, other_variables, inp)
      # Update FP8 metas
      other_variables = te.update_fp8_metas(other_grads)

TensorFlow

import tensorflow as tf
import transformer_engine.tensorflow as te
from transformer_engine.common import recipe

# Set dimensions.
in_features = 768
out_features = 3072
hidden_size = 2048

# Initialize model and inputs.
model = te.Dense(out_features, use_bias=True)
inp = tf.random.normal((hidden_size, in_features))

optimizer = tf.keras.optimizers.Adam(0.001)

# Create FP8 recipe. Note: All input args are optional.
fp8_recipe = recipe.DelayedScaling(margin=0, interval=1, fp8_format=recipe.Format.E4M3)

with tf.GradientTape(persistent=True) as tape:
    # Enables autocasting for the forward pass
    with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):
        out = model(inp)
    loss = tf.reduce_sum(out)
grads = tape.gradient(loss, model.trainable_variables)
optimizer.apply_gradients(zip(grads, model.trainable_variables))

Installation

In the NGC container

The quickest way to get started with Transformer Engine is the NGC PyTorch container on NVIDIA GPU Cloud Catalog (versions 22.09 and later).

docker run --gpus all -it --rm nvcr.io/nvidia/pytorch:23.04-py3

Where 23.04 is the container version. For example, 23.04 for April 2023 release.

Pre-requisites

  • Linux x86_64
  • CUDA 11.8 or later
  • NVIDIA Driver supporting CUDA 11.8 or later
  • cuDNN 8.1 or later
  • For fused attention, CUDA 12.1 or later, NVIDIA Driver supporting CUDA 12.1 or later, and cuDNN 8.9 or later.

From source

See the installation guide.

Compiling with Flash Attention 2

TransformerEngine release v0.11.0 adds support for Flash Attention 2.0 for improved performance. It is a known issue that Flash Attention 2.0 compilation is resource intensive and requires a large amount of RAM (see bug), which may lead to out of memory errors during the installation of TransformerEngine. To circumvent the issue, please try setting MAX_JOBS=1 in the environment. If the errors persist, then proceed to install a supported version of Flash Attention 1 (v1.0.6 to v1.0.9).

Model Support

While the more granular modules in Transformer Engine allow building any Transformer architecture, the TransformerLayer API of Transformer Engine is flexible enough to build multiple major Transformer model architectures.

Transformer Engine supports the following DL frameworks: PyTorch, JAX (Flax, Praxis), and TensorFlow.

NOTE: For simplicity, we only show PyTorch examples below. For the usage of TransformerLayer of all supported frameworks, refer to examples.

GPT

GPT architecture has LayerNorm at the input side (before QKV Gemm) and the residual connection is taken from the input of that LayerNorm. In TE this can be achieved by setting the following arguments in the TransformerLayer API.

transformer_engine.pytorch.TransformerLayer(
        ...,
        ...,
        apply_residual_connection_post_layernorm=False,
        output_layernorm=False,
        layer_type="encoder",
)

BERT

BERT architecture has LayerNorm at the output side (after the final BiasDropoutAdd) and the residual connection is taken from the output of that LayerNorm. In TE this can be achieved by setting the following arguments in the TransformerLayer API.

transformer_engine.pytorch.TransformerLayer(
        ...,
        ...,
        apply_residual_connection_post_layernorm=True,
        output_layernorm=True,
        layer_type="encoder",
)

T5

T5 architecture has an additional cross-attention + BiasDropoutAdd + LayerNorm block before the MLP layer. In TE this can be added by setting the layer_type to decoder in the TransformerLayer API.

transformer_engine.pytorch.TransformerLayer(
        ...,
        ...,
        layer_type="decoder",
)

Integrations

Transformer Engine has been integrated with several popular open-source DL frameworks such as:

Contributing

We welcome contributions to Transformer Engine! To contribute to Transformer Engine and make pull requests, follow the guidelines outlined in the CONTRIBUTING.rst guide.

Papers

Videos