Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add RayService vLLM TPU Inference script #1467

Open
wants to merge 18 commits into
base: main
Choose a base branch
from

Conversation

ryanaoleary
Copy link
Contributor

Description

This PR adds a simple inference script to be used for a Ray multi-host TPU example serving Meta-Llama-3-70B. Similar to the other scripts in the /llm/ folder, serve_tpu.py builds a serve deployment for vLLM, which can then be queried with text prompts to generate output. This script will be used as part of a tutorial in the GKE and Ray docs.

Tasks

  • The contributing guide has been read and followed.
  • The samples added / modified have been fully tested.
  • Workflow files have been added / modified, if applicable.
  • Region tags have been properly added, if new samples.
  • All dependencies are set to up-to-date versions, as applicable.
  • Merge this pull-request for me once it is approved.

@ryanaoleary ryanaoleary force-pushed the multihost-example branch 2 times, most recently from 3ff9092 to 717b6ef Compare September 25, 2024 00:24
Signed-off-by: Ryan O'Leary <ryanaoleary@google.com>

bug fixes

Signed-off-by: Ryan O'Leary <ryanaoleary@google.com>

remove extra ray init

Signed-off-by: Ryan O'Leary <ryanaoleary@google.com>

Read hf token from os

Signed-off-by: Ryan O'Leary <ryanaoleary@google.com>

Fix bugs

Signed-off-by: Ryan O'Leary <ryanaoleary@google.com>

Remove hf token logic

Signed-off-by: Ryan O'Leary <ryanaoleary@google.com>

Fix serve script

Signed-off-by: Ryan O'Leary <ryanaoleary@google.com>
@andrewsykim
Copy link
Collaborator

Do we need a RayService YAML in the repo with region tags that you can reference in the GCP docs?

ai-ml/gke-ray/rayserve/llm/serve_tpu.py Outdated Show resolved Hide resolved
# See the License for the specific language governing permissions and
# limitations under the License.

# NOTE: this file was inspired from: https://github.com/richardsliu/vllm/blob/rayserve/examples/rayserve_tpu.py
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@richardsliu can we get this example merged into the vllm repo?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I opened this one a while back: vllm-project/vllm#8038

I'll ping them again on it.

Signed-off-by: Ryan O'Leary <ryanaoleary@google.com>
Signed-off-by: Ryan O'Leary <ryanaoleary@google.com>
Copy link

snippet-bot bot commented Sep 27, 2024

Here is the summary of changes.

You are about to add 4 region tags.

This comment is generated by snippet-bot.
If you find problems with this result, please file an issue at:
https://github.com/googleapis/repo-automation-bots/issues.
To update this comment, add snippet-bot:force-run label or use the checkbox below:

  • Refresh this comment

@ryanaoleary
Copy link
Contributor Author

Do we need a RayService YAML in the repo with region tags that you can reference in the GCP docs?

Yeah that sounds good. I'm still testing out the 405B RayService, but I added the 8B and 70B ones in fe6440c, we can then use envsubst to replace the image var.

Signed-off-by: Ryan O'Leary <ryanaoleary@google.com>
@ryanaoleary
Copy link
Contributor Author

I've tried running LLama-3.1-405B with TPU slice sizes up to 4x4x8 v4 and 8x16 v5e and ran into a few issues:

  1. As slice sizes grow larger, the amount of time needed for vLLM initialization and memory profiling grows incredibly large
  2. Attempting to run inference with smaller topologies than the aforementioned slice sizes leads to errors like RuntimeError: Bad StatusOr access: RESOURCE_EXHAUSTED: XLA:TPU compile permanent error. Ran out of memory in memory space hbm. Used 20.44G of 15.75G hbm. Exceeded hbm capacity by 4.70G., since the TPUs only have 16 Gi and 32 Gi for v5e and v4 TPUs respectively. Relatively small HBM capacity (compared to GPUs) means that we need much larger slice sizes to fit the sharded weights.
  3. Even when loading a model that will be sharded (i.e. with tensor-parallelism > 1), vLLM still downloads the entire model on each worker, only afterwards storing the relevant weights in new files to each worker. This means that larger slice sizes will require an extremely high amount of total disk space when loading large models.
  4. Larger multi-host slice sizes lead to ValueError: Too large swap space. errors, where vLLM attempts to allocate more than the total amount of available CPU memory to the swap space. I've gotten around this error by simply setting swap_space=0 in the vLLM EngineArgs, but I'm worried this slows down the model loading.
  5. vLLM lacks support for running multiple multi-host TPU slices (i.e. just with specifying pipeline-parallelism > 1)
  6. vLLM TPU backend lacks support for loading quantized models

If the user has sufficient quota for TPU chips and SSD in their region, a v4 4x4x8 or v5e 8x16 are large enough to run multi-host inference with Llama-3.1-405B. However, I'm wondering whether I'm missing anything obvious here (with the current amount of TPU support in vLLM) that could allow us to a). load the model faster and b). require less disk space when initializing the model.

cc: @richardsliu @andrewsykim

ryanaoleary and others added 4 commits October 2, 2024 05:18
Signed-off-by: Ryan O'Leary <ryanaoleary@google.com>
Signed-off-by: Ryan O'Leary <ryanaoleary@google.com>
Signed-off-by: Ryan O'Leary <ryanaoleary@google.com>
Signed-off-by: Ryan O'Leary <ryanaoleary@google.com>
Signed-off-by: Ryan O'Leary <ryanaoleary@google.com>
Signed-off-by: Ryan O'Leary <ryanaoleary@google.com>
Signed-off-by: Ryan O'Leary <ryanaoleary@google.com>
Signed-off-by: Ryan O'Leary <ryanaoleary@google.com>
Signed-off-by: Ryan O'Leary <ryanaoleary@google.com>
Signed-off-by: Ryan O'Leary <ryanaoleary@google.com>
Signed-off-by: ryanaoleary <ryanaoleary@google.com>
Signed-off-by: ryanaoleary <ryanaoleary@google.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants