Skip to content

Commit 5552503

Browse files
authored
Fix TPU Spawn gather (#6896)
1 parent 2e53fd3 commit 5552503

File tree

4 files changed

+21
-22
lines changed

4 files changed

+21
-22
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -231,6 +231,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
231231
- Fixed `EarlyStopping` logic when `min_epochs` or `min_steps` requirement is not met ([#6705](https://github.com/PyTorchLightning/pytorch-lightning/pull/6705))
232232

233233

234+
- Fixed TPU Spawn all gather ([#6896](https://github.com/PyTorchLightning/pytorch-lightning/pull/6896))
235+
236+
234237
- Fixed `--gpus` default for parser returned by `Trainer.add_argparse_args` ([#6898](https://github.com/PyTorchLightning/pytorch-lightning/pull/6898))
235238

236239

pytorch_lightning/accelerators/tpu.py

Lines changed: 1 addition & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,8 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
from typing import Any, Callable, Optional, TYPE_CHECKING, Union
14+
from typing import Any, Callable, TYPE_CHECKING, Union
1515

16-
import torch
1716
from torch.optim import Optimizer
1817

1918
from pytorch_lightning.accelerators.accelerator import Accelerator
@@ -57,21 +56,6 @@ def run_optimizer_step(
5756
) -> None:
5857
xm.optimizer_step(optimizer, barrier=False, optimizer_args={'closure': lambda_closure, **kwargs})
5958

60-
def all_gather(self, tensor: torch.Tensor, group: Optional[Any] = None, sync_grads: bool = False) -> torch.Tensor:
61-
"""
62-
Function to gather a tensor from several distributed processes
63-
Args:
64-
tensor: tensor of shape (batch, ...)
65-
group: not available with TPUs
66-
sync_grads: not available with TPUs
67-
Return:
68-
A tensor of shape (world_size, batch, ...)
69-
"""
70-
# todo: Add support for backward with all_gather
71-
if isinstance(self.training_type_plugin, TPUSpawnPlugin) and self.training_type_plugin.is_distributed:
72-
return xm.all_gather(tensor).view(-1, *tensor.shape)
73-
return tensor
74-
7559
def clip_gradients(self, optimizer: Optimizer, clip_val: Union[float, int], norm_type: float = 2.0):
7660

7761
model = self.lightning_module

pytorch_lightning/plugins/training_type/tpu_spawn.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -195,14 +195,14 @@ def broadcast(self, obj: object, src: int = 0) -> object:
195195
return obj
196196

197197
def reduce_boolean_decision(self, decision: bool) -> bool:
198-
decision = torch.tensor(int(decision), device=self.device)
199-
decision = self.reduce(decision, "sum")
198+
decision = torch.tensor(int(decision), device=self.lightning_module.device)
199+
decision = self.reduce(decision, reduce_op="sum")
200200
decision = bool(decision == self.world_size)
201201
return decision
202202

203203
def reduce(self, output, group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = None):
204204
if not isinstance(output, torch.Tensor):
205-
output = torch.tensor(output, device=self.device)
205+
output = torch.tensor(output, device=self.lightning_module.device)
206206

207207
_invalid_reduce_op = isinstance(reduce_op, ReduceOp) and reduce_op != ReduceOp.SUM
208208
_invalid_reduce_op_str = isinstance(reduce_op, str) and reduce_op.lower() not in ("sum", "mean", "avg")
@@ -267,3 +267,15 @@ def save_checkpoint(self, checkpoint: Dict[str, Any], filepath: str) -> None:
267267
if _OMEGACONF_AVAILABLE:
268268
checkpoint = apply_to_collection(checkpoint, (DictConfig, ListConfig), OmegaConf.to_container)
269269
self.save({k: v for k, v in checkpoint.items() if k != "callbacks"}, filepath)
270+
271+
def all_gather(self, tensor: torch.Tensor, group: Optional[Any] = None, sync_grads: bool = False) -> torch.Tensor:
272+
"""
273+
Function to gather a tensor from several distributed processes
274+
Args:
275+
tensor: tensor of shape (batch, ...)
276+
group: not available with TPUs
277+
sync_grads: not available with TPUs
278+
Return:
279+
A tensor of shape (world_size, batch, ...)
280+
"""
281+
return xm.all_gather(tensor.unsqueeze(0))

tests/models/test_tpu.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -229,8 +229,8 @@ def test_tpu_clip_grad_by_value(tmpdir):
229229
progress_bar_refresh_rate=0,
230230
max_epochs=4,
231231
tpu_cores=1,
232-
limit_train_batches=4,
233-
limit_val_batches=4,
232+
limit_train_batches=10,
233+
limit_val_batches=10,
234234
gradient_clip_val=0.5,
235235
gradient_clip_algorithm='value'
236236
)

0 commit comments

Comments
 (0)