Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add vLLM TPU example RayService manifest #3000

Merged
merged 3 commits into from
Feb 12, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
apiVersion: ray.io/v1
kind: RayService
metadata:
name: vllm-tpu
spec:
serveConfigV2: |
applications:
- name: llm
import_path: ray-operator.config.samples.vllm.serve_tpu:model
deployments:
- name: VLLMDeployment
num_replicas: 1
runtime_env:
working_dir: "https://github.com/ray-project/kuberay/archive/master.zip"
env_vars:
MODEL_ID: "$MODEL_ID"
MAX_MODEL_LEN: "$MAX_MODEL_LEN"
DTYPE: "$DTYPE"
TOKENIZER_MODE: "$TOKENIZER_MODE"
TPU_CHIPS: "8"
rayClusterConfig:
headGroupSpec:
rayStartParams: {}
template:
metadata:
annotations:
gke-gcsfuse/volumes: "true"
gke-gcsfuse/cpu-limit: "0"
gke-gcsfuse/memory-limit: "0"
gke-gcsfuse/ephemeral-storage-limit: "0"
spec:
# replace $KSA_NAME with your Kubernetes Service Account
serviceAccountName: $KSA_NAME
containers:
- name: ray-head
# replace $VLLM_IMAGE with your vLLM container image
image: $VLLM_IMAGE
imagePullPolicy: IfNotPresent
ports:
- containerPort: 6379
name: gcs
- containerPort: 8265
name: dashboard
- containerPort: 10001
name: client
- containerPort: 8000
name: serve
env:
- name: HUGGING_FACE_HUB_TOKEN
valueFrom:
secretKeyRef:
name: hf-secret
key: hf_api_token
- name: VLLM_XLA_CACHE_PATH
value: "/data"
resources:
limits:
cpu: "2"
memory: 8G
requests:
cpu: "2"
memory: 8G
volumeMounts:
- name: gcs-fuse-csi-ephemeral
mountPath: /data
- name: dshm
mountPath: /dev/shm
volumes:
- name: gke-gcsfuse-cache
emptyDir:
medium: Memory
- name: dshm
emptyDir:
medium: Memory
- name: gcs-fuse-csi-ephemeral
csi:
driver: gcsfuse.csi.storage.gke.io
volumeAttributes:
# replace $GSBUCKET with your GCS bucket
bucketName: $GSBUCKET
mountOptions: "implicit-dirs,file-cache:enable-parallel-downloads:true,file-cache:parallel-downloads-per-file:100,file-cache:max-parallel-downloads:-1,file-cache:download-chunk-size-mb:10,file-cache:max-size-mb:-1"
workerGroupSpecs:
- groupName: tpu-group
replicas: 1
minReplicas: 1
maxReplicas: 1
numOfHosts: 1
rayStartParams: {}
template:
metadata:
annotations:
gke-gcsfuse/volumes: "true"
gke-gcsfuse/cpu-limit: "0"
gke-gcsfuse/memory-limit: "0"
gke-gcsfuse/ephemeral-storage-limit: "0"
spec:
# replace $KSA_NAME with your Kubernetes Service Account
serviceAccountName: $KSA_NAME
containers:
- name: ray-worker
# replace $VLLM_IMAGE with your vLLM container image
image: $VLLM_IMAGE
imagePullPolicy: IfNotPresent
resources:
limits:
cpu: "100"
google.com/tpu: "8"
ephemeral-storage: 40G
memory: 200G
requests:
cpu: "100"
google.com/tpu: "8"
ephemeral-storage: 40G
memory: 200G
env:
- name: JAX_PLATFORMS
value: "tpu"
- name: HUGGING_FACE_HUB_TOKEN
valueFrom:
secretKeyRef:
name: hf-secret
key: hf_api_token
- name: VLLM_XLA_CACHE_PATH
value: "/data"
volumeMounts:
- name: gcs-fuse-csi-ephemeral
mountPath: /data
- name: dshm
mountPath: /dev/shm
volumes:
- name: gke-gcsfuse-cache
emptyDir:
medium: Memory
- name: dshm
emptyDir:
medium: Memory
- name: gcs-fuse-csi-ephemeral
csi:
driver: gcsfuse.csi.storage.gke.io
volumeAttributes:
# replace $GSBUCKET with your GCS bucket
bucketName: $GSBUCKET
mountOptions: "implicit-dirs,file-cache:enable-parallel-downloads:true,file-cache:parallel-downloads-per-file:100,file-cache:max-parallel-downloads:-1,file-cache:download-chunk-size-mb:10,file-cache:max-size-mb:-1"
nodeSelector:
cloud.google.com/gke-tpu-accelerator: tpu-v6e-slice
cloud.google.com/gke-tpu-topology: 2x4
109 changes: 109 additions & 0 deletions ray-operator/config/samples/vllm/serve_tpu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
import os

import json
import logging
from typing import Dict, List, Optional

import ray
from fastapi import FastAPI
from ray import serve
from starlette.requests import Request
from starlette.responses import Response

from vllm import LLM, SamplingParams

logger = logging.getLogger("ray.serve")

app = FastAPI()

@serve.deployment(name="VLLMDeployment")
@serve.ingress(app)
class VLLMDeployment:
def __init__(
self,
model_id,
num_tpu_chips,
max_model_len,
tokenizer_mode,
dtype,
):
self.llm = LLM(
model=model_id,
tensor_parallel_size=num_tpu_chips,
max_model_len=max_model_len,
dtype=dtype,
download_dir=os.environ['VLLM_XLA_CACHE_PATH'], # Error if not provided.
tokenizer_mode=tokenizer_mode,
enforce_eager=True,
)

@app.post("/v1/generate")
async def generate(self, request: Request):
request_dict = await request.json()
prompts = request_dict.pop("prompt")
max_toks = int(request_dict.pop("max_tokens"))
print("Processing prompt ", prompts)
sampling_params = SamplingParams(temperature=0.7,
top_p=1.0,
n=1,
max_tokens=max_toks)

outputs = self.llm.generate(prompts, sampling_params)
for output in outputs:
prompt = output.prompt
generated_text = ""
token_ids = []
for completion_output in output.outputs:
generated_text += completion_output.text
token_ids.extend(list(completion_output.token_ids))

print("Generated text: ", generated_text)
ret = {
"prompt": prompt,
"text": generated_text,
"token_ids": token_ids,
}

return Response(content=json.dumps(ret))

def get_num_tpu_chips() -> int:
if "TPU" not in ray.cluster_resources():
# Pass in TPU chips when the current Ray cluster resources can't be auto-detected (i.e for autoscaling).
if os.environ.get('TPU_CHIPS') is not None:
return int(os.environ.get('TPU_CHIPS'))
return 0
return int(ray.cluster_resources()["TPU"])

def get_max_model_len() -> Optional[int]:
if 'MAX_MODEL_LEN' not in os.environ or os.environ['MAX_MODEL_LEN'] == "":
return None
return int(os.environ['MAX_MODEL_LEN'])

def get_tokenizer_mode() -> str:
if 'TOKENIZER_MODE' not in os.environ or os.environ['TOKENIZER_MODE'] == "":
return "auto"
return os.environ['TOKENIZER_MODE']

def get_dtype() -> str:
if 'DTYPE' not in os.environ or os.environ['DTYPE'] == "":
return "auto"
return os.environ['DTYPE']

def build_app(cli_args: Dict[str, str]) -> serve.Application:
"""Builds the Serve app based on CLI arguments."""
ray.init(ignore_reinit_error=True, address="ray://localhost:10001")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the address="ray://localhost:10001" required?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I needed to add that to get it to work with the RayCluster when submitting it with serve run, otherwise it was consistently starting a local cluster (even though I was port-forwarding to the localhost) and not detecting the remote cluster.


model_id = os.environ['MODEL_ID']

num_tpu_chips = get_num_tpu_chips()
pg_resources = []
pg_resources.append({"CPU": 1}) # for the deployment replica
for i in range(num_tpu_chips):
pg_resources.append({"CPU": 1, "TPU": 1}) # for the vLLM actors

# Use PACK strategy since the deployment may use more than one TPU node.
return VLLMDeployment.options(
placement_group_bundles=pg_resources,
placement_group_strategy="PACK").bind(model_id, num_tpu_chips, get_max_model_len(), get_tokenizer_mode(), get_dtype())

model = build_app({})
Loading