From 12d19cad1dca0b2e17667dceb4a8fe609f66d62d Mon Sep 17 00:00:00 2001 From: Guokai Ma Date: Mon, 30 Sep 2024 11:07:07 +0800 Subject: [PATCH 01/16] print graph log --- .github/workflows/xpu-compile.yml | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/.github/workflows/xpu-compile.yml b/.github/workflows/xpu-compile.yml index c2392091012f..b6c6058a6e2b 100644 --- a/.github/workflows/xpu-compile.yml +++ b/.github/workflows/xpu-compile.yml @@ -57,3 +57,7 @@ jobs: export ZE_AFFINITY_MASK=0,1 deepspeed test_compile.py --deepspeed_config ds_config.json 2>&1 | tee log.txt cat log.txt | grep "'graph_breaks'" | sed 's/,/ /g' | awk '{print $2}' >> $GITHUB_STEP_SUMMARY + # print log.txt content for debugging + echo "vvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvv" + cat log.txt + echo "^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^" From 45704ae4857786355e296a110372f589caad7821 Mon Sep 17 00:00:00 2001 From: Guokai Ma Date: Mon, 30 Sep 2024 11:48:48 +0800 Subject: [PATCH 02/16] add more logging code in the test --- tests/torch_compile/test_compile.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/torch_compile/test_compile.py b/tests/torch_compile/test_compile.py index 529ca56ae0a8..c054d9d37c1b 100644 --- a/tests/torch_compile/test_compile.py +++ b/tests/torch_compile/test_compile.py @@ -97,3 +97,5 @@ def forward(self, data, residual): if comm.get_rank() == 0: print(dynamo_stats) + print(torch._dynamo.utils.counters["stats"]["calls_captured"]) + print(torch._dynamo.utils.log) From 5f61c31cf6c0d22b3e6bacef37d9d091ec7c651b Mon Sep 17 00:00:00 2001 From: Guokai Ma Date: Mon, 30 Sep 2024 14:52:27 +0800 Subject: [PATCH 03/16] force torch dynamo dump more logs --- tests/torch_compile/test_compile.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/tests/torch_compile/test_compile.py b/tests/torch_compile/test_compile.py index c054d9d37c1b..1157b738fbfa 100644 --- a/tests/torch_compile/test_compile.py +++ b/tests/torch_compile/test_compile.py @@ -16,6 +16,10 @@ import collections +import os +os.environ["TORCH_COMPILE_DEBUG"] = "1" +os.environ["TORCHDYNAMO_VERBOSE"] = "1" +os.environ["TORCH_LOGS"] = "+graph_breaks" def get_dynamo_stats(): # TODO: consider deepcopy'ing the entire counters struct and @@ -97,5 +101,4 @@ def forward(self, data, residual): if comm.get_rank() == 0: print(dynamo_stats) - print(torch._dynamo.utils.counters["stats"]["calls_captured"]) - print(torch._dynamo.utils.log) + print(torch._dynamo.utils.counters) From 3848abc73c87be1ee1c319ee955572c62a0de56d Mon Sep 17 00:00:00 2001 From: Guokai Ma Date: Mon, 30 Sep 2024 16:44:46 +0800 Subject: [PATCH 04/16] print detaild counters --- .github/workflows/xpu-compile.yml | 3 --- tests/torch_compile/test_compile.py | 31 +++++++++++++---------------- 2 files changed, 14 insertions(+), 20 deletions(-) diff --git a/.github/workflows/xpu-compile.yml b/.github/workflows/xpu-compile.yml index b6c6058a6e2b..f81d95eaa57f 100644 --- a/.github/workflows/xpu-compile.yml +++ b/.github/workflows/xpu-compile.yml @@ -57,7 +57,4 @@ jobs: export ZE_AFFINITY_MASK=0,1 deepspeed test_compile.py --deepspeed_config ds_config.json 2>&1 | tee log.txt cat log.txt | grep "'graph_breaks'" | sed 's/,/ /g' | awk '{print $2}' >> $GITHUB_STEP_SUMMARY - # print log.txt content for debugging - echo "vvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvv" cat log.txt - echo "^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^" diff --git a/tests/torch_compile/test_compile.py b/tests/torch_compile/test_compile.py index 1157b738fbfa..912632621800 100644 --- a/tests/torch_compile/test_compile.py +++ b/tests/torch_compile/test_compile.py @@ -16,24 +16,20 @@ import collections -import os -os.environ["TORCH_COMPILE_DEBUG"] = "1" -os.environ["TORCHDYNAMO_VERBOSE"] = "1" -os.environ["TORCH_LOGS"] = "+graph_breaks" - def get_dynamo_stats(): # TODO: consider deepcopy'ing the entire counters struct and # adding a helper to do subtraction on it - return collections.Counter({ - "calls_captured": torch._dynamo.utils.counters["stats"]["calls_captured"], - "unique_graphs": torch._dynamo.utils.counters["stats"]["unique_graphs"], - "graph_breaks": sum(torch._dynamo.utils.counters["graph_break"].values()), - # NB: The plus removes zero counts - "unique_graph_breaks": len(+torch._dynamo.utils.counters["graph_break"]), - "autograd_captures": torch._dynamo.utils.counters["compiled_autograd"]["captures"], - "autograd_compiles": torch._dynamo.utils.counters["compiled_autograd"]["compiles"], - "cudagraph_skips": torch._dynamo.utils.counters["inductor"]["cudagraph_skips"], - }) + return torch._dynamo.utils.counters["graph_break"] + #return collections.Counter({ + #"calls_captured": torch._dynamo.utils.counters["stats"]["calls_captured"], + #"unique_graphs": torch._dynamo.utils.counters["stats"]["unique_graphs"], + #"graph_breaks": sum(torch._dynamo.utils.counters["graph_break"].values()), + ## NB: The plus removes zero counts + #"unique_graph_breaks": len(+torch._dynamo.utils.counters["graph_break"]), + #"autograd_captures": torch._dynamo.utils.counters["compiled_autograd"]["captures"], + #"autograd_compiles": torch._dynamo.utils.counters["compiled_autograd"]["compiles"], + #"cudagraph_skips": torch._dynamo.utils.counters["inductor"]["cudagraph_skips"], + #}) class RandomDataset(Dataset): @@ -100,5 +96,6 @@ def forward(self, data, residual): dynamo_stats.subtract(start_stats) if comm.get_rank() == 0: - print(dynamo_stats) - print(torch._dynamo.utils.counters) + print(dynamo_stats['graph_breaks']) + for item in dynamo_states.items(): + print(item) From fe700cfcae4d87f4d9b49d12f1f6e5c875987904 Mon Sep 17 00:00:00 2001 From: Guokai Ma Date: Mon, 30 Sep 2024 18:43:24 +0800 Subject: [PATCH 05/16] fix typo --- tests/torch_compile/test_compile.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/torch_compile/test_compile.py b/tests/torch_compile/test_compile.py index 912632621800..b7c0d8e5dabf 100644 --- a/tests/torch_compile/test_compile.py +++ b/tests/torch_compile/test_compile.py @@ -97,5 +97,5 @@ def forward(self, data, residual): if comm.get_rank() == 0: print(dynamo_stats['graph_breaks']) - for item in dynamo_states.items(): + for item in dynamo_stats.items(): print(item) From ab3f72e4ba7cda6b388619de19b2dde8526452df Mon Sep 17 00:00:00 2001 From: Guokai Ma Date: Mon, 30 Sep 2024 18:54:26 +0800 Subject: [PATCH 06/16] no print 'graph_breaks' in dynamo_stats --- tests/torch_compile/test_compile.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/torch_compile/test_compile.py b/tests/torch_compile/test_compile.py index b7c0d8e5dabf..4b11c9ba9791 100644 --- a/tests/torch_compile/test_compile.py +++ b/tests/torch_compile/test_compile.py @@ -96,6 +96,6 @@ def forward(self, data, residual): dynamo_stats.subtract(start_stats) if comm.get_rank() == 0: - print(dynamo_stats['graph_breaks']) + #print(dynamo_stats['graph_breaks']) for item in dynamo_stats.items(): print(item) From eabc834bf1a400ad21187b2888bcd619f1e296b7 Mon Sep 17 00:00:00 2001 From: Guokai Ma Date: Mon, 30 Sep 2024 20:37:59 +0800 Subject: [PATCH 07/16] avoid grep graph breaks in the workflow --- .github/workflows/xpu-compile.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/xpu-compile.yml b/.github/workflows/xpu-compile.yml index f81d95eaa57f..6d8968ee3c53 100644 --- a/.github/workflows/xpu-compile.yml +++ b/.github/workflows/xpu-compile.yml @@ -56,5 +56,5 @@ jobs: cd tests/torch_compile export ZE_AFFINITY_MASK=0,1 deepspeed test_compile.py --deepspeed_config ds_config.json 2>&1 | tee log.txt - cat log.txt | grep "'graph_breaks'" | sed 's/,/ /g' | awk '{print $2}' >> $GITHUB_STEP_SUMMARY + #cat log.txt | grep "'graph_breaks'" | sed 's/,/ /g' | awk '{print $2}' >> $GITHUB_STEP_SUMMARY cat log.txt From b4be3025d8b5d14a903d3ab3363c885fb0ff135d Mon Sep 17 00:00:00 2001 From: Guokai Ma Date: Mon, 30 Sep 2024 21:01:50 +0800 Subject: [PATCH 08/16] don't substract --- tests/torch_compile/test_compile.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/tests/torch_compile/test_compile.py b/tests/torch_compile/test_compile.py index 4b11c9ba9791..9d1bd181d922 100644 --- a/tests/torch_compile/test_compile.py +++ b/tests/torch_compile/test_compile.py @@ -82,6 +82,11 @@ def forward(self, data, residual): start_stats = get_dynamo_stats() +if comm.get_rank() == 0: + #print(dynamo_stats['graph_breaks']) + for item in dynamo_stats.items(): + print(item) + for step, batch in enumerate(rand_loader): if step % 10 == 0 and comm.get_rank() == 0: print(f'step={step}') @@ -93,7 +98,6 @@ def forward(self, data, residual): model_engine.step() dynamo_stats = get_dynamo_stats() -dynamo_stats.subtract(start_stats) if comm.get_rank() == 0: #print(dynamo_stats['graph_breaks']) From dadc42cd92860090e6505db9be237419706298aa Mon Sep 17 00:00:00 2001 From: Guokai Ma Date: Mon, 30 Sep 2024 21:17:25 +0800 Subject: [PATCH 09/16] fix typo --- tests/torch_compile/test_compile.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/torch_compile/test_compile.py b/tests/torch_compile/test_compile.py index 9d1bd181d922..5899ec5ddf06 100644 --- a/tests/torch_compile/test_compile.py +++ b/tests/torch_compile/test_compile.py @@ -84,7 +84,7 @@ def forward(self, data, residual): if comm.get_rank() == 0: #print(dynamo_stats['graph_breaks']) - for item in dynamo_stats.items(): + for item in start_stats.items(): print(item) for step, batch in enumerate(rand_loader): From 20022e1632ee6491e137c403040f96e54bc0d9f4 Mon Sep 17 00:00:00 2001 From: Guokai Ma Date: Sun, 6 Oct 2024 11:56:36 +0800 Subject: [PATCH 10/16] print summary in markdown table format --- tests/torch_compile/test_compile.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/tests/torch_compile/test_compile.py b/tests/torch_compile/test_compile.py index 5899ec5ddf06..680ff7b0c70f 100644 --- a/tests/torch_compile/test_compile.py +++ b/tests/torch_compile/test_compile.py @@ -101,5 +101,11 @@ def forward(self, data, residual): if comm.get_rank() == 0: #print(dynamo_stats['graph_breaks']) + # print break down of graph break stats with markdown, print in table format, start with number of graph breaks, which is second item of each tuple. Then follow by reason, which is first item of each tuple + # print a table head first + # then print total number at the end of the table + print("| Reason | Count |") + print("| ------ | ----- |") for item in dynamo_stats.items(): - print(item) + print(f"| {item[0]} | {item[1]} |") + print(f"| Total | {sum(dynamo_stats.values())} |") From 05fa0a85b84dc682e3c46aa8e720ac27c7bc9d80 Mon Sep 17 00:00:00 2001 From: Guokai Ma Date: Sun, 6 Oct 2024 12:38:54 +0800 Subject: [PATCH 11/16] output to git summary file --- .github/workflows/xpu-compile.yml | 4 ++-- tests/torch_compile/test_compile.py | 10 ++++++---- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/.github/workflows/xpu-compile.yml b/.github/workflows/xpu-compile.yml index 6d8968ee3c53..eee52bee8faf 100644 --- a/.github/workflows/xpu-compile.yml +++ b/.github/workflows/xpu-compile.yml @@ -56,5 +56,5 @@ jobs: cd tests/torch_compile export ZE_AFFINITY_MASK=0,1 deepspeed test_compile.py --deepspeed_config ds_config.json 2>&1 | tee log.txt - #cat log.txt | grep "'graph_breaks'" | sed 's/,/ /g' | awk '{print $2}' >> $GITHUB_STEP_SUMMARY - cat log.txt + # for each line start with 'dynamo_output', extract the second field and following fields and append to GITHUB_STEP_SUMMARY using awk + cat log.txt | awk '/^dynamo_output/ {$1=""; print $0}' >> $GITHUB_STEP_SUMMARY diff --git a/tests/torch_compile/test_compile.py b/tests/torch_compile/test_compile.py index 680ff7b0c70f..49586d9363d4 100644 --- a/tests/torch_compile/test_compile.py +++ b/tests/torch_compile/test_compile.py @@ -104,8 +104,10 @@ def forward(self, data, residual): # print break down of graph break stats with markdown, print in table format, start with number of graph breaks, which is second item of each tuple. Then follow by reason, which is first item of each tuple # print a table head first # then print total number at the end of the table - print("| Reason | Count |") - print("| ------ | ----- |") + # print a tag 'dynamo_output' before each line to allow post processing + print("dynamo_output | Reason | Count |") + print("dynamo_output | ------ | ----- |") for item in dynamo_stats.items(): - print(f"| {item[0]} | {item[1]} |") - print(f"| Total | {sum(dynamo_stats.values())} |") + print(f"dynamo_output | {item[0]} | {item[1]} |") + print("dynamo_output | ------ | ----- |") + print(f"dynamo_output | Total | {sum(dynamo_stats.values())} |") From c17a2c3327dec41d291d46ca76f7605d4068c745 Mon Sep 17 00:00:00 2001 From: Guokai Ma Date: Sun, 6 Oct 2024 12:53:17 +0800 Subject: [PATCH 12/16] debug output format --- tests/torch_compile/test_compile.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/torch_compile/test_compile.py b/tests/torch_compile/test_compile.py index 49586d9363d4..4d9e862b0592 100644 --- a/tests/torch_compile/test_compile.py +++ b/tests/torch_compile/test_compile.py @@ -105,9 +105,9 @@ def forward(self, data, residual): # print a table head first # then print total number at the end of the table # print a tag 'dynamo_output' before each line to allow post processing - print("dynamo_output | Reason | Count |") - print("dynamo_output | ------ | ----- |") + #print("dynamo_output | Reason | Count |") + #print("dynamo_output | ------ | ----- |") for item in dynamo_stats.items(): - print(f"dynamo_output | {item[0]} | {item[1]} |") - print("dynamo_output | ------ | ----- |") - print(f"dynamo_output | Total | {sum(dynamo_stats.values())} |") + #print(f"dynamo_output | {item[0]} | {item[1]} |") + print(f"dynamo_output {item}") + #print(f"dynamo_output | Total | {sum(dynamo_stats.values())} |") From c5e9985e9df8a8eafe6855c2f1e0a5bb0b0a9f04 Mon Sep 17 00:00:00 2001 From: Guokai Ma Date: Sun, 6 Oct 2024 13:08:43 +0800 Subject: [PATCH 13/16] save | literal in table --- tests/torch_compile/test_compile.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/tests/torch_compile/test_compile.py b/tests/torch_compile/test_compile.py index 4d9e862b0592..d483b42c3baf 100644 --- a/tests/torch_compile/test_compile.py +++ b/tests/torch_compile/test_compile.py @@ -105,9 +105,10 @@ def forward(self, data, residual): # print a table head first # then print total number at the end of the table # print a tag 'dynamo_output' before each line to allow post processing - #print("dynamo_output | Reason | Count |") - #print("dynamo_output | ------ | ----- |") + print("dynamo_output | Reason | Count |") + print("dynamo_output | ------ | ----- |") for item in dynamo_stats.items(): - #print(f"dynamo_output | {item[0]} | {item[1]} |") - print(f"dynamo_output {item}") - #print(f"dynamo_output | Total | {sum(dynamo_stats.values())} |") + # replace '|' in item[0] with a literal '|' to avoid mess with table format + item = (item[0].replace('|', r'\|'), item[1]) + print(f"dynamo_output | {item[0]} | {item[1]} |") + print(f"dynamo_output | Total | {sum(dynamo_stats.values())} |") From 72b633b6fcdf117e0419cc8eda77c921f0147fe1 Mon Sep 17 00:00:00 2001 From: Guokai Ma Date: Sun, 6 Oct 2024 13:51:25 +0800 Subject: [PATCH 14/16] cleanup --- tests/torch_compile/test_compile.py | 15 +-------------- 1 file changed, 1 insertion(+), 14 deletions(-) diff --git a/tests/torch_compile/test_compile.py b/tests/torch_compile/test_compile.py index d483b42c3baf..e42d65d9c642 100644 --- a/tests/torch_compile/test_compile.py +++ b/tests/torch_compile/test_compile.py @@ -20,16 +20,6 @@ def get_dynamo_stats(): # TODO: consider deepcopy'ing the entire counters struct and # adding a helper to do subtraction on it return torch._dynamo.utils.counters["graph_break"] - #return collections.Counter({ - #"calls_captured": torch._dynamo.utils.counters["stats"]["calls_captured"], - #"unique_graphs": torch._dynamo.utils.counters["stats"]["unique_graphs"], - #"graph_breaks": sum(torch._dynamo.utils.counters["graph_break"].values()), - ## NB: The plus removes zero counts - #"unique_graph_breaks": len(+torch._dynamo.utils.counters["graph_break"]), - #"autograd_captures": torch._dynamo.utils.counters["compiled_autograd"]["captures"], - #"autograd_compiles": torch._dynamo.utils.counters["compiled_autograd"]["compiles"], - #"cudagraph_skips": torch._dynamo.utils.counters["inductor"]["cudagraph_skips"], - #}) class RandomDataset(Dataset): @@ -100,10 +90,7 @@ def forward(self, data, residual): dynamo_stats = get_dynamo_stats() if comm.get_rank() == 0: - #print(dynamo_stats['graph_breaks']) - # print break down of graph break stats with markdown, print in table format, start with number of graph breaks, which is second item of each tuple. Then follow by reason, which is first item of each tuple - # print a table head first - # then print total number at the end of the table + # print break down of graph break stats with markdown, print in table format, start with reason, then count # print a tag 'dynamo_output' before each line to allow post processing print("dynamo_output | Reason | Count |") print("dynamo_output | ------ | ----- |") From 80d0a44f2e46af57dbfad0aad34245e7c6106a6d Mon Sep 17 00:00:00 2001 From: Guokai Ma Date: Sun, 6 Oct 2024 15:45:24 +0800 Subject: [PATCH 15/16] fix formatting --- tests/torch_compile/test_compile.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/torch_compile/test_compile.py b/tests/torch_compile/test_compile.py index e42d65d9c642..711876fa4b28 100644 --- a/tests/torch_compile/test_compile.py +++ b/tests/torch_compile/test_compile.py @@ -14,7 +14,6 @@ torch._dynamo.config.cache_size_limit = 100 -import collections def get_dynamo_stats(): # TODO: consider deepcopy'ing the entire counters struct and From dcf7af9a505efa6ba9985dcdbef1f5114a398c15 Mon Sep 17 00:00:00 2001 From: "Ma, Guokai" Date: Thu, 10 Oct 2024 15:40:32 +0800 Subject: [PATCH 16/16] remove unneeded comments and add zero 2 --- .github/workflows/xpu-compile.yml | 9 ++++- tests/torch_compile/ds_config_z2.json | 40 +++++++++++++++++++ .../{ds_config.json => ds_config_z3.json} | 0 tests/torch_compile/test_compile.py | 4 +- 4 files changed, 48 insertions(+), 5 deletions(-) create mode 100644 tests/torch_compile/ds_config_z2.json rename tests/torch_compile/{ds_config.json => ds_config_z3.json} (100%) diff --git a/.github/workflows/xpu-compile.yml b/.github/workflows/xpu-compile.yml index eee52bee8faf..e095e089fc30 100644 --- a/.github/workflows/xpu-compile.yml +++ b/.github/workflows/xpu-compile.yml @@ -51,10 +51,15 @@ jobs: - name: Compile Status shell: bash run: | + echo "# torch.compile graph breaks" >> $GITHUB_STEP_SUMMARY export FI_HMEM=system ulimit -n 1048575 cd tests/torch_compile export ZE_AFFINITY_MASK=0,1 - deepspeed test_compile.py --deepspeed_config ds_config.json 2>&1 | tee log.txt + echo "## ZeRO stage 3" >> $GITHUB_STEP_SUMMARY + deepspeed test_compile.py --deepspeed_config ds_config_z3.json 2>&1 | tee log_z3.txt # for each line start with 'dynamo_output', extract the second field and following fields and append to GITHUB_STEP_SUMMARY using awk - cat log.txt | awk '/^dynamo_output/ {$1=""; print $0}' >> $GITHUB_STEP_SUMMARY + cat log_z3.txt | awk '/^dynamo_output/ {$1=""; print $0}' >> $GITHUB_STEP_SUMMARY + echo "## ZeRO stage 2" >> $GITHUB_STEP_SUMMARY + deepspeed test_compile.py --deepspeed_config ds_config_z2.json 2>&1 | tee log_z2.txt + cat log_z2.txt | awk '/^dynamo_output/ {$1=""; print $0}' >> $GITHUB_STEP_SUMMARY diff --git a/tests/torch_compile/ds_config_z2.json b/tests/torch_compile/ds_config_z2.json new file mode 100644 index 000000000000..30e1237c558c --- /dev/null +++ b/tests/torch_compile/ds_config_z2.json @@ -0,0 +1,40 @@ +{ + "train_batch_size": 8, + "steps_per_print": 2000, + "optimizer": { + "type": "Adam", + "params": { + "lr": 0.001, + "betas": [ + 0.8, + 0.999 + ], + "eps": 1e-8, + "weight_decay": 3e-7 + } + }, + "scheduler": { + "type": "WarmupLR", + "params": { + "warmup_min_lr": 0, + "warmup_max_lr": 0.001, + "warmup_num_steps": 1000 + } + }, + "gradient_clipping": 1.0, + "prescale_gradients": false, + "bf16": { + "enabled": true, + "loss_scale": 0, + "loss_scale_window": 500, + "hysteresis": 2, + "min_loss_scale": 1, + "initial_scale_power": 15 + }, + "wall_clock_breakdown": false, + "zero_optimization": { + "stage": 2, + "overlap_comm": false, + "contiguous_gradients": false + } +} diff --git a/tests/torch_compile/ds_config.json b/tests/torch_compile/ds_config_z3.json similarity index 100% rename from tests/torch_compile/ds_config.json rename to tests/torch_compile/ds_config_z3.json diff --git a/tests/torch_compile/test_compile.py b/tests/torch_compile/test_compile.py index 711876fa4b28..adbf6eaa947a 100644 --- a/tests/torch_compile/test_compile.py +++ b/tests/torch_compile/test_compile.py @@ -16,8 +16,6 @@ def get_dynamo_stats(): - # TODO: consider deepcopy'ing the entire counters struct and - # adding a helper to do subtraction on it return torch._dynamo.utils.counters["graph_break"] @@ -59,7 +57,7 @@ def forward(self, data, residual): parser.add_argument('--local_rank', type=int, default=-1, help='local rank passed from distributed launcher') parser.add_argument('--deepspeed_config', type=str, - default='ds_config.json', + default='ds_config_z3.json', help='path to DeepSpeed configuration file') cmd_args = parser.parse_args()