Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cast on host instead of IPU when using precision=16 #13880

Merged
merged 19 commits into from
Jul 28, 2022
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/pytorch_lightning/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Changed

- When training with `precision=16` on IPU, the cast has been moved off the IPU onto the host, making the copies from host to IPU cheaper ([#13880](https://github.com/Lightning-AI/lightning/pull/13880))
hmellor marked this conversation as resolved.
Show resolved Hide resolved

- `accelerator="gpu"` now automatically selects an available GPU backend (CUDA and MPS currently) ([#13642](https://github.com/Lightning-AI/lightning/pull/13642))


Expand Down
30 changes: 11 additions & 19 deletions src/pytorch_lightning/strategies/ipu.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,25 +42,8 @@


class LightningIPUModule(_LightningModuleWrapperBase):
hmellor marked this conversation as resolved.
Show resolved Hide resolved
hmellor marked this conversation as resolved.
Show resolved Hide resolved
def __init__(
self, pl_module: Union["pl.LightningModule", _LightningPrecisionModuleWrapperBase], precision: Union[str, int]
) -> None:
def __init__(self, pl_module: Union["pl.LightningModule", _LightningPrecisionModuleWrapperBase]) -> None:
super().__init__(pl_module)
self.precision = precision

def forward(self, *inputs: Any, **kwargs: Any) -> Any:
if self.precision in (PrecisionType.MIXED, PrecisionType.HALF):
inputs = self._move_float_tensors_to_half(inputs)

return super().forward(*inputs, **kwargs)

@staticmethod
def batch_to(data: Tensor) -> Tensor:
return data.half()

def _move_float_tensors_to_half(self, batch: Any) -> Any:
batch = apply_to_collection(batch, (FloatTensor, torch.cuda.FloatTensor), function=self.batch_to)
return batch


class IPUStrategy(ParallelStrategy):
Expand Down Expand Up @@ -142,7 +125,7 @@ def setup(self, trainer: "pl.Trainer") -> None:
self._optimizer_zero_grad_original = self.lightning_module.optimizer_zero_grad
self._disable_zero_grad()

model = LightningIPUModule(self.lightning_module, self.precision_plugin.precision)
model = LightningIPUModule(self.lightning_module)
self.model = model

# reset the backup
Expand Down Expand Up @@ -272,6 +255,15 @@ def to_tensor(x):
args = apply_to_collection(args, dtype=(int, float), function=to_tensor)
return args

@staticmethod
def batch_to(data: Tensor) -> Tensor:
hmellor marked this conversation as resolved.
Show resolved Hide resolved
return data.half()

def batch_to_device(self, batch: Any, device: Optional[torch.device] = None, dataloader_idx: int = 0) -> Any:
hmellor marked this conversation as resolved.
Show resolved Hide resolved
if self.precision_plugin.precision in (PrecisionType.MIXED, PrecisionType.HALF):
batch = apply_to_collection(batch, (FloatTensor, torch.cuda.FloatTensor), function=self.batch_to)
return batch
hmellor marked this conversation as resolved.
Show resolved Hide resolved

def _disable_zero_grad(self) -> None:
lightning_module = self.lightning_module
if is_overridden("optimizer_zero_grad", lightning_module):
Expand Down
2 changes: 1 addition & 1 deletion tests/tests_pytorch/accelerators/test_ipu.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ def setup(self, trainer: Trainer, pl_module: LightningModule, stage: Optional[st
def test_pure_half_precision(tmpdir):
class TestCallback(Callback):
def on_train_start(self, trainer: Trainer, pl_module: LightningModule) -> None:
assert trainer.strategy.model.precision == 16
assert trainer.strategy.precision_plugin.precision == 16
for param in trainer.strategy.model.parameters():
assert param.dtype == torch.float16
raise SystemExit
Expand Down