From a2bf484102e5712a32dadc1dc2a8b8d8df34868a Mon Sep 17 00:00:00 2001 From: awaelchli Date: Thu, 16 May 2024 14:03:36 +0200 Subject: [PATCH] remove redundant strategy device setting --- src/lightning/pytorch/strategies/deepspeed.py | 4 ---- src/lightning/pytorch/strategies/fsdp.py | 4 ---- 2 files changed, 8 deletions(-) diff --git a/src/lightning/pytorch/strategies/deepspeed.py b/src/lightning/pytorch/strategies/deepspeed.py index 6be3d3f8ba590..382f8070898f8 100644 --- a/src/lightning/pytorch/strategies/deepspeed.py +++ b/src/lightning/pytorch/strategies/deepspeed.py @@ -337,10 +337,6 @@ def setup(self, trainer: "pl.Trainer") -> None: assert self.accelerator is not None self.accelerator.setup(trainer) - # we set the device so that optimizers can be created with distributed comms. - assert self.lightning_module is not None - self.lightning_module._device = self.root_device - assert self.model is not None self.model = self.precision_plugin.convert_module(self.model) self.model = self._setup_model(self.model) diff --git a/src/lightning/pytorch/strategies/fsdp.py b/src/lightning/pytorch/strategies/fsdp.py index 1aae8b678b674..70590d2f254e2 100644 --- a/src/lightning/pytorch/strategies/fsdp.py +++ b/src/lightning/pytorch/strategies/fsdp.py @@ -305,10 +305,6 @@ def setup(self, trainer: "pl.Trainer") -> None: if trainer.state.fn == TrainerFn.FITTING and self._layer_sync: self.model = self._layer_sync.apply(self.model) - # we set the device so that optimizers can be created with distributed comms. - assert self.lightning_module is not None - self.lightning_module._device = self.root_device - self.model = self.precision_plugin.convert_module(self.model) if is_overridden("configure_sharded_model", self.lightning_module):