diff --git a/pytorch_lightning/utilities/distributed.py b/pytorch_lightning/utilities/distributed.py index be5d781939c04..2a0b989e9b9cd 100644 --- a/pytorch_lightning/utilities/distributed.py +++ b/pytorch_lightning/utilities/distributed.py @@ -202,6 +202,7 @@ def all_gather_ddp_if_available( Return: A tensor of shape (world_size, batch, ...) """ + group = group if group is not None else torch.distributed.group.WORLD if torch.distributed.is_available() and torch.distributed.is_initialized(): if sync_grads: return AllGatherGrad.apply(tensor, group)