Skip to content

Commit b227924

Browse files
authored
Add vLLM TPU example RayService manifest (#3000)
* Add vLLM TPU example RayService Signed-off-by: Ryan O'Leary <ryanaoleary@google.com> * Add comments Signed-off-by: Ryan O'Leary <ryanaoleary@google.com> * Lint Signed-off-by: Ryan O'Leary <ryanaoleary@google.com> --------- Signed-off-by: Ryan O'Leary <ryanaoleary@google.com>
1 parent 9559227 commit b227924

File tree

2 files changed

+255
-0
lines changed

2 files changed

+255
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,146 @@
1+
apiVersion: ray.io/v1
2+
kind: RayService
3+
metadata:
4+
name: vllm-tpu
5+
spec:
6+
serveConfigV2: |
7+
applications:
8+
- name: llm
9+
import_path: ray-operator.config.samples.vllm.serve_tpu:model
10+
deployments:
11+
- name: VLLMDeployment
12+
num_replicas: 1
13+
runtime_env:
14+
working_dir: "https://github.com/ray-project/kuberay/archive/master.zip"
15+
env_vars:
16+
MODEL_ID: "$MODEL_ID"
17+
MAX_MODEL_LEN: "$MAX_MODEL_LEN"
18+
DTYPE: "$DTYPE"
19+
TOKENIZER_MODE: "$TOKENIZER_MODE"
20+
TPU_CHIPS: "8"
21+
rayClusterConfig:
22+
headGroupSpec:
23+
rayStartParams: {}
24+
template:
25+
metadata:
26+
annotations:
27+
gke-gcsfuse/volumes: "true"
28+
gke-gcsfuse/cpu-limit: "0"
29+
gke-gcsfuse/memory-limit: "0"
30+
gke-gcsfuse/ephemeral-storage-limit: "0"
31+
spec:
32+
# replace $KSA_NAME with your Kubernetes Service Account
33+
serviceAccountName: $KSA_NAME
34+
containers:
35+
- name: ray-head
36+
# replace $VLLM_IMAGE with your vLLM container image
37+
image: $VLLM_IMAGE
38+
imagePullPolicy: IfNotPresent
39+
ports:
40+
- containerPort: 6379
41+
name: gcs
42+
- containerPort: 8265
43+
name: dashboard
44+
- containerPort: 10001
45+
name: client
46+
- containerPort: 8000
47+
name: serve
48+
env:
49+
- name: HUGGING_FACE_HUB_TOKEN
50+
valueFrom:
51+
secretKeyRef:
52+
name: hf-secret
53+
key: hf_api_token
54+
- name: VLLM_XLA_CACHE_PATH
55+
value: "/data"
56+
resources:
57+
limits:
58+
cpu: "2"
59+
memory: 8G
60+
requests:
61+
cpu: "2"
62+
memory: 8G
63+
volumeMounts:
64+
- name: gcs-fuse-csi-ephemeral
65+
mountPath: /data
66+
- name: dshm
67+
mountPath: /dev/shm
68+
volumes:
69+
- name: gke-gcsfuse-cache
70+
emptyDir:
71+
medium: Memory
72+
- name: dshm
73+
emptyDir:
74+
medium: Memory
75+
- name: gcs-fuse-csi-ephemeral
76+
csi:
77+
driver: gcsfuse.csi.storage.gke.io
78+
volumeAttributes:
79+
# replace $GSBUCKET with your GCS bucket
80+
bucketName: $GSBUCKET
81+
mountOptions: "implicit-dirs,file-cache:enable-parallel-downloads:true,file-cache:parallel-downloads-per-file:100,file-cache:max-parallel-downloads:-1,file-cache:download-chunk-size-mb:10,file-cache:max-size-mb:-1"
82+
workerGroupSpecs:
83+
- groupName: tpu-group
84+
replicas: 1
85+
minReplicas: 1
86+
maxReplicas: 1
87+
numOfHosts: 1
88+
rayStartParams: {}
89+
template:
90+
metadata:
91+
annotations:
92+
gke-gcsfuse/volumes: "true"
93+
gke-gcsfuse/cpu-limit: "0"
94+
gke-gcsfuse/memory-limit: "0"
95+
gke-gcsfuse/ephemeral-storage-limit: "0"
96+
spec:
97+
# replace $KSA_NAME with your Kubernetes Service Account
98+
serviceAccountName: $KSA_NAME
99+
containers:
100+
- name: ray-worker
101+
# replace $VLLM_IMAGE with your vLLM container image
102+
image: $VLLM_IMAGE
103+
imagePullPolicy: IfNotPresent
104+
resources:
105+
limits:
106+
cpu: "100"
107+
google.com/tpu: "8"
108+
ephemeral-storage: 40G
109+
memory: 200G
110+
requests:
111+
cpu: "100"
112+
google.com/tpu: "8"
113+
ephemeral-storage: 40G
114+
memory: 200G
115+
env:
116+
- name: JAX_PLATFORMS
117+
value: "tpu"
118+
- name: HUGGING_FACE_HUB_TOKEN
119+
valueFrom:
120+
secretKeyRef:
121+
name: hf-secret
122+
key: hf_api_token
123+
- name: VLLM_XLA_CACHE_PATH
124+
value: "/data"
125+
volumeMounts:
126+
- name: gcs-fuse-csi-ephemeral
127+
mountPath: /data
128+
- name: dshm
129+
mountPath: /dev/shm
130+
volumes:
131+
- name: gke-gcsfuse-cache
132+
emptyDir:
133+
medium: Memory
134+
- name: dshm
135+
emptyDir:
136+
medium: Memory
137+
- name: gcs-fuse-csi-ephemeral
138+
csi:
139+
driver: gcsfuse.csi.storage.gke.io
140+
volumeAttributes:
141+
# replace $GSBUCKET with your GCS bucket
142+
bucketName: $GSBUCKET
143+
mountOptions: "implicit-dirs,file-cache:enable-parallel-downloads:true,file-cache:parallel-downloads-per-file:100,file-cache:max-parallel-downloads:-1,file-cache:download-chunk-size-mb:10,file-cache:max-size-mb:-1"
144+
nodeSelector:
145+
cloud.google.com/gke-tpu-accelerator: tpu-v6e-slice
146+
cloud.google.com/gke-tpu-topology: 2x4
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
import os
2+
3+
import json
4+
import logging
5+
from typing import Dict, List, Optional
6+
7+
import ray
8+
from fastapi import FastAPI
9+
from ray import serve
10+
from starlette.requests import Request
11+
from starlette.responses import Response
12+
13+
from vllm import LLM, SamplingParams
14+
15+
logger = logging.getLogger("ray.serve")
16+
17+
app = FastAPI()
18+
19+
@serve.deployment(name="VLLMDeployment")
20+
@serve.ingress(app)
21+
class VLLMDeployment:
22+
def __init__(
23+
self,
24+
model_id,
25+
num_tpu_chips,
26+
max_model_len,
27+
tokenizer_mode,
28+
dtype,
29+
):
30+
self.llm = LLM(
31+
model=model_id,
32+
tensor_parallel_size=num_tpu_chips,
33+
max_model_len=max_model_len,
34+
dtype=dtype,
35+
download_dir=os.environ['VLLM_XLA_CACHE_PATH'], # Error if not provided.
36+
tokenizer_mode=tokenizer_mode,
37+
enforce_eager=True,
38+
)
39+
40+
@app.post("/v1/generate")
41+
async def generate(self, request: Request):
42+
request_dict = await request.json()
43+
prompts = request_dict.pop("prompt")
44+
max_toks = int(request_dict.pop("max_tokens"))
45+
print("Processing prompt ", prompts)
46+
sampling_params = SamplingParams(temperature=0.7,
47+
top_p=1.0,
48+
n=1,
49+
max_tokens=max_toks)
50+
51+
outputs = self.llm.generate(prompts, sampling_params)
52+
for output in outputs:
53+
prompt = output.prompt
54+
generated_text = ""
55+
token_ids = []
56+
for completion_output in output.outputs:
57+
generated_text += completion_output.text
58+
token_ids.extend(list(completion_output.token_ids))
59+
60+
print("Generated text: ", generated_text)
61+
ret = {
62+
"prompt": prompt,
63+
"text": generated_text,
64+
"token_ids": token_ids,
65+
}
66+
67+
return Response(content=json.dumps(ret))
68+
69+
def get_num_tpu_chips() -> int:
70+
if "TPU" not in ray.cluster_resources():
71+
# Pass in TPU chips when the current Ray cluster resources can't be auto-detected (i.e for autoscaling).
72+
if os.environ.get('TPU_CHIPS') is not None:
73+
return int(os.environ.get('TPU_CHIPS'))
74+
return 0
75+
return int(ray.cluster_resources()["TPU"])
76+
77+
def get_max_model_len() -> Optional[int]:
78+
if 'MAX_MODEL_LEN' not in os.environ or os.environ['MAX_MODEL_LEN'] == "":
79+
return None
80+
return int(os.environ['MAX_MODEL_LEN'])
81+
82+
def get_tokenizer_mode() -> str:
83+
if 'TOKENIZER_MODE' not in os.environ or os.environ['TOKENIZER_MODE'] == "":
84+
return "auto"
85+
return os.environ['TOKENIZER_MODE']
86+
87+
def get_dtype() -> str:
88+
if 'DTYPE' not in os.environ or os.environ['DTYPE'] == "":
89+
return "auto"
90+
return os.environ['DTYPE']
91+
92+
def build_app(cli_args: Dict[str, str]) -> serve.Application:
93+
"""Builds the Serve app based on CLI arguments."""
94+
ray.init(ignore_reinit_error=True, address="ray://localhost:10001")
95+
96+
model_id = os.environ['MODEL_ID']
97+
98+
num_tpu_chips = get_num_tpu_chips()
99+
pg_resources = []
100+
pg_resources.append({"CPU": 1}) # for the deployment replica
101+
for i in range(num_tpu_chips):
102+
pg_resources.append({"CPU": 1, "TPU": 1}) # for the vLLM actors
103+
104+
# Use PACK strategy since the deployment may use more than one TPU node.
105+
return VLLMDeployment.options(
106+
placement_group_bundles=pg_resources,
107+
placement_group_strategy="PACK").bind(model_id, num_tpu_chips, get_max_model_len(), get_tokenizer_mode(), get_dtype())
108+
109+
model = build_app({})

0 commit comments

Comments
 (0)