Skip to content

Conversation

@LucasWilkinson
Copy link
Collaborator

Fix hang caused by improperly sized cudagraph in the dummy runs

Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request addresses a critical bug that causes hangs in distributed (DP/EP) environments during CUDA graph dummy runs. The issue stemmed from using a local, unpadded token count (num_tokens) for dispatching CUDA graphs, which could lead to inconsistencies and deadlocks across different ranks. The fix correctly utilizes the synchronized and padded token count (num_tokens_after_padding), ensuring all ranks operate on a consistent batch size for graph dispatching, thereby resolving the hang. This is an important and correct fix for distributed execution stability.


# filter out the valid batch descriptor
_cg_mode, batch_descriptor = self.cudagraph_dispatcher.dispatch(
BatchDescriptor(num_tokens=num_tokens,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

Using num_tokens for the BatchDescriptor is incorrect in a distributed setting as it represents the local, unpadded number of tokens, which can differ across data-parallel ranks. This discrepancy can lead to ranks dispatching to different CUDA graphs (or none at all), causing a hang during subsequent collective operations. The change to use num_tokens_after_padding is correct, as this value is synchronized across all ranks, ensuring consistent CUDA graph dispatching and preventing hangs.

                BatchDescriptor(num_tokens=num_tokens_after_padding,

@LucasWilkinson LucasWilkinson added this to the v0.11.0 Cherry Picks milestone Sep 29, 2025
@simon-mo simon-mo enabled auto-merge (squash) September 29, 2025 21:14
@github-actions github-actions bot added the ready ONLY add when PR is ready to merge/full CI is needed label Sep 29, 2025
Copy link
Collaborator

@alexm-redhat alexm-redhat left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tnx Lucas!

# filter out the valid batch descriptor
_cg_mode, batch_descriptor = self.cudagraph_dispatcher.dispatch(
BatchDescriptor(num_tokens=num_tokens,
BatchDescriptor(num_tokens=num_tokens_after_padding,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch!

LucasWilkinson and others added 5 commits September 29, 2025 15:26
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
@robertgshaw2-redhat
Copy link
Collaborator

Confirmed this does not hang on the GCP B200 cluster with FI

@robertgshaw2-redhat
Copy link
Collaborator

config:

apiVersion: leaderworkerset.x-k8s.io/v1
kind: LeaderWorkerSet
metadata:
    name: vllm-deepseek-ep-decode
spec:
    replicas: 1
    leaderWorkerTemplate:
        size: 2
        restartPolicy: None #RecreateGroupOnPodRestart

        workerTemplate:
            metadata:
              labels:
                app: vllm-deepseek-ep
                component: vllm-deepseek-ep-decode
                server: vllm
              annotations:
                networking.gke.io/default-interface: 'eth0'
                networking.gke.io/interfaces: |
                  [
                    {"interfaceName":"eth0","network":"default"},
                    {"interfaceName":"eth2","network":"rdma-0"},
                    {"interfaceName":"eth3","network":"rdma-1"},
                    {"interfaceName":"eth4","network":"rdma-2"},
                    {"interfaceName":"eth5","network":"rdma-3"},
                    {"interfaceName":"eth6","network":"rdma-4"},
                    {"interfaceName":"eth7","network":"rdma-5"},
                    {"interfaceName":"eth8","network":"rdma-6"},
                    {"interfaceName":"eth9","network":"rdma-7"}
                  ]
            spec:
              containers:
              - name: vllm-worker
                image: gcr.io/claytoncoleman-gke-dev/github.com/smarterclayton/vllm-dp-lws:main_b200_gcp_nixl_06
                imagePullPolicy: Always
                workingDir: /code
                stdin: true
                tty: true
                command: ["/bin/bash","-c"]
                args:
                  - |
                    set -euo pipefail

                    if [[ -n "${DROP_KERNEL_CACHE}" ]]; then
                      find /root/.nv/ComputeCache/ -delete
                    fi

                    # Debugging tools for the environment, libnl required for set_nccl_env.sh
                    if command -v apt-get >/dev/null 2>&1; then
                      # Using https://github.com/smarterclayton/vllm-dp-lws/tree/working_branch
                      apt-get install -y python3.12-dbg dnsutils
                    else
                      dnf install -qy gdb
                      # dnf debuginfo-install -qy python3.12-3.12.9-1.el9_6.1.x86_64
                    fi

                    # Create ~/.bashrc so that kubectl exec -- /bin/bash is ready to run
                    cat <<'EOF' > ~/.bashrc
                    #!/bin/bash

                    bind '"\e[A":history-search-backward'
                    bind '"\e[B":history-search-forward'

                    shopt -s histappend
                    export HISTFILESIZE=1000000
                    export HISTSIZE=1000000
                    export HISTCONTROL=ignoreboth
                    shopt -s cmdhist

                    if command -v apt-get >/dev/null 2>&1; then
                      export VLLM_HOME=/app/venv
                    else
                      # Assume we're using the RH midstream image
                      export VLLM_HOME=/opt/vllm
                      export HF_HUB_OFFLINE=0
                    fi

                    # Configure gIB (assuming v1.10 because of NCCL 2.27 for vLLM main)
                    export PATH=/usr/local/nvidia/bin:${PATH}:/usr/local/gib/bin
                    source /usr/local/gib/scripts/set_nccl_env.sh

                    export START_RANK=$(( ${LWS_WORKER_INDEX:-0} * DP_SIZE_LOCAL ))
                    if [ "${LWS_WORKER_INDEX:-0}" -eq 0 ]; then
                      #################
                      # Leader-only launch
                      #################
                      serve=(
                      ${VLLM_HOME}/bin/vllm serve \
                        ${VLLM_MODEL} \
                        --port 8000 \
                        --disable-log-requests \
                        --data-parallel-hybrid-lb \
                        --speculative-config '$(VLLM_SPECULATIVE_CONFIG)' \
                        --enable-dbo \
                        -O.cudagraph_mode=FULL_DECODE_ONLY \
                        # try to keep 40GB free, https://github.com/vllm-project/vllm/issues/25597
                        --kv-cache-memory=63000000000 \
                        --enable-expert-parallel \
                        --enable-eplb \
                        --num-redundant-experts $DP_SIZE \
                        --eplb-window-size 1000 \
                        --eplb-step-interval 3000 \
                        --eplb-log-balancedness \
                        --tensor-parallel-size $TP_SIZE \
                        --data-parallel-size $DP_SIZE \
                        --data-parallel-size-local $DP_SIZE_LOCAL \
                        --data-parallel-address ${LWS_LEADER_ADDRESS} \
                        --data-parallel-rpc-port 5555 \
                        --data-parallel-start-rank $START_RANK \
                        --kv-transfer-config '{"kv_connector":"NixlConnector", "kv_role":"kv_both"}' \
                        --trust-remote-code
                      )
                    else
                      #################
                      # Worker-only launch
                      #################
                      serve=(
                      ${VLLM_HOME}/bin/vllm serve \
                        ${VLLM_MODEL} \
                        --port 8000 \
                        --disable-log-requests \
                        --data-parallel-hybrid-lb \
                        --speculative-config '$(VLLM_SPECULATIVE_CONFIG)' \
                        --enable-dbo \
                        -O.cudagraph_mode=FULL_DECODE_ONLY \
                        # try to keep 40GB free, https://github.com/vllm-project/vllm/issues/25597
                        --kv-cache-memory=63000000000 \
                        --enable-expert-parallel \
                        --enable-eplb \
                        --num-redundant-experts $DP_SIZE \
                        --eplb-window-size 1000 \
                        --eplb-step-interval 3000 \
                        --eplb-log-balancedness \
                        --tensor-parallel-size $TP_SIZE \
                        --data-parallel-size $DP_SIZE \
                        --data-parallel-size-local $DP_SIZE_LOCAL \
                        --data-parallel-address ${LWS_LEADER_ADDRESS} \
                        --data-parallel-rpc-port 5555 \
                        --data-parallel-start-rank $START_RANK \
                        --kv-transfer-config '{"kv_connector":"NixlConnector", "kv_role":"kv_both"}' \
                        --trust-remote-code
                      )
                    fi
                    eval "serve+=( ${VLLM_DEBUG_ARGS:-} )"

                    DEFAULT_BALANCER_HOST=vllm-deepseek-ep-balancer
                    # Run a benchmark that is decode-only
                    function bench_decode {
                      if [[ -n "${DISAGG-}" ]]; then PORT=8200; prefill_hosts="$( dig SRV +short +search vllm-deepseek-ep-prefill | cut -d ' ' -f 4 | sed 's/local\.$/local/' | sort | awk '{print $1 ":8000"}' | paste -sd ', ' -)"; headers=" --header x-prefiller-host-port=${prefill_hosts} "; fi
                      VIRTUAL_ENV=/app/venv uv run /app/venv/bin/vllm bench serve ${headers:-} --base-url http://${DEFAULT_BALANCER_HOST}:${PORT:-8000} --model ${VLLM_MODEL} --dataset-name random --seed $(date +%M%H%M%S) --ignore-eos --random-input-len ${TOKENS_IN:-2} --random-output-len ${TOKENS_OUT:-100} --max-concurrency ${MAX_CONCURRENCY:-512} --num-prompts ${NUM_PROMPTS:-$((${MAX_CONCURRENCY:-512} * 3))} --ready-check-timeout-sec 0
                    }
                    # Run a benchmark that is prefill-only
                    function bench_prefill {
                      if [[ -n "${DISAGG-}" ]]; then PORT=8200; prefill_hosts="$( dig SRV +short +search vllm-deepseek-ep-prefill | cut -d ' ' -f 4 | sed 's/local\.$/local/' | sort | awk '{print $1 ":8000"}' | paste -sd ', ' -)"; headers=" --header x-prefiller-host-port=${prefill_hosts} "; fi
                      VIRTUAL_ENV=/app/venv uv run /app/venv/bin/vllm bench serve ${headers:-} --base-url http://${DEFAULT_BALANCER_HOST}:${PORT:-8000} --model ${VLLM_MODEL} --dataset-name random --seed $(date +%M%H%M%S) --ignore-eos --random-input-len ${TOKENS_IN:-10000} --random-output-len ${TOKENS_OUT:-1} --max-concurrency ${MAX_CONCURRENCY:-$((${DP_SIZE}*3))} --num-prompts ${NUM_PROMPTS:-$((${MAX_CONCURRENCY:-8} * 3))} --ready-check-timeout-sec 0
                    }
                    # Run a benchmark with mixed prefill and decode
                    function bench_mixed {
                      if [[ -n "${DISAGG-}" ]]; then PORT=8200; prefill_hosts="$( dig SRV +short +search vllm-deepseek-ep-prefill | cut -d ' ' -f 4 | sed 's/local\.$/local/' | sort | awk '{print $1 ":8000"}' | paste -sd ', ' -)"; headers=" --header x-prefiller-host-port=${prefill_hosts} "; fi
                      VIRTUAL_ENV=/app/venv uv run /app/venv/bin/vllm bench serve ${headers:-} --base-url http://${DEFAULT_BALANCER_HOST}:${PORT:-8000} --model ${VLLM_MODEL} --dataset-name random --seed $(date +%M%H%M%S) --ignore-eos --random-input-len ${TOKENS_IN:-1000} --random-output-len ${TOKENS_OUT:-100} --max-concurrency ${MAX_CONCURRENCY:-512} --num-prompts ${NUM_PROMPTS:-$((${MAX_CONCURRENCY:-8} * 3))} --ready-check-timeout-sec 0
                    }
                    # Start Qwen3 for a faster and simpler MoE test harness
                    function serve_simple {
                      ${VLLM_HOME}/bin/vllm serve Qwen/Qwen3-30B-A3B-FP8 --port ${PORT:-8000} --enforce-eager --disable-log-requests --enable-expert-parallel --tensor-parallel-size 1 --data-parallel-size 2 --trust-remote-code
                    }
                    function req_complete {
                      curl -i http://${LWS_LEADER_ADDRESS}:${PORT:-8000}/v1/completions -H 'Content-Type: application/json' -d "{\"temperature\": 0,\"prompt\": \"Write as if you were a critic: San Francisco\",\"max_tokens\": ${TOKENS_OUT:-100},\"model\": \"${VLLM_MODEL}\"}"
                    }
                    # Execute DeepEP's internode latency and tuning test
                    function test_internode {
                      VIRTUAL_ENV=/app/venv WORLD_SIZE=$(( ${DP_SIZE} / ${DP_SIZE_LOCAL} )) MASTER_ADDR=${LWS_LEADER_ADDRESS} RANK=$(( ${START_RANK} / ${DP_SIZE_LOCAL} )) uv run /app/deepep/tests/test_internode.py
                    }
                    function req_disagg {
                      curl -i ${DEFAULT_BALANCER_HOST}:8200/v1/completions \
                        -H "x-prefiller-host-port: ${DEFAULT_PREFILL_HOST}:${PORT:-8000}" \
                        -H 'Content-Type: application/json' -d \
                        '{"temperature": 0,"prompt": "Write as if you were a critic, and provide extensive details of the time you spent with the subject. Don''t forget to provide personal anecdotes.\n\n: San Francisco","max_tokens": 10,"model": "${VLLM_MODEL}"}' "$@"
                    }
                    function collect_vllm_env {
                      wget -O collect_env.py https://raw.githubusercontent.com/vllm-project/vllm/main/vllm/collect_env.py && VIRTUAL_ENV=/app/venv uv run collect_env.py
                    }
                    function print_stacks {
                      pgrep 'VLLM' | xargs -P8 -I {} gdb -p {} --batch --eval-command py-bt | grep -v 'New LWP'
                    }

                    set +e
                    EOF

                    #######################################################
                    # INSTALL Dependencies that have a _BRANCH variable set
                    #######################################################
                    components=( deepep deepgemm flashinfer vllm )
                    for script in "${components[@]}"; do
                      branch="${script^^}_BRANCH"
                      force="${script^^}_INSTALL"
                      if [[ -z "${!branch-}" && -z "${!force-}" ]]; then
                        continue
                      fi
                      for location in /init-scripts /install-scripts; do
                        if [[ -f ${location}/${script}.sh ]]; then
                          ${location}/${script}.sh
                          break
                        fi
                      done
                    done
                    echo
                    for script in "${components[@]}"; do
                      echo "${script} $( git -C "/app/${script}" log --oneline -1 2>&1 || true )"
                    done
                    echo

                    # If set, hold the container before launch to allow debugging
                    if [[ -n "${INTERACTIVE:-}" ]]; then
                      echo "Waiting for /code/launch to run vLLM"
                      while [[ ! -f /code/launch ]]; do
                        sleep 10
                      done
                      rm /code/launch
                    fi

                    source ~/.bashrc
                    env | sort

                    exec "${serve[@]}"
                env:
                  - name: INTERACTIVE
                    value: ""
                  - name: DROP_KERNEL_CACHE
                    value: ""

                  # Uncomment to force vLLM to build at a specific point
                  - name: VLLM_REPO_URL
                    # value: "https://github.com/smarterclayton/vllm.git"
                    # value: "https://github.com/neuralmagic/vllm.git"
                    value: "https://github.com/robertgshaw2-redhat/vllm"
                  - name: VLLM_BRANCH
                    # value: "gcp_tweak_4"
                    # value: "flashinfer-mla-full-cg"
                    # value: "fix_hang"
                    # value: "lwilkinson/fix-hang"
                    value: "gcp_tweak_5"
                  # - name: VLLM_COMMIT 
                  #   value: ""
                  # Set to 1 when running from a build significantly older than vLLM HEAD, as
                  # the vLLM kernels may not be compatible with latest
                  - name: VLLM_USE_PRECOMPILED
                    value: "1"
                  # Prefer the PE mapping of the NIC
                  # - name: DEEPEP_URL
                  #   value: https://github.com/robertgshaw2-redhat/DeepEP.git
                  # - name: DEEPEP_REPO_URL
                  #   value: https://github.com/robertgshaw2-redhat/DeepEP.git
                  # - name: DEEPEP_BRANCH
                  #   value: main
                  - name: FLASHINFER_BRANCH
                    value: "v0.3.1"
                  # When building packages at runtime, ensure we are compiled for hopper and blackwell
                  - name: TORCH_CUDA_ARCH_LIST
                    value: "10.0"

                  # NIXL Configuration
                  - name: VLLM_NIXL_SIDE_CHANNEL_PORT
                    value: "6555"
                  - name: VLLM_NIXL_SIDE_CHANNEL_HOST
                    valueFrom:
                      fieldRef:
                        fieldPath: status.podIP
                  # Force use of TCP for NIXL disaggregation
                  #- name: UCX_TLS
                  #  value: "tcp,cuda_copy"
                  #- name: UCX_NET_DEVICES
                  #  value: "eth0"
                  # GCP RDMA Networking config
                  - name: NVSHMEM_REMOTE_TRANSPORT
                    value: "ibgda"
                  - name: NVSHMEM_IB_ENABLE_IBGDA
                    value: "true"
                  # Currently disabled as we don't mount /dev/gdrdrv into the container
                  - name: NVIDIA_GDRCOPY
                    value: "disabled"
                  - name: NVSHMEM_BOOTSTRAP_UID_SOCK_IFNAME
                    value: "eth0"
                  - name: GLOO_SOCKET_IFNAME
                    value: "eth0"
                  # May not be necessary, but setting to verify NVSHMEM uses the
                  # gcp_tweak around
                  - name: NVSHMEM_ENABLE_NIC_PE_MAPPING
                    value: "1"
                  - name: NVSHMEM_QP_DEPTH
                    value: "8192"

                  # Debugging logging
                  - name: UCX_LOG_LEVEL
                    value: "info"
                  #- name: UCX_HANDLE_ERRORS
                  #  value: "freeze"
                  #- name: UCX_ERROR_SIGNALS
                  #  value: "SIGSEGV"
                  - name: NVSHMEM_INFO
                    value: "true"
                  - name: NVSHMEM_DEBUG
                    value: "INFO"
                  - name: NVSHMEM_DEBUG_SUBSYS
                    value: "ALL"
                  - name: CUDA_ENABLE_USER_TRIGGERED_COREDUMP
                    value: "1"
                  - name: NCCL_DEBUG
                    value: INFO # trace
                  - name: VLLM_LOGGING_LEVEL
                    value: "DEBUG"
                  - name: HF_HUB_DISABLE_PROGRESS_BARS
                    value: "1"
                  - name: VLLM_DEBUG_ARGS
                    value: ""
                  # Uncomment to debug NCCL hangs or crashes
                  #- name: NCCL_BLOCKING_WAIT
                  #  value: "1"
                  # Uncomment to identify where CPP calls are crashing
                  #- name: TORCH_SHOW_CPP_STACKTRACES
                  #  value: "1"
                  # Uncomment to get more accurate crash traces
                  # - name: CUDA_LAUNCH_BLOCKING
                  #   value: "1"
                  - name: VLLM_TORCH_PROFILER_DIR
                    value: "/code/traces"
                  # - name: CUDA_ENABLE_COREDUMP_ON_EXCEPTION
                  #   value: "1"
                  # - name: CUDA_COREDUMP_GENERATION_FLAGS
                  #   value: "skip_nonrelocated_elf_images,skip_global_memory,skip_shared_memory,skip_local_memory"
                  # - name: CUDA_COREDUMP_FILE
                  #   value: "/huggingface-cache/cuda_coredump_%h.%p.%t"

                  # The model and any specific config
                  - name: VLLM_MODEL
                    value: "deepseek-ai/DeepSeek-V3.1"
                    #value: "nvidia/DeepSeek-R1-FP4"
                  # Use this when running DeepSeek @ FP4
                  #- name: VLLM_USE_FLASHINFER_MOE_FP4
                  #  value: "1"

                  # vLLM performance tuning
                  - name: VLLM_USE_DEEP_GEMM
                    value: "1"
                  # Select an all2all backend optimized for latency or throughput
                  - name: VLLM_ALL2ALL_BACKEND
                    # value: "naive"
                    # value: "pplx"
                    # value: "deepep_high_throughput"
                    value: deepep_low_latency
                  - name: VLLM_ATTENTION_BACKEND
                    value: FLASHINFER_MLA
                    # value: CUTLASS_MLA
                    # value: TRITON_MLA
                  - name: DP_SIZE
                    value: "16"
                  - name: TP_SIZE
                    value: "1"
                  - name: DP_SIZE_LOCAL
                    value: "8"
                  # Values above 512 for decode nodes are unlikely to improve performance
                  # because per rank batch size is typically below 512
                  - name: VLLM_MOE_DP_CHUNK_SIZE
                    value: "256"
                  - name: VLLM_RANDOMIZE_DP_DUMMY_INPUTS
                    value: "1"
                  - name: VLLM_SPECULATIVE_CONFIG
                    value: "" #'{"method":"deepseek_mtp","num_speculative_tokens":1}'

                  # Use cache directories from the host Cache directories
                  - name: CCACHE_DIR
                    value: /root/.nv/ComputeCache/.ccache
                  - name: VLLM_CACHE_ROOT
                    value: /root/.nv/ComputeCache/.vllm
                  - name: FLASHINFER_WORKSPACE_BASE
                    value: /root/.nv/ComputeCache
                  - name: TRITON_CACHE_DIR
                    value: /root/.nv/ComputeCache/.triton
                  - name: DG_JIT_CACHE_DIR
                    value: /root/.nv/ComputeCache/.deepgemm
                  - name: HF_HUB_CACHE
                    value: /huggingface-cache

                  # Mount secrets for loading images
                  - name: HF_TOKEN
                    valueFrom:
                      secretKeyRef:
                        name: hf-secret
                        key: HF_TOKEN
                        optional: true

                terminationMessagePolicy: FallbackToLogsOnError
                lifecycle:
                  preStop:
                    sleep:
                      seconds: 1
                readinessProbe:
                  httpGet:
                    path: /health
                    port: 8000
                    scheme: HTTP
                  periodSeconds: 1
                  successThreshold: 1
                  failureThreshold: 1
                  timeoutSeconds: 1

                # may not be needed on GKE
                securityContext:
                  runAsUser: 0  # needed for RH image to be able to override files (runs as vllm)
                  # NVSHMEM IBGDA requires CAP_SYS_ADMIN or nvidia.ko to have PeerMappingOverride=1 set
                  # Since PeerMappingOverride allows workloads to potential impact each other on the
                  # same host, privileged is no less secure.
                  privileged: true # needed to open the host device for UAR ??
                  capabilities:
                    add:
                    - "IPC_LOCK"
                    - "SYS_RAWIO"
                resources:
                  limits:
                    cpu: 200
                    ephemeral-storage: 3Ti
                    nvidia.com/gpu: "8"
                    
                    networking.gke.io.networks/rdma-0: "1"
                    networking.gke.io.networks/rdma-0.IP: "1"
                    networking.gke.io.networks/rdma-1: "1"
                    networking.gke.io.networks/rdma-1.IP: "1"
                    networking.gke.io.networks/rdma-2: "1"
                    networking.gke.io.networks/rdma-2.IP: "1"
                    networking.gke.io.networks/rdma-3: "1"
                    networking.gke.io.networks/rdma-3.IP: "1"
                    networking.gke.io.networks/rdma-4: "1"
                    networking.gke.io.networks/rdma-4.IP: "1"
                    networking.gke.io.networks/rdma-5: "1"
                    networking.gke.io.networks/rdma-5.IP: "1"
                    networking.gke.io.networks/rdma-6: "1"
                    networking.gke.io.networks/rdma-6.IP: "1"
                    networking.gke.io.networks/rdma-7: "1"
                    networking.gke.io.networks/rdma-7.IP: "1"                    
                  requests:
                    cpu: 200
                    ephemeral-storage: 3Ti
                    nvidia.com/gpu: "8"

                    networking.gke.io.networks/rdma-0: "1"
                    networking.gke.io.networks/rdma-0.IP: "1"
                    networking.gke.io.networks/rdma-1: "1"
                    networking.gke.io.networks/rdma-1.IP: "1"
                    networking.gke.io.networks/rdma-2: "1"
                    networking.gke.io.networks/rdma-2.IP: "1"
                    networking.gke.io.networks/rdma-3: "1"
                    networking.gke.io.networks/rdma-3.IP: "1"
                    networking.gke.io.networks/rdma-4: "1"
                    networking.gke.io.networks/rdma-4.IP: "1"
                    networking.gke.io.networks/rdma-5: "1"
                    networking.gke.io.networks/rdma-5.IP: "1"
                    networking.gke.io.networks/rdma-6: "1"
                    networking.gke.io.networks/rdma-6.IP: "1"
                    networking.gke.io.networks/rdma-7: "1"
                    networking.gke.io.networks/rdma-7.IP: "1"
                volumeMounts:
                  - name: shm
                    mountPath: /dev/shm
                  - name: init-scripts-volume
                    mountPath: /init-scripts
                  - name: hf-cache
                    mountPath: /huggingface-cache
                  - name: nv-compute-cache
                    mountPath: /root/.nv/ComputeCache
                  - name: vllm
                    mountPath: /code
                  # Required to access the gIB configuration for NCCL on GKE
                  - mountPath: /usr/local/gib
                    name: gib

              volumes:
                # Volume for the init script from ConfigMap
                - name: init-scripts-volume
                  configMap:
                    name: vllm-init-scripts-config
                    defaultMode: 0755 # Set execute permissions for the script
                    optional: true
                # Needed for NCCL to function
                - name: shm
                  emptyDir:
                    medium: Memory
                    # Verified that we need about 2Gi to start on NCCL 2.27
                    sizeLimit: 3Gi
                # Use a cache directory across pods on the SSD to avoid redownloading large models
                - name: hf-cache
                  hostPath:
                    path: /mnt/stateful_partition/kube-ephemeral-ssd/shared_disk/hfcache/
                    type: DirectoryOrCreate
                # Use a cache directory across pods on the SSD to avoid recompiling kernels
                # Note: there are occasionally bugs in compilation cache hashing that may trigger
                - name: nv-compute-cache
                  hostPath:
                    path: /mnt/stateful_partition/kube-ephemeral-ssd/shared_disk/nv-compute-cache
                    type: DirectoryOrCreate
                # Necessary for gIB
                - name: gib
                  hostPath:
                    path: /home/kubernetes/bin/gib
                    type: ""
                # Scratch directory for the workload
                - name: vllm
                  emptyDir: {}

              # Add the llm-d routing sidecar for disaggregation testing
              initContainers:
              - args:
                - --port=8200
                - --vllm-port=8000
                - --connector=nixlv2
                - -v=5
                - --secure-proxy=false
                env:
                - name: ENABLE_PREFILLER_SAMPLING
                  value: "1"
                image: gcr.io/claytoncoleman-gke-dev/github.com/smarterclayton/llm-d-routing-sidecar:enable_prefiller_sampling
                imagePullPolicy: Always
                name: sidecar
                ports:
                - containerPort: 8200
                  protocol: TCP
                restartPolicy: Always
                terminationMessagePolicy: FallbackToLogsOnError
---
kind: Service
apiVersion: v1
metadata:
  name: vllm-deepseek-ep-balancer
spec:
  selector:
    app: vllm-deepseek-ep
    component: vllm-deepseek-ep-decode
  ports:
  - port: 8000
    name: vllm
  - port: 8200
    name: vllm-sidecar

@robertgshaw2-redhat
Copy link
Collaborator

robertgshaw2-redhat commented Sep 30, 2025

nevermind, after running in a loop for a while, I see another hang, it seems we have not yet solved all the problems with b200 attn on dp/ep

@simon-mo simon-mo merged commit 23194d8 into vllm-project:main Sep 30, 2025
42 checks passed
simon-mo pushed a commit that referenced this pull request Oct 1, 2025
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
Signed-off-by: simon-mo <simon.mo@hey.com>
pdasigi pushed a commit to pdasigi/vllm that referenced this pull request Oct 2, 2025
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
yewentao256 pushed a commit that referenced this pull request Oct 3, 2025
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
Signed-off-by: yewentao256 <zhyanwentao@126.com>
tomeras91 pushed a commit to tomeras91/vllm that referenced this pull request Oct 6, 2025
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
Signed-off-by: Tomer Asida <57313761+tomeras91@users.noreply.github.com>
xuebwang-amd pushed a commit to xuebwang-amd/vllm that referenced this pull request Oct 10, 2025
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
Signed-off-by: xuebwang-amd <xuebwang@amd.com>
choprahetarth pushed a commit to Tandemn-Labs/vllm that referenced this pull request Oct 11, 2025
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
Signed-off-by: simon-mo <simon.mo@hey.com>
shyeh25 pushed a commit to shyeh25/vllm that referenced this pull request Oct 14, 2025
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
Signed-off-by: simon-mo <simon.mo@hey.com>
lywa1998 pushed a commit to lywa1998/vllm that referenced this pull request Oct 20, 2025
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
alhridoy pushed a commit to alhridoy/vllm that referenced this pull request Oct 24, 2025
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
xuebwang-amd pushed a commit to xuebwang-amd/vllm that referenced this pull request Oct 24, 2025
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
Signed-off-by: xuebwang-amd <xuebwang@amd.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready ONLY add when PR is ready to merge/full CI is needed v1

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants