Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions .github/workflows/xpu-compile.yml
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,15 @@ jobs:
- name: Compile Status
shell: bash
run: |
echo "# torch.compile graph breaks" >> $GITHUB_STEP_SUMMARY
export FI_HMEM=system
ulimit -n 1048575
cd tests/torch_compile
export ZE_AFFINITY_MASK=0,1
deepspeed test_compile.py --deepspeed_config ds_config.json 2>&1 | tee log.txt
cat log.txt | grep "'graph_breaks'" | sed 's/,/ /g' | awk '{print $2}' >> $GITHUB_STEP_SUMMARY
echo "## ZeRO stage 3" >> $GITHUB_STEP_SUMMARY
deepspeed test_compile.py --deepspeed_config ds_config_z3.json 2>&1 | tee log_z3.txt
# for each line start with 'dynamo_output', extract the second field and following fields and append to GITHUB_STEP_SUMMARY using awk
cat log_z3.txt | awk '/^dynamo_output/ {$1=""; print $0}' >> $GITHUB_STEP_SUMMARY
echo "## ZeRO stage 2" >> $GITHUB_STEP_SUMMARY
deepspeed test_compile.py --deepspeed_config ds_config_z2.json 2>&1 | tee log_z2.txt
cat log_z2.txt | awk '/^dynamo_output/ {$1=""; print $0}' >> $GITHUB_STEP_SUMMARY
40 changes: 40 additions & 0 deletions tests/torch_compile/ds_config_z2.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
{
"train_batch_size": 8,
"steps_per_print": 2000,
"optimizer": {
"type": "Adam",
"params": {
"lr": 0.001,
"betas": [
0.8,
0.999
],
"eps": 1e-8,
"weight_decay": 3e-7
}
},
"scheduler": {
"type": "WarmupLR",
"params": {
"warmup_min_lr": 0,
"warmup_max_lr": 0.001,
"warmup_num_steps": 1000
}
},
"gradient_clipping": 1.0,
"prescale_gradients": false,
"bf16": {
"enabled": true,
"loss_scale": 0,
"loss_scale_window": 500,
"hysteresis": 2,
"min_loss_scale": 1,
"initial_scale_power": 15
},
"wall_clock_breakdown": false,
"zero_optimization": {
"stage": 2,
"overlap_comm": false,
"contiguous_gradients": false
}
}
33 changes: 16 additions & 17 deletions tests/torch_compile/test_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,22 +14,9 @@

torch._dynamo.config.cache_size_limit = 100

import collections


def get_dynamo_stats():
# TODO: consider deepcopy'ing the entire counters struct and
# adding a helper to do subtraction on it
return collections.Counter({
"calls_captured": torch._dynamo.utils.counters["stats"]["calls_captured"],
"unique_graphs": torch._dynamo.utils.counters["stats"]["unique_graphs"],
"graph_breaks": sum(torch._dynamo.utils.counters["graph_break"].values()),
# NB: The plus removes zero counts
"unique_graph_breaks": len(+torch._dynamo.utils.counters["graph_break"]),
"autograd_captures": torch._dynamo.utils.counters["compiled_autograd"]["captures"],
"autograd_compiles": torch._dynamo.utils.counters["compiled_autograd"]["compiles"],
"cudagraph_skips": torch._dynamo.utils.counters["inductor"]["cudagraph_skips"],
})
return torch._dynamo.utils.counters["graph_break"]


class RandomDataset(Dataset):
Expand Down Expand Up @@ -70,7 +57,7 @@ def forward(self, data, residual):
parser.add_argument('--local_rank', type=int, default=-1, help='local rank passed from distributed launcher')
parser.add_argument('--deepspeed_config',
type=str,
default='ds_config.json',
default='ds_config_z3.json',
help='path to DeepSpeed configuration file')
cmd_args = parser.parse_args()

Expand All @@ -82,6 +69,11 @@ def forward(self, data, residual):

start_stats = get_dynamo_stats()

if comm.get_rank() == 0:
#print(dynamo_stats['graph_breaks'])
for item in start_stats.items():
print(item)

for step, batch in enumerate(rand_loader):
if step % 10 == 0 and comm.get_rank() == 0:
print(f'step={step}')
Expand All @@ -93,7 +85,14 @@ def forward(self, data, residual):
model_engine.step()

dynamo_stats = get_dynamo_stats()
dynamo_stats.subtract(start_stats)

if comm.get_rank() == 0:
print(dynamo_stats)
# print break down of graph break stats with markdown, print in table format, start with reason, then count
# print a tag 'dynamo_output' before each line to allow post processing
print("dynamo_output | Reason | Count |")
print("dynamo_output | ------ | ----- |")
for item in dynamo_stats.items():
# replace '|' in item[0] with a literal '|' to avoid mess with table format
item = (item[0].replace('|', r'\|'), item[1])
print(f"dynamo_output | {item[0]} | {item[1]} |")
print(f"dynamo_output | Total | {sum(dynamo_stats.values())} |")
Loading