forked from kingoflolz/swarm-jax
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtpu.yaml
67 lines (59 loc) · 3.23 KB
/
tpu.yaml
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
# This config is an example TPU config allowing you to run
# https://github.com/Yard1/swarm-jax on GCP TPUs
# Replace provider.project_id with your GCP project id
# After the nodes are up, run:
# ray attach tpu.yaml swarm_tpu_jax.py swarm-jax/data/enwik8 [NUM_TPUS] [EPOCHS]
# A unique identifier for the head node and workers of this cluster.
cluster_name: tputest
# The maximum number of worker nodes to launch in addition to the head
# node.
max_workers: 7
available_node_types:
ray_head_default:
min_workers: 0
max_workers: 0
resources: {"CPU": 2}
node_config:
machineType: n2-standard-2
disks:
- boot: true
autoDelete: true
type: PERSISTENT
initializeParams:
diskSizeGb: 50
# See https://cloud.google.com/compute/docs/images for more images
sourceImage: projects/deeplearning-platform-release/global/images/family/common-cpu
ray_tpu:
min_workers: 7
resources: {"TPU": 1} # use TPU custom resource in your code
node_config:
acceleratorType: v2-8
runtimeVersion: v2-alpha
# Uncomment to use preemptible TPUs
# schedulingConfig:
# preemptible: true
provider:
type: gcp
region: us-central1
availability_zone: us-central1-f
project_id: null # replace with your GCP project id
setup_commands: []
# Specify the node type of the head node (as configured above).
# TPUs cannot be head nodes (will raise an exception).
head_node_type: ray_head_default
# Compute instances have python 3.7, but TPUs have 3.8 - need to update
# Install Jax and other dependencies on the Compute head node
head_setup_commands:
- conda create -y -n "ray" python=3.8.5 && sudo update-alternatives --install /opt/conda/bin/python python /opt/conda/envs/ray/bin/python 10 && sudo update-alternatives --install /opt/conda/bin/pip pip /opt/conda/envs/ray/bin/pip 10
- export PATH="$PATH:/opt/conda/envs/ray/bin" && echo 'export PATH="$PATH:/opt/conda/envs/ray/bin"' >> ~/.bashrc
- python -m pip install --upgrade "jax[cpu]==0.2.14" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
- python -m pip install --upgrade fabric dataclasses optax==0.0.6 git+https://github.com/deepmind/dm-haiku google-api-python-client cryptography tensorboardX ray[default]
- python -m pip install -U https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-2.0.0.dev0-cp38-cp38-manylinux2014_x86_64.whl
- git clone https://github.com/Yard1/swarm-jax.git && cd swarm-jax && python -m pip install .
# Install Jax and other dependencies on TPU
worker_setup_commands:
- pip3 install --upgrade "jax[tpu]==0.2.14" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
- pip3 install --upgrade fabric dataclasses optax==0.0.6 git+https://github.com/deepmind/dm-haiku tensorboardX ray[default]
- python3 -c "import jax; jax.device_count(); jax.numpy.add(1, 1)" # test if Jax has been installed correctly
- pip3 install -U https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-2.0.0.dev0-cp38-cp38-manylinux2014_x86_64.whl
- git clone https://github.com/Yard1/swarm-jax.git && cd swarm-jax && sudo pip3 install .