diff --git a/ray-operator/config/samples/vllm/ray-service.vllm-tpu-v6e-singlehost.yaml b/ray-operator/config/samples/vllm/ray-service.vllm-tpu-v6e-singlehost.yaml new file mode 100644 index 00000000000..daafbe9e333 --- /dev/null +++ b/ray-operator/config/samples/vllm/ray-service.vllm-tpu-v6e-singlehost.yaml @@ -0,0 +1,146 @@ +apiVersion: ray.io/v1 +kind: RayService +metadata: + name: vllm-tpu +spec: + serveConfigV2: | + applications: + - name: llm + import_path: ray-operator.config.samples.vllm.serve_tpu:model + deployments: + - name: VLLMDeployment + num_replicas: 1 + runtime_env: + working_dir: "https://github.com/ray-project/kuberay/archive/master.zip" + env_vars: + MODEL_ID: "$MODEL_ID" + MAX_MODEL_LEN: "$MAX_MODEL_LEN" + DTYPE: "$DTYPE" + TOKENIZER_MODE: "$TOKENIZER_MODE" + TPU_CHIPS: "8" + rayClusterConfig: + headGroupSpec: + rayStartParams: {} + template: + metadata: + annotations: + gke-gcsfuse/volumes: "true" + gke-gcsfuse/cpu-limit: "0" + gke-gcsfuse/memory-limit: "0" + gke-gcsfuse/ephemeral-storage-limit: "0" + spec: + # replace $KSA_NAME with your Kubernetes Service Account + serviceAccountName: $KSA_NAME + containers: + - name: ray-head + # replace $VLLM_IMAGE with your vLLM container image + image: $VLLM_IMAGE + imagePullPolicy: IfNotPresent + ports: + - containerPort: 6379 + name: gcs + - containerPort: 8265 + name: dashboard + - containerPort: 10001 + name: client + - containerPort: 8000 + name: serve + env: + - name: HUGGING_FACE_HUB_TOKEN + valueFrom: + secretKeyRef: + name: hf-secret + key: hf_api_token + - name: VLLM_XLA_CACHE_PATH + value: "/data" + resources: + limits: + cpu: "2" + memory: 8G + requests: + cpu: "2" + memory: 8G + volumeMounts: + - name: gcs-fuse-csi-ephemeral + mountPath: /data + - name: dshm + mountPath: /dev/shm + volumes: + - name: gke-gcsfuse-cache + emptyDir: + medium: Memory + - name: dshm + emptyDir: + medium: Memory + - name: gcs-fuse-csi-ephemeral + csi: + driver: gcsfuse.csi.storage.gke.io + volumeAttributes: + # replace $GSBUCKET with your GCS bucket + bucketName: $GSBUCKET + mountOptions: "implicit-dirs,file-cache:enable-parallel-downloads:true,file-cache:parallel-downloads-per-file:100,file-cache:max-parallel-downloads:-1,file-cache:download-chunk-size-mb:10,file-cache:max-size-mb:-1" + workerGroupSpecs: + - groupName: tpu-group + replicas: 1 + minReplicas: 1 + maxReplicas: 1 + numOfHosts: 1 + rayStartParams: {} + template: + metadata: + annotations: + gke-gcsfuse/volumes: "true" + gke-gcsfuse/cpu-limit: "0" + gke-gcsfuse/memory-limit: "0" + gke-gcsfuse/ephemeral-storage-limit: "0" + spec: + # replace $KSA_NAME with your Kubernetes Service Account + serviceAccountName: $KSA_NAME + containers: + - name: ray-worker + # replace $VLLM_IMAGE with your vLLM container image + image: $VLLM_IMAGE + imagePullPolicy: IfNotPresent + resources: + limits: + cpu: "100" + google.com/tpu: "8" + ephemeral-storage: 40G + memory: 200G + requests: + cpu: "100" + google.com/tpu: "8" + ephemeral-storage: 40G + memory: 200G + env: + - name: JAX_PLATFORMS + value: "tpu" + - name: HUGGING_FACE_HUB_TOKEN + valueFrom: + secretKeyRef: + name: hf-secret + key: hf_api_token + - name: VLLM_XLA_CACHE_PATH + value: "/data" + volumeMounts: + - name: gcs-fuse-csi-ephemeral + mountPath: /data + - name: dshm + mountPath: /dev/shm + volumes: + - name: gke-gcsfuse-cache + emptyDir: + medium: Memory + - name: dshm + emptyDir: + medium: Memory + - name: gcs-fuse-csi-ephemeral + csi: + driver: gcsfuse.csi.storage.gke.io + volumeAttributes: + # replace $GSBUCKET with your GCS bucket + bucketName: $GSBUCKET + mountOptions: "implicit-dirs,file-cache:enable-parallel-downloads:true,file-cache:parallel-downloads-per-file:100,file-cache:max-parallel-downloads:-1,file-cache:download-chunk-size-mb:10,file-cache:max-size-mb:-1" + nodeSelector: + cloud.google.com/gke-tpu-accelerator: tpu-v6e-slice + cloud.google.com/gke-tpu-topology: 2x4 diff --git a/ray-operator/config/samples/vllm/serve_tpu.py b/ray-operator/config/samples/vllm/serve_tpu.py new file mode 100644 index 00000000000..27d2c918372 --- /dev/null +++ b/ray-operator/config/samples/vllm/serve_tpu.py @@ -0,0 +1,109 @@ +import os + +import json +import logging +from typing import Dict, List, Optional + +import ray +from fastapi import FastAPI +from ray import serve +from starlette.requests import Request +from starlette.responses import Response + +from vllm import LLM, SamplingParams + +logger = logging.getLogger("ray.serve") + +app = FastAPI() + +@serve.deployment(name="VLLMDeployment") +@serve.ingress(app) +class VLLMDeployment: + def __init__( + self, + model_id, + num_tpu_chips, + max_model_len, + tokenizer_mode, + dtype, + ): + self.llm = LLM( + model=model_id, + tensor_parallel_size=num_tpu_chips, + max_model_len=max_model_len, + dtype=dtype, + download_dir=os.environ['VLLM_XLA_CACHE_PATH'], # Error if not provided. + tokenizer_mode=tokenizer_mode, + enforce_eager=True, + ) + + @app.post("/v1/generate") + async def generate(self, request: Request): + request_dict = await request.json() + prompts = request_dict.pop("prompt") + max_toks = int(request_dict.pop("max_tokens")) + print("Processing prompt ", prompts) + sampling_params = SamplingParams(temperature=0.7, + top_p=1.0, + n=1, + max_tokens=max_toks) + + outputs = self.llm.generate(prompts, sampling_params) + for output in outputs: + prompt = output.prompt + generated_text = "" + token_ids = [] + for completion_output in output.outputs: + generated_text += completion_output.text + token_ids.extend(list(completion_output.token_ids)) + + print("Generated text: ", generated_text) + ret = { + "prompt": prompt, + "text": generated_text, + "token_ids": token_ids, + } + + return Response(content=json.dumps(ret)) + +def get_num_tpu_chips() -> int: + if "TPU" not in ray.cluster_resources(): + # Pass in TPU chips when the current Ray cluster resources can't be auto-detected (i.e for autoscaling). + if os.environ.get('TPU_CHIPS') is not None: + return int(os.environ.get('TPU_CHIPS')) + return 0 + return int(ray.cluster_resources()["TPU"]) + +def get_max_model_len() -> Optional[int]: + if 'MAX_MODEL_LEN' not in os.environ or os.environ['MAX_MODEL_LEN'] == "": + return None + return int(os.environ['MAX_MODEL_LEN']) + +def get_tokenizer_mode() -> str: + if 'TOKENIZER_MODE' not in os.environ or os.environ['TOKENIZER_MODE'] == "": + return "auto" + return os.environ['TOKENIZER_MODE'] + +def get_dtype() -> str: + if 'DTYPE' not in os.environ or os.environ['DTYPE'] == "": + return "auto" + return os.environ['DTYPE'] + +def build_app(cli_args: Dict[str, str]) -> serve.Application: + """Builds the Serve app based on CLI arguments.""" + ray.init(ignore_reinit_error=True, address="ray://localhost:10001") + + model_id = os.environ['MODEL_ID'] + + num_tpu_chips = get_num_tpu_chips() + pg_resources = [] + pg_resources.append({"CPU": 1}) # for the deployment replica + for i in range(num_tpu_chips): + pg_resources.append({"CPU": 1, "TPU": 1}) # for the vLLM actors + + # Use PACK strategy since the deployment may use more than one TPU node. + return VLLMDeployment.options( + placement_group_bundles=pg_resources, + placement_group_strategy="PACK").bind(model_id, num_tpu_chips, get_max_model_len(), get_tokenizer_mode(), get_dtype()) + +model = build_app({})