Skip to content

[Bug]: External Launcher producing NaN outputs on Large Models when Collocating with Model Training #14443

@fabianlim

Description

@fabianlim

Your current environment

The output of `python collect_env.py`
PyTorch version: 2.5.1+cu124
Is debug build: False
CUDA used to build PyTorch: 12.4
ROCM used to build PyTorch: N/A

OS: CentOS Stream 9 (x86_64)
GCC version: (GCC) 11.5.0 20240719 (Red Hat 11.5.0-2)
Clang version: Could not collect
CMake version: Could not collect
Libc version: glibc-2.34

Python version: 3.10.16 (main, Feb  5 2025, 05:53:54) [GCC 11.5.0 20240719 (Red Hat 11.5.0-2)] (64-bit runtime)
Python platform: Linux-5.14.0-284.73.1.el9_2.x86_64-x86_64-with-glibc2.34
Is CUDA available: True
CUDA runtime version: 12.4.131
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration:
GPU 0: NVIDIA A100-SXM4-80GB
GPU 1: NVIDIA A100-SXM4-80GB
GPU 2: NVIDIA A100-SXM4-80GB
GPU 3: NVIDIA A100-SXM4-80GB
Nvidia driver version: 550.54.14


Versions of relevant libraries:
[pip3] numpy==1.26.4
[pip3] nvidia-cublas-cu12==12.4.5.8
[pip3] nvidia-cuda-cupti-cu12==12.4.127
[pip3] nvidia-cuda-nvrtc-cu12==12.4.127
[pip3] nvidia-cuda-runtime-cu12==12.4.127
[pip3] nvidia-cudnn-cu12==9.1.0.70
[pip3] nvidia-cufft-cu12==11.2.1.3
[pip3] nvidia-curand-cu12==10.3.5.147
[pip3] nvidia-cusolver-cu12==11.6.1.9
[pip3] nvidia-cusparse-cu12==12.3.1.170
[pip3] nvidia-cusparselt-cu12==0.6.2
[pip3] nvidia-nccl-cu12==2.21.5
[pip3] nvidia-nvjitlink-cu12==12.4.127
[pip3] nvidia-nvtx-cu12==12.4.127
[pip3] pyzmq==26.2.1
[pip3] torch==2.5.1
[pip3] torchaudio==2.5.1
[pip3] torchvision==0.20.1
[pip3] transformers==4.49.0
[pip3] triton==3.1.0

🐛 Describe the bug

We are observing the following when collocating a relatively large model (e.g. 72B) using external launcher with a model sharding framework (e.g., FSDP).

Motivation: for RLHF, collocating is desirable, but when the model is large, we expect to shard both the vllm and model in training so that we can fit into the GPU.

  • for VLLM, it would be TP (or PP)
  • for model in training, it would be using a framework like FSDP and DeepSpeed

Bug: we are observing NaN's coming from the TPed VLLM model when the model size becomes big. We have a reproduction script in this gist which demostrates the following behavior

  • in this script, we demonstrate two scenarios, i) the model in train is sharded with FSDP, ii) the model is unsharded. This is controlled by setting shard_train_model=True and False, respectively.
  • For demonstration purposes: the model in train is a very small OPT model. In actual use-cases, the model in train is usually the same as the VLLM model. But this is purposely done to make the script runnable in 4 x 80GB GPUs, and also to demonstrate this observation is independent on the size of the model in train.

Setting shard_train_model=False, the outputs are correct:

# text (0):  We can visualize the spheres with radii 11, 13, and 19
# text (1):  First, let's denote $c = \log_{2^b}(2^{100
# text (2): First, we need to find the dimensions of the rhombus. A key property of the rh
# text (3): First, I’ll start by simplifying the given equation. We are given the equation: $\sqrt

Setting shard_train_model=True, the outputs become garbled due to NaN's creeping into the tensors:

# - the strange characters caused by NaN's coming from TP reduce
# text (0):  We can visualize the spheres with radii 11, 13, and 19
# text (1): !!!!!!!!!!!!!!!!!!!!
# text (2): !!!!!!!!!!!!!!!!!!!!
# text (3): !!!!!!!!!!!!!!!!!!!!

Reproduction code found in this gist.

Work-around

In the actual use case, turning off FSDP for the model-in-train is out of the question. I have found another workaround, which is to set max_num_seq=1, which slows down my inference but at least still allows me to fit all models in the GPUs. But this is not ideal as there is no more continous batching

Before submitting a new issue...

  • Make sure you already searched for relevant issues, and asked the chatbot living at the bottom right corner of the documentation page, which can answer lots of frequently asked questions.

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't workingstaleOver 90 days of inactivity

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions