Skip to content

Commit 10eacfd

Browse files
JasonLi1909gemini-code-assist[bot]angelinalg
authored
Getting Started with PyTorch Fully Sharded Data Parallel (FSDP2) and Ray Train Template (#56298)
A new workspace template that walks users through how to integrate PyTorch's FSDP2 with Ray Train. The purpose of this template is to allow customers to quickly get started with FSDP and Ray Train whether they are coming from using Distributed Data Parallel (DDP) or just getting started with training large models. For a high level overview, this template covers: - A hands-on example of training an image classification model - Model checkpoint saving and loading with PyTorch Distributed Checkpoint (DCP) - Configuring FSDP2 to mitigate out-of-memory (OOM) errors using mixed precision, CPU offloading, sharding granularity, and more - GPU memory profiling with PyTorch Profiler - Loading a distributed model for inference Link to original PR (pivoted to make available on OSS): **anyscale/templates#463 **Testing** - This notebook was tested in an [Anyscale workspace](https://console.anyscale.com/cld_kvedZWag2qA8i5BjxUevf5i7/prj_cz951f43jjdybtzkx1s5sjgz99/workspaces/expwrk_nktjw7a3j2l5c7af9rh3n5rskw?workspace-tab=overview&command-history-section=application_logs&file=%252Fhome%252Fray%252Fdefault%252FREADME.ipynb) and ran as expected **For easy testing:** Simply copy the notebook into an Anyscale workspace (and preferably the image directory), provisioned with two T4 nodes. --------- Signed-off-by: JasonLi1909 <jasli1909@gmail.com> Signed-off-by: Jason Li <57246540+JasonLi1909@users.noreply.github.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Co-authored-by: angelinalg <122562471+angelinalg@users.noreply.github.com>
1 parent f857275 commit 10eacfd

File tree

17 files changed

+1815
-1
lines changed

17 files changed

+1815
-1
lines changed

doc/source/train/examples.yml

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,13 @@ examples:
5555
- natural language processing
5656
contributor: community
5757
link: examples/intel_gaudi/bert
58-
58+
- title: Get started with PyTorch Fully Sharded Data Parallel (FSDP2) and Ray Train
59+
skill_level: intermediate
60+
frameworks:
61+
- pytorch
62+
use_cases:
63+
- computer vision
64+
link: examples/pytorch/pytorch-fsdp/README
5965
- title: Train a text classifier with DeepSpeed
6066
frameworks:
6167
- deepspeed

doc/source/train/examples/pytorch/pytorch-fsdp/README.ipynb

Lines changed: 927 additions & 0 deletions
Large diffs are not rendered by default.

doc/source/train/examples/pytorch/pytorch-fsdp/README.md

Lines changed: 731 additions & 0 deletions
Large diffs are not rendered by default.
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
filegroup(
2+
name = "ci_yamls",
3+
srcs = ["aws.yaml", "gce.yaml"],
4+
visibility = ["//release:__pkg__"],
5+
)
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
cloud_id: {{env["ANYSCALE_CLOUD_ID"]}}
2+
region: us-central1
3+
4+
head_node_type:
5+
name: head_node
6+
instance_type: m5.2xlarge
7+
8+
worker_node_types:
9+
- instance_type: g4dn.xlarge
10+
name: '1xT4:4CPU-16GB'
11+
min_workers: 2
12+
max_workers: 2
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
cloud_id: {{env["ANYSCALE_CLOUD_ID"]}}
2+
region: us-central1
3+
4+
head_node_type:
5+
name: head
6+
instance_type: n2-standard-8
7+
8+
worker_node_types:
9+
- name: gpu_worker
10+
instance_type: n1-standard-8-nvidia-t4-16gb-1
11+
min_workers: 2
12+
max_workers: 2
13+
use_spot: false
Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
#!/usr/bin/env python3
2+
import argparse
3+
import nbformat
4+
5+
6+
def convert_notebook(input_path: str, output_path: str) -> None:
7+
"""
8+
Read a Jupyter notebook and write a Python script, converting all %%bash
9+
cells and IPython "!" commands into subprocess.run calls that raise on error.
10+
Cells that load or autoreload extensions are ignored.
11+
"""
12+
nb = nbformat.read(input_path, as_version=4)
13+
with open(output_path, "w") as out:
14+
for cell in nb.cells:
15+
# Only process code cells
16+
if cell.cell_type != "code":
17+
continue
18+
19+
lines = cell.source.splitlines()
20+
# Skip cells that load or autoreload extensions
21+
if any(
22+
l.strip().startswith("%load_ext autoreload")
23+
or l.strip().startswith("%autoreload all")
24+
for l in lines
25+
):
26+
continue
27+
28+
# Detect a %%bash cell
29+
if lines and lines[0].strip().startswith("%%bash"):
30+
bash_script = "\n".join(lines[1:]).rstrip()
31+
out.write("import subprocess\n")
32+
out.write(
33+
f"subprocess.run(r'''{bash_script}''',\n"
34+
" shell=True,\n"
35+
" check=True,\n"
36+
" executable='/bin/bash')\n\n"
37+
)
38+
else:
39+
# Detect any IPython '!' shell commands in code lines
40+
has_bang = any(line.lstrip().startswith("!") for line in lines)
41+
if has_bang:
42+
out.write("import subprocess\n")
43+
for line in lines:
44+
stripped = line.lstrip()
45+
if stripped.startswith("!"):
46+
cmd = stripped[1:].lstrip()
47+
out.write(
48+
f"subprocess.run(r'''{cmd}''',\n"
49+
" shell=True,\n"
50+
" check=True,\n"
51+
" executable='/bin/bash')\n"
52+
)
53+
else:
54+
out.write(line.rstrip() + "\n")
55+
out.write("\n")
56+
else:
57+
# Regular Python cell: dump as-is
58+
out.write(cell.source.rstrip() + "\n\n")
59+
60+
61+
def main() -> None:
62+
parser = argparse.ArgumentParser(
63+
description="Convert a Jupyter notebook to a Python script, preserving bash cells and '!' commands as subprocess calls."
64+
)
65+
parser.add_argument("input_nb", help="Path to the input .ipynb file")
66+
parser.add_argument("output_py", help="Path for the output .py script")
67+
args = parser.parse_args()
68+
convert_notebook(args.input_nb, args.output_py)
69+
70+
71+
if __name__ == "__main__":
72+
main()
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
#!/bin/bash
2+
python ci/nb2py.py README.ipynb README.py # convert notebook to py script
3+
python README.py # run the converted python script
4+
rm README.py # remove the generated script
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
head_node_type:
2+
name: head_node
3+
instance_type: m5.2xlarge
4+
5+
worker_node_types:
6+
- instance_type: g4dn.xlarge
7+
name: '1xT4:4CPU-16GB'
8+
min_workers: 2
9+
max_workers: 2
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
head_node_type:
2+
name: head
3+
instance_type: n2-standard-8
4+
5+
worker_node_types:
6+
- name: gpu_worker
7+
instance_type: n1-standard-8-nvidia-t4-16gb-1
8+
min_workers: 2
9+
max_workers: 2
10+
use_spot: false

0 commit comments

Comments
 (0)