Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cast on host instead of IPU when using precision=16 #13880

Merged
merged 19 commits into from
Jul 28, 2022
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/pytorch_lightning/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Changed

- When training with `precision=16` on IPU, the cast has been moved off the IPU onto the host, making the copies from host to IPU cheaper ([#13880](https://github.com/Lightning-AI/lightning/pull/13880))

- `accelerator="gpu"` now automatically selects an available GPU backend (CUDA and MPS currently) ([#13642](https://github.com/Lightning-AI/lightning/pull/13642))


Expand Down
33 changes: 22 additions & 11 deletions src/pytorch_lightning/strategies/ipu.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from pytorch_lightning.utilities.enums import PrecisionType
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.model_helpers import is_overridden
from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation
from pytorch_lightning.utilities.types import STEP_OUTPUT

if _POPTORCH_AVAILABLE:
Expand All @@ -45,6 +46,7 @@ class LightningIPUModule(_LightningModuleWrapperBase):
def __init__(
self, pl_module: Union["pl.LightningModule", _LightningPrecisionModuleWrapperBase], precision: Union[str, int]
) -> None:
rank_zero_deprecation("`LightningIPUModule` is deprecated in v1.8 and will be removed in v2.0.0")
super().__init__(pl_module)
self.precision = precision

Expand Down Expand Up @@ -142,8 +144,7 @@ def setup(self, trainer: "pl.Trainer") -> None:
self._optimizer_zero_grad_original = self.lightning_module.optimizer_zero_grad
self._disable_zero_grad()

model = LightningIPUModule(self.lightning_module, self.precision_plugin.precision)
self.model = model
self.model = _LightningModuleWrapperBase(self.lightning_module)

# reset the backup
self.poptorch_models = {}
Expand All @@ -156,22 +157,22 @@ def setup(self, trainer: "pl.Trainer") -> None:
training_opts = self.training_opts
inference_opts = self.inference_opts
optimizer = self.lightning_module.trainer.optimizers[0]
model = poptorch.trainingModel(model=model, options=training_opts, optimizer=optimizer)
model = poptorch.trainingModel(model=self.model, options=training_opts, optimizer=optimizer)
self.poptorch_models[RunningStage.TRAINING] = model

if self.lightning_module.trainer.enable_validation:
model = poptorch.inferenceModel(model=model, options=inference_opts)
model = poptorch.inferenceModel(model=self.model, options=inference_opts)
self.poptorch_models[RunningStage.VALIDATING] = model
if self.lightning_module.trainer.num_sanity_val_steps > 0:
self.poptorch_models[RunningStage.SANITY_CHECKING] = model
elif trainer_fn == TrainerFn.VALIDATING:
model = poptorch.inferenceModel(model=model, options=self.inference_opts)
model = poptorch.inferenceModel(model=self.model, options=self.inference_opts)
self.poptorch_models[RunningStage.VALIDATING] = model
elif trainer_fn == TrainerFn.TESTING:
model = poptorch.inferenceModel(model=model, options=self.inference_opts)
model = poptorch.inferenceModel(model=self.model, options=self.inference_opts)
self.poptorch_models[RunningStage.TESTING] = model
elif trainer_fn == TrainerFn.PREDICTING:
model = poptorch.inferenceModel(model=model, options=self.inference_opts)
model = poptorch.inferenceModel(model=self.model, options=self.inference_opts)
self.poptorch_models[RunningStage.PREDICTING] = model

def setup_optimizers(self, trainer: "pl.Trainer") -> None:
Expand Down Expand Up @@ -219,10 +220,6 @@ def inference_opts(self) -> "poptorch.Options":
self._inference_opts = self._create_opts(training=False)
return self._inference_opts

@property
def lightning_module(self) -> Optional["pl.LightningModule"]:
return self.model.module if isinstance(self.model, LightningIPUModule) else self.model

def _convert_to_poptorch_loader(
self, dataloader: DataLoader, sampler, mode: Optional[RunningStage] = None
) -> "poptorch.DataLoader":
Expand Down Expand Up @@ -272,6 +269,20 @@ def to_tensor(x):
args = apply_to_collection(args, dtype=(int, float), function=to_tensor)
return args

def batch_to_device(self, batch: Any, device: Optional[torch.device] = None, dataloader_idx: int = 0) -> Any:
# This override is necessary because the cast must occur before the data
# is moved to the device to prevent wasteful host->device copies.
if self.precision_plugin.precision in (PrecisionType.MIXED, PrecisionType.HALF):

def to_half(data: Tensor) -> Tensor:
return data.half()

batch = apply_to_collection(batch, (FloatTensor, torch.cuda.FloatTensor), function=to_half)
# We don't call `super().batch_to_device` because `data.to(device)` is not
# currently necessary for IPUs. The movement of data from host<->IPU is
# currently handled by PopTorch.
Comment on lines +277 to +279
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if that's the case, we need to do a follow-up for users to let them know that some of the hooks won't work with IPUs, like transfer_batch_to_device, ...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this something you could do? I don't know the codebase well enough to know all the implications.

return batch

def _disable_zero_grad(self) -> None:
lightning_module = self.lightning_module
if is_overridden("optimizer_zero_grad", lightning_module):
Expand Down
5 changes: 4 additions & 1 deletion tests/tests_pytorch/accelerators/test_ipu.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,16 +40,19 @@

class IPUModel(BoringModel):
def training_step(self, batch, batch_idx):
assert self.precision == torch.finfo(batch.dtype).bits
output = self(batch)
loss = self.loss(batch, output)
return loss

def validation_step(self, batch, batch_idx):
assert self.precision == torch.finfo(batch.dtype).bits
output = self(batch)
loss = self.loss(batch, output)
return loss

def test_step(self, batch, batch_idx):
assert self.precision == torch.finfo(batch.dtype).bits
output = self(batch)
loss = self.loss(batch, output)
return loss
Expand Down Expand Up @@ -205,7 +208,7 @@ def setup(self, trainer: Trainer, pl_module: LightningModule, stage: Optional[st
def test_pure_half_precision(tmpdir):
class TestCallback(Callback):
def on_train_start(self, trainer: Trainer, pl_module: LightningModule) -> None:
assert trainer.strategy.model.precision == 16
assert trainer.strategy.precision_plugin.precision == 16
for param in trainer.strategy.model.parameters():
assert param.dtype == torch.float16
raise SystemExit
Expand Down
8 changes: 8 additions & 0 deletions tests/tests_pytorch/deprecated_api/test_remove_2-0.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import pytorch_lightning
from pytorch_lightning import Trainer
from pytorch_lightning.demos.boring_classes import BoringModel
from pytorch_lightning.strategies.ipu import LightningIPUModule
from tests_pytorch.callbacks.test_callbacks import OldStatefulCallback
from tests_pytorch.helpers.runif import RunIf

Expand Down Expand Up @@ -50,6 +51,13 @@ def test_v2_0_0_deprecated_ipus(_, monkeypatch):
_ = Trainer(ipus=4)


@mock.patch("pytorch_lightning.accelerators.ipu.IPUAccelerator.is_available", return_value=True)
def test_v2_0_0_deprecated_lightning_ipu_module(_, monkeypatch):
monkeypatch.setattr(pytorch_lightning.strategies.ipu, "_IPU_AVAILABLE", True)
with pytest.deprecated_call(match=r"is deprecated in v1.8 and will be removed in v2.0."):
_ = LightningIPUModule(BoringModel(), 32)


def test_v2_0_resume_from_checkpoint_trainer_constructor(tmpdir):
# test resume_from_checkpoint still works until v2.0 deprecation
model = BoringModel()
Expand Down