-
Notifications
You must be signed in to change notification settings - Fork 56
276 lines (263 loc) · 10.7 KB
/
nsys-jax.yaml
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
name: nsys-jax pure-Python CI
concurrency:
group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }}
cancel-in-progress: ${{ github.ref != 'refs/heads/main' }}
on:
pull_request:
types:
- opened
- reopened
- ready_for_review
- synchronize
paths-ignore:
- '**.md'
push:
branches:
- main
defaults:
run:
shell: bash -x -eo pipefail {0}
env:
NSYS_JAX_PYTHON_FILES: |
JAX-Toolbox/.github/container/nsys_jax
JAX-Toolbox/.github/container/jax-nccl-test
jobs:
mypy:
runs-on: ubuntu-24.04
steps:
- name: Check out the repository under ${GITHUB_WORKSPACE}
uses: actions/checkout@v4
with:
path: JAX-Toolbox
sparse-checkout: |
.github/container
- name: "Setup Python 3.10"
uses: actions/setup-python@v5
with:
python-version: '3.10'
# jax is just a CPU-only build of the latest release for type-checking purposes
- name: "Install jax / nsys_jax / mypy"
run: pip install jax -e JAX-Toolbox/.github/container/nsys_jax matplotlib mypy nbconvert types-protobuf types-requests
- name: "Install protoc"
# TODO: this could install into the pip prefix as a default
run: |
install-protoc local/
echo "$PWD/local/bin" >> "${GITHUB_PATH}"
- name: "Fetch XLA .proto files"
uses: actions/checkout@v4
with:
path: xla
repository: openxla/xla
sparse-checkout: |
*.proto
sparse-checkout-cone-mode: false
- name: "Compile .proto files"
run: |
mkdir compiled_protos compiled_stubs protos
mv -v xla/third_party/tsl/tsl protos/
mv -v xla/xla protos/
python -c "from nsys_jax import compile_protos; compile_protos(proto_dir='protos', output_dir='compiled_protos', output_stub_dir='compiled_stubs')"
touch compiled_stubs/py.typed
- name: "Convert .ipynb to .py"
run: |
for notebook in $(find ${NSYS_JAX_PYTHON_FILES} -name '*.ipynb'); do
jupyter nbconvert --to script ${notebook}
done
- name: "Run mypy checks"
run: |
export MYPYPATH="${PWD}/compiled_stubs"
mypy --scripts-are-modules ${NSYS_JAX_PYTHON_FILES}
# Test nsys-jax-combine and notebook execution; in future perhaps upload the rendered
# notebook from here too. These input files were generated with something like
# srun -n 4 --container-name=XXX --container-image=ghcr.io/nvidia/jax:pax-2024-07-06
# env NPROC=4 XLA_FLAGS=--xla_gpu_enable_latency_hiding_scheduler\
# --xla_gpu_enable_command_buffer= nsys-jax -o ...-fsdp4-4proc-proc%q{SLURM_PROCID}
# -- test-pax.sh --steps=5 --fsdp=4 --multiprocess
# with newer nsys-jax components bind-mounted in.
combine:
runs-on: ubuntu-24.04
steps:
- name: Check out the repository under ${GITHUB_WORKSPACE}
uses: actions/checkout@v4
- name: "Setup Python 3.12"
uses: actions/setup-python@v5
with:
python-version: '3.12'
- name: Install nsys-jax and dependencies
run: |
# Installs nsys-jax-combine; use an editable install here for better coverage
pip install -e .github/container/nsys_jax[jupyter]
# TODO: this could install into the pip prefix as a default
install-flamegraph local/
install-protoc local/
echo "$PWD/local/bin" >> "${GITHUB_PATH}"
- name: Use nsys-jax-combine to merge profiles from multiple nsys processes
run: |
nsys-jax-combine \
--analysis summary \
--analysis communication \
-o pax_fsdp4_4proc.zip \
.github/workflows/nsys-jax/test_data/pax_fsdp4_4proc_proc*.zip
- name: Extract the output .zip file
run: |
mkdir combined/
unzip -d combined/ pax_fsdp4_4proc.zip
- name: Execute the notebook
run: |
NOTEBOOK=$(python -c 'from importlib.resources import files; print(files("nsys_jax") / "analyses" / "Analysis.ipynb")')
# Point to the extracted nsys-jax-combine output
export NSYS_JAX_DEFAULT_PREFIX="${PWD}/combined"
# Run with ipython for the sake of getting a clear error message
ipython "${NOTEBOOK}"
# This input file was generated with something like
# srun -n 1 --container-name=XXX --container-image=ghcr.io/nvidia/jax:pax-2024-07-06
# env NPROC=4 XLA_FLAGS=--xla_gpu_enable_latency_hiding_scheduler\
# --xla_gpu_enable_command_buffer= nsys-jax -o ...-fsdp4-1proc -- test-pax.sh
# --steps=5 --fsdp=4
notebook:
env:
# TODO: these could/should be saved in the repository settings instead
RENDERED_NOTEBOOK_GIST_ID: e2cd3520201caab6b67385ed36fad3c1
MOCK_RENDERED_NOTEBOOK_GIST_ID: 16698d9e9e52320243165d61b5bb3975
# Name/bash regex for shields.io endpoint JSON files
PUBLISH_NOTEBOOK_FILES: '(.*\.ipynb|.*\.svg)'
runs-on: ubuntu-24.04
steps:
- name: Check out the repository under ${GITHUB_WORKSPACE}
uses: actions/checkout@v4
- name: Extract the post-processed profile data from a real .zip file (no .nsys-rep)
run: |
# Get the actual test data from a real archive, minus the .nsys-rep file
mkdir profile_data/
unzip -d profile_data/ .github/workflows/nsys-jax/test_data/pax_fsdp4_1proc.zip
- name: "Setup Python 3.10"
uses: actions/setup-python@v5
with:
python-version: '3.10'
- name: Install nsys-jax and dependencies
run: |
# Do *not* use an editable install (covered above) for better coverage
pip install .github/container/nsys_jax[jupyter]
# TODO: this could install into the pip prefix as a default
install-flamegraph local/
install-protoc local/
echo "$PWD/local/bin" >> "${GITHUB_PATH}"
- name: Execute the notebook
id: exec
run: |
NOTEBOOK=$(python -c 'from importlib.resources import files; print(files("nsys_jax") / "analyses" / "Analysis.ipynb")')
# Point to the extracted profile data
export NSYS_JAX_DEFAULT_PREFIX="${PWD}/profile_data"
# Run with ipython for the sake of getting a clear error message
ipython "${NOTEBOOK}"
echo "NOTEBOOK=${NOTEBOOK}" >> $GITHUB_OUTPUT
- name: Render the notebook
id: render
run: |
workdir=$(mktemp -d)
export NSYS_JAX_DEFAULT_PREFIX="${PWD}/profile_data"
jupyter nbconvert --execute --inplace '${{ steps.exec.outputs.NOTEBOOK }}'
cp '${{ steps.exec.outputs.NOTEBOOK }}' *.svg "${workdir}"
echo "WORKDIR=${workdir}" >> $GITHUB_OUTPUT
- name: Upload rendered notebook to Gist
id: upload
uses: actions/github-script@v7
with:
github-token: ${{ secrets.NVJAX_GIST_TOKEN }}
script: |
const currentDateTime = new Date().toISOString();
const gistDescription =
`Rendered IPython notebook from workflow: ${{ github.workflow }}, ` +
`Run ID: ${{ github.run_id }}, ` +
`Repository: ${{ github.repository }}, ` +
`Event: ${{ github.event_name }}, ` +
`Created: ${currentDateTime}`;
const fs = require('fs').promises;
const workdir = '${{ steps.render.outputs.WORKDIR }}'
const files = await fs.readdir(workdir);
gist = await github.rest.gists.create({
description: gistDescription,
public: false,
files: Object.fromEntries(
await Promise.all(
files.map(
async filename => {
const content = await fs.readFile(`${workdir}/${filename}`, 'utf8');
return [filename, { content }];
}
)
)
)
});
console.log(gist)
return gist.data.id;
- name: Copy rendered notebook to Gist with well-known ID
uses: actions/github-script@v7
with:
github-token: ${{ secrets.NVJAX_GIST_TOKEN }}
script: |
const srcId = ${{ steps.upload.outputs.result }};
const dstId = "${{ github.ref == 'refs/heads/main' && env.RENDERED_NOTEBOOK_GIST_ID || env.MOCK_RENDERED_NOTEBOOK_GIST_ID }}";
const { PUBLISH_NOTEBOOK_FILES } = process.env;
// Fetch existing files from destination gist
const { data: dstData } = await github.rest.gists.get({
gist_id: dstId
});
// Mark existing files in destination gist for deletion
let filesToUpdate = {};
for (const filename of Object.keys(dstData.files)) {
filesToUpdate[filename] = null;
}
// Fetch files from source gist
const { data: srcData } = await github.rest.gists.get({
gist_id: srcId
});
// Add or update files based on the pattern
const pattern = new RegExp(`${PUBLISH_NOTEBOOK_FILES}`);
for (const [filename, fileObj] of Object.entries(srcData.files)) {
if (filename.match(pattern)) {
// If the total gist size is too large, not all the content will have
// been returned and we need some extra requests.
if (fileObj.truncated) {
const { data } = await github.request(fileObj.raw_url)
filesToUpdate[filename] = {
content: new TextDecoder().decode(data)
};
} else {
filesToUpdate[filename] = {
content: fileObj.content
};
}
}
}
// Update files in destination gist
await github.rest.gists.update({
gist_id: dstId,
files: filesToUpdate
});
console.log("Files copied successfully.");
console.log(Object.keys(filesToUpdate));
ruff:
runs-on: ubuntu-24.04
steps:
- name: Check out the repository under ${GITHUB_WORKSPACE}
uses: actions/checkout@v4
with:
path: JAX-Toolbox
sparse-checkout: |
.github/container
- name: "Setup Python 3.10"
uses: actions/setup-python@v5
with:
python-version: '3.10'
- name: "Install ruff"
run: pip install ruff
- name: "Run ruff checks"
run: |
ruff check ${NSYS_JAX_PYTHON_FILES}
check_status=$?
ruff format --check ${NSYS_JAX_PYTHON_FILES}
format_status=$?
if [[ $format_status != 0 || $check_status != 0 ]]; then
exit 1
fi