diff --git a/tests/unittests/bases/test_ddp.py b/tests/unittests/bases/test_ddp.py index d221e7a37f7..2a91585c126 100644 --- a/tests/unittests/bases/test_ddp.py +++ b/tests/unittests/bases/test_ddp.py @@ -108,6 +108,10 @@ def _test_ddp_gather_autograd_different_shape(rank: int, worldsize: int = NUM_PR scalar2 = scalar2 + torch.sum(result[idx] * torch.ones_like(result[idx])) gradient1 = torch.autograd.grad(scalar1, [tensor], retain_graph=True)[0] gradient2 = torch.autograd.grad(scalar2, [tensor])[0] + print("shapes", gradient1.shape, gradient2.shape) + print("g1", gradient1) + print("g2", gradient2) + print("diff", gradient1 - gradient2) assert torch.allclose(gradient1, gradient2)