diff --git a/tests/test_aot_eager.py b/tests/test_aot_eager.py index 86d994b7..5b22d879 100644 --- a/tests/test_aot_eager.py +++ b/tests/test_aot_eager.py @@ -53,10 +53,16 @@ def test_aot_eager_bitwise_equivalent(llama3_debug_model): x = torch.randint(0, vocab_size, (batch_size, seqlen), device="cuda") torch.manual_seed(3999) r1 = llama3_debug_model(x) - grads1 = torch.autograd.grad(r1.sum(), llama3_debug_model.parameters()) + with torch.profiler.profile(with_stack=True) as prof1: + grads1 = torch.autograd.grad(r1.sum(), llama3_debug_model.parameters()) + prof1.export_chrome_trace("/tmp/profile/prof1.json") torch.manual_seed(3999) r2 = torch.compile(backend="aot_eager")(llama3_debug_model)(x) - grads2 = torch.autograd.grad(r2.sum(), llama3_debug_model.parameters()) + with torch.profiler.profile() as prof2: + grads2 = torch.autograd.grad(r2.sum(), llama3_debug_model.parameters()) + from torch.fx.traceback import populate_stack_traces_to_kineto_trace + prof2.export_chrome_trace("/tmp/profile/prof2.json") + populate_stack_traces_to_kineto_trace("/tmp/profile/prof2.json") assert torch.equal(r1, r2) # bitwise equal for g1, g2 in zip(grads1, grads2): assert torch.equal(g1, g2)