Skip to content

Commit

Permalink
Remove LightningIPUModule
Browse files Browse the repository at this point in the history
  • Loading branch information
hmellor committed Jul 28, 2022
1 parent 2cd738d commit b9ca3d9
Showing 1 changed file with 6 additions and 16 deletions.
22 changes: 6 additions & 16 deletions src/pytorch_lightning/strategies/ipu.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,6 @@
poptorch = None


class LightningIPUModule(_LightningModuleWrapperBase):
def __init__(self, pl_module: Union["pl.LightningModule", _LightningPrecisionModuleWrapperBase]) -> None:
super().__init__(pl_module)


class IPUStrategy(ParallelStrategy):
"""Plugin for training on IPU devices."""

Expand Down Expand Up @@ -125,8 +120,7 @@ def setup(self, trainer: "pl.Trainer") -> None:
self._optimizer_zero_grad_original = self.lightning_module.optimizer_zero_grad
self._disable_zero_grad()

model = LightningIPUModule(self.lightning_module)
self.model = model
self.model = _LightningModuleWrapperBase(self.lightning_module)

# reset the backup
self.poptorch_models = {}
Expand All @@ -139,22 +133,22 @@ def setup(self, trainer: "pl.Trainer") -> None:
training_opts = self.training_opts
inference_opts = self.inference_opts
optimizer = self.lightning_module.trainer.optimizers[0]
model = poptorch.trainingModel(model=model, options=training_opts, optimizer=optimizer)
model = poptorch.trainingModel(model=self.model, options=training_opts, optimizer=optimizer)
self.poptorch_models[RunningStage.TRAINING] = model

if self.lightning_module.trainer.enable_validation:
model = poptorch.inferenceModel(model=model, options=inference_opts)
model = poptorch.inferenceModel(model=self.model, options=inference_opts)
self.poptorch_models[RunningStage.VALIDATING] = model
if self.lightning_module.trainer.num_sanity_val_steps > 0:
self.poptorch_models[RunningStage.SANITY_CHECKING] = model
elif trainer_fn == TrainerFn.VALIDATING:
model = poptorch.inferenceModel(model=model, options=self.inference_opts)
model = poptorch.inferenceModel(model=self.model, options=self.inference_opts)
self.poptorch_models[RunningStage.VALIDATING] = model
elif trainer_fn == TrainerFn.TESTING:
model = poptorch.inferenceModel(model=model, options=self.inference_opts)
model = poptorch.inferenceModel(model=self.model, options=self.inference_opts)
self.poptorch_models[RunningStage.TESTING] = model
elif trainer_fn == TrainerFn.PREDICTING:
model = poptorch.inferenceModel(model=model, options=self.inference_opts)
model = poptorch.inferenceModel(model=self.model, options=self.inference_opts)
self.poptorch_models[RunningStage.PREDICTING] = model

def setup_optimizers(self, trainer: "pl.Trainer") -> None:
Expand Down Expand Up @@ -202,10 +196,6 @@ def inference_opts(self) -> "poptorch.Options":
self._inference_opts = self._create_opts(training=False)
return self._inference_opts

@property
def lightning_module(self) -> Optional["pl.LightningModule"]:
return self.model.module if isinstance(self.model, LightningIPUModule) else self.model

def _convert_to_poptorch_loader(
self, dataloader: DataLoader, sampler, mode: Optional[RunningStage] = None
) -> "poptorch.DataLoader":
Expand Down

0 comments on commit b9ca3d9

Please sign in to comment.