Skip to content

Commit

Permalink
[bugfix] TPU + all_gather + SingleTPU shouldn't call xm.all_gather (#…
Browse files Browse the repository at this point in the history
…6296)

* resolve an issue with TPU

* update

* add changelog
  • Loading branch information
tchaton authored and Borda committed Mar 9, 2021
1 parent 92c545b commit 8578ffa
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 1 deletion.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed `trainer.test` from `best_path` hangs after calling `trainer.fit` ([#6272](https://github.com/PyTorchLightning/pytorch-lightning/pull/6272))


- Fixed `SingleTPU` calling `all_gather` ([#6296](https://github.com/PyTorchLightning/pytorch-lightning/pull/6296))

## [1.2.2] - 2021-03-02

### Added
Expand Down
5 changes: 4 additions & 1 deletion pytorch_lightning/accelerators/tpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,4 +40,7 @@ def all_gather(self, tensor: Union[torch.Tensor], group: Optional[Any] = None, s
Return:
A tensor of shape (world_size, batch, ...)
"""
return xm.all_gather(tensor, group=group, sync_grads=sync_grads)
# todo: Add support for backward with all_gather
if torch.distributed.is_initialized():
return xm.all_gather(tensor, group=group, sync_grads=sync_grads)
return tensor

0 comments on commit 8578ffa

Please sign in to comment.