Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

JAX Hello World Multi-Node GKE H100 with GPUDirectTCPx tutorial #1236 #1237

Draft
wants to merge 8 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions ai-ml/gke-h100-gpudirecttcpx-jax/Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
FROM python:3.10-slim
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you make a workflow that does at minimum docker build as a dry-run that it does build? You can find some examples in the .github/workflows directory

https://github.com/GoogleCloudPlatform/kubernetes-engine-samples/blob/main/.github/CONTRIBUTING.md#samples-requirements


RUN pip install --no-cache-dir --upgrade pip

RUN pip install --no-cache-dir --upgrade "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

ENV LD_LIBRARY_PATH=/usr/local/nvidia/lib64

WORKDIR /workspace/

ADD jax_pingpong_tcpx.py .
ADD start_jax_pingpong_tcpx.sh .

ENTRYPOINT ["bash", "start_jax_pingpong_tcpx.sh"]
53 changes: 53 additions & 0 deletions ai-ml/gke-h100-gpudirecttcpx-jax/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
# JAX Mult-Node 'Hello World' on GKE + H100-80GB with GPUDirectTCPx
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Typically we'll have instructions be the Google Cloud docs tutorial, and point to that from here. This reduces duplication and makes it so we only have to modify the instructions in one source of truth (the tutorial)


This tutorial shows how to run a simple JAX Multi-Node program using NVIDIA GPUs H100-80GB on a GKE cluster with GPUDirectTCPx

## Pre-Requisites

This guide assumes that you already have created a GKE H100 GPUDirectTCPx cluster with GPU drivers.

## Building the image

Build and push the container to your registry. This will push a container to
`gcr.io/<your project>/jax-pingpong-tcpx:latest`. This might take a few minutes.

```
$ bash build_and_push_container.sh
```

## Run Multi-Node JAX

In kubernetes/jobset.yaml, change <<PROJECT>> by your GCP project name.

Run the JAX application on the compute nodes. This will create 2 pods.

```
$ cd kubernetes
$ kubectl apply -k .
```

Use

```
kubectl get pods

$ kubectl get pods

NAME READY STATUS RESTARTS AGE
pingpong-j-0-0-zmcrr 0/2 ContainerCreating 0 5s
pingpong-j-0-1-gw4c5 0/2 ContainerCreating 0 5s
```

to check the status. This will change from `ContainerCreating` to `Pending` (after a few minutes), `Running` and finally `Completed`.

Once the job has completed, use kubectl logs to see the output from one pod

```
$ kubectl logs pingpong-j-0-0-zmcrr
. . .
[16. 16. 16. 16. 16. 16. 16. 16.]
Shutting Down . . .

```

The application creates an array of length 1 equal to [1.0] on each process and then reduces them all. The output, on 16 processes, should be [16.0] on each process.
21 changes: 21 additions & 0 deletions ai-ml/gke-h100-gpudirecttcpx-jax/build_and_push_container.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

#!/bin/bash

export PROJECT=$(gcloud config list project --format "value(core.project)")

docker build . -f Dockerfile -t "gcr.io/${PROJECT}/jax-pingpong-tcpx:latest"

docker push "gcr.io/${PROJECT}/jax-pingpong-tcpx:latest"
47 changes: 47 additions & 0 deletions ai-ml/gke-h100-gpudirecttcpx-jax/jax_pingpong_tcpx.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import jax

COORDINATOR_ADDR = str(os.getenv("COORDINATOR_ADDR"))
COORDINATOR_PORT = str(os.getenv("COORDINATOR_PORT"))

def log(user_str):
print(user_str, flush = True)

def run():
xs = jax.numpy.ones(jax.local_device_count())
log(xs)
log(jax.pmap(lambda x: jax.lax.psum(x, "i"), axis_name="i")(xs))

def init_processes():
jax.distributed.initialize(
coordinator_address=f"{COORDINATOR_ADDR}:{COORDINATOR_PORT}",
num_processes=int(os.getenv("NNODES")),
process_id=int(os.getenv("NODE_RANK"))
)

log(
f"JAX process {jax.process_index()}/{jax.process_count()} initialized on"
f" {COORDINATOR_ADDR}"
)
log(f"JAX global devices:{jax.devices()}")
log(f"JAX local devices:{jax.local_devices()}")
run()

if __name__ == "__main__":
log("Starting . . .")
init_processes()
log("Shutting Down . . .")
126 changes: 126 additions & 0 deletions ai-ml/gke-h100-gpudirecttcpx-jax/kubernetes/jobset.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

apiVersion: jobset.x-k8s.io/v1alpha2
kind: JobSet
metadata:
name: pingpong
spec:
replicatedJobs:
- name: j
template:
spec:
parallelism: 2
completions: 2
backoffLimit: 0
template:
metadata:
annotations:
kubectl.kubernetes.io/default-container: jax-pingpong-tcpx
spec:
tolerations:
- key: "cloud.google.com/impending-node-termination"
operator: "Exists"
hostNetwork: true
dnsPolicy: ClusterFirstWithHostNet
serviceAccountName: "default"
restartPolicy: Never
volumes:
- name: nvidia-install-dir-host
hostPath:
path: /home/kubernetes/bin/nvidia/lib64
- name: tcpx-socket
emptyDir: {}
- name: shared-memory
emptyDir:
medium: "Memory"
sizeLimit: 200Gi
- name: tcpx-nccl-plugin-volume
emptyDir: {}

initContainers:
- name: tcpx-nccl-plugin-installer
image: us-docker.pkg.dev/gce-ai-infra/gpudirect-tcpx/nccl-plugin-gpudirecttcpx-dev:v3.1.7
imagePullPolicy: Always
volumeMounts:
- name: tcpx-nccl-plugin-volume
mountPath: /var/lib/tcpx
resources:
requests:
cpu: 150m
command:
- /bin/sh
- -c
- |
/scripts/container_entry.sh install

containers:
- name: tcpx-daemon
image: us-docker.pkg.dev/gce-ai-infra/gpudirect-tcpx/tcpgpudmarxd-dev:v2.0.11
imagePullPolicy: Always
command:
- "bash"
- "-c"
- |
/tcpgpudmarxd/build/app/tcpgpudmarxd --gpu_nic_preset a3vm \
--uds_path /run/tcpx \
--gpu_shmem_type fd --setup_param "--verbose 128 2 0" &
while [ ! -e "/run/tcpx/workload_terminated" ]; do sleep 10; echo "sleeping"; done
securityContext:
privileged: true
volumeMounts:
- name: nvidia-install-dir-host
mountPath: /usr/local/nvidia/lib64
- name: tcpx-socket
mountPath: /run/tcpx
env:
- name: LD_LIBRARY_PATH
value: /usr/local/nvidia/lib64

- name: jax-pingpong-tcpx
image: gcr.io/<<PROJECT>>/jax-pingpong-tcpx:latest
imagePullPolicy: Always
securityContext:
privileged: true
env:
- name: NODE_RANK
valueFrom:
fieldRef:
fieldPath: metadata.annotations['batch.kubernetes.io/job-completion-index']
# this must be the same as completions defined at the job
# otherwise init will timeout with connection errors
- name: NNODES
value: "2"
- name: USE_GPUDIRECT_TCPX
value: "yes"
- name: GPUS_PER_NODE
value: "8"
- name: COORDINATOR_ADDR
value: "pingpong-j-0-0.pingpong"
- name: COORDINATOR_PORT
value: "6002"
- name: LD_LIBRARY_PATH
value: /usr/local/nvidia/lib64
volumeMounts:
- name: nvidia-install-dir-host
mountPath: /usr/local/nvidia/lib64
- name: tcpx-nccl-plugin-volume
mountPath: /usr/local/tcpx
- name: tcpx-socket
mountPath: /run/tcpx
- name: shared-memory
mountPath: /dev/shm
resources:
limits:
nvidia.com/gpu: "8"
72 changes: 72 additions & 0 deletions ai-ml/gke-h100-gpudirecttcpx-jax/start_jax_pingpong_tcpx.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

#!/bin/bash
set -e
set -x
set -u
set -o pipefail

: "${NNODES:?Must set NNODES}"
: "${NODE_RANK:?Must set NODE_RANK}"
: "${COORDINATOR_PORT:?Must set COORDINATOR_PORT}"
: "${COORDINATOR_ADDR:?Must set COORDINATOR_ADDR}"

export COORDINATOR_PORT=$COORDINATOR_PORT
export COORDINATOR_ADDR=$COORDINATOR_ADDR
export NNODES=$NNODES
export NODE_RANK=$NODE_RANK

set_nccl_gpudirect_tcpx_specific_configuration() {
if [[ "$USE_GPUDIRECT_TCPX" == "yes" ]]; then
echo "Using GPUDirect-TCPX"
export NCCL_CROSS_NIC=0
export NCCL_ALGO=Ring
export NCCL_PROTO=Simple
export NCCL_DEBUG=INFO
export NCCL_NET_GDR_LEVEL=PIX
export NCCL_P2P_PXN_LEVEL=0
export NCCL_DEBUG_SUBSYS=INIT,GRAPH,ENV,TUNING,NET,VERSION
export LD_LIBRARY_PATH="${LD_LIBRARY_PATH}:/usr/local/tcpx/lib64"
export NCCL_GPUDIRECTTCPX_UNIX_CLIENT_PREFIX=/run/tcpx
export NCCL_GPUDIRECTTCPX_FORCE_ACK=0
export NCCL_GPUDIRECTTCPX_TX_COMPLETION_NANOSLEEP=1000
export NCCL_DYNAMIC_CHUNK_SIZE=524288
export NCCL_P2P_NET_CHUNKSIZE=524288
export NCCL_P2P_PCI_CHUNKSIZE=524288
export NCCL_P2P_NVL_CHUNKSIZE=1048576
export NCCL_NSOCKS_PERTHREAD=4
export NCCL_SOCKET_NTHREADS=1
export NCCL_MAX_NCHANNELS=12
export NCCL_MIN_NCHANNELS=12
export NCCL_GPUDIRECTTCPX_PROGRAM_FLOW_STEERING_WAIT_MICROS=1000000
export NCCL_SOCKET_IFNAME=eth0
export NCCL_GPUDIRECTTCPX_TX_BINDINGS="eth1:8-21,112-125;eth2:8-21,112-125;eth3:60-73,164-177;eth4:60-73,164-177"
export NCCL_GPUDIRECTTCPX_RX_BINDINGS="eth1:22-35,124-139;eth2:22-35,124-139;eth3:74-87,178-191;eth4:74-87,178-191"
export NCCL_GPUDIRECTTCPX_SOCKET_IFNAME=eth1,eth2,eth3,eth4
export NCCL_GPUDIRECTTCPX_CTRL_DEV=eth0
else
echo "NOT using TCPX"
fi
}

function on_script_completion {
# semaphore to cleanly exit hardware utilization monitor
touch /run/tcpx/workload_terminated
}
trap on_script_completion EXIT

set_nccl_gpudirect_tcpx_specific_configuration

python jax_pingpong_tcpx.py