From b9ca3d9656457b45c1f2f77066e9bdcf4145a2c0 Mon Sep 17 00:00:00 2001 From: Harry Mellor Date: Thu, 28 Jul 2022 10:13:11 +0100 Subject: [PATCH] Remove `LightningIPUModule` --- src/pytorch_lightning/strategies/ipu.py | 22 ++++++---------------- 1 file changed, 6 insertions(+), 16 deletions(-) diff --git a/src/pytorch_lightning/strategies/ipu.py b/src/pytorch_lightning/strategies/ipu.py index 6a32485997c10..54daa808369be 100644 --- a/src/pytorch_lightning/strategies/ipu.py +++ b/src/pytorch_lightning/strategies/ipu.py @@ -41,11 +41,6 @@ poptorch = None -class LightningIPUModule(_LightningModuleWrapperBase): - def __init__(self, pl_module: Union["pl.LightningModule", _LightningPrecisionModuleWrapperBase]) -> None: - super().__init__(pl_module) - - class IPUStrategy(ParallelStrategy): """Plugin for training on IPU devices.""" @@ -125,8 +120,7 @@ def setup(self, trainer: "pl.Trainer") -> None: self._optimizer_zero_grad_original = self.lightning_module.optimizer_zero_grad self._disable_zero_grad() - model = LightningIPUModule(self.lightning_module) - self.model = model + self.model = _LightningModuleWrapperBase(self.lightning_module) # reset the backup self.poptorch_models = {} @@ -139,22 +133,22 @@ def setup(self, trainer: "pl.Trainer") -> None: training_opts = self.training_opts inference_opts = self.inference_opts optimizer = self.lightning_module.trainer.optimizers[0] - model = poptorch.trainingModel(model=model, options=training_opts, optimizer=optimizer) + model = poptorch.trainingModel(model=self.model, options=training_opts, optimizer=optimizer) self.poptorch_models[RunningStage.TRAINING] = model if self.lightning_module.trainer.enable_validation: - model = poptorch.inferenceModel(model=model, options=inference_opts) + model = poptorch.inferenceModel(model=self.model, options=inference_opts) self.poptorch_models[RunningStage.VALIDATING] = model if self.lightning_module.trainer.num_sanity_val_steps > 0: self.poptorch_models[RunningStage.SANITY_CHECKING] = model elif trainer_fn == TrainerFn.VALIDATING: - model = poptorch.inferenceModel(model=model, options=self.inference_opts) + model = poptorch.inferenceModel(model=self.model, options=self.inference_opts) self.poptorch_models[RunningStage.VALIDATING] = model elif trainer_fn == TrainerFn.TESTING: - model = poptorch.inferenceModel(model=model, options=self.inference_opts) + model = poptorch.inferenceModel(model=self.model, options=self.inference_opts) self.poptorch_models[RunningStage.TESTING] = model elif trainer_fn == TrainerFn.PREDICTING: - model = poptorch.inferenceModel(model=model, options=self.inference_opts) + model = poptorch.inferenceModel(model=self.model, options=self.inference_opts) self.poptorch_models[RunningStage.PREDICTING] = model def setup_optimizers(self, trainer: "pl.Trainer") -> None: @@ -202,10 +196,6 @@ def inference_opts(self) -> "poptorch.Options": self._inference_opts = self._create_opts(training=False) return self._inference_opts - @property - def lightning_module(self) -> Optional["pl.LightningModule"]: - return self.model.module if isinstance(self.model, LightningIPUModule) else self.model - def _convert_to_poptorch_loader( self, dataloader: DataLoader, sampler, mode: Optional[RunningStage] = None ) -> "poptorch.DataLoader":