Skip to content

Commit

Permalink
Merge pull request #2 from PyTorchLightning/lite-deepspeed-hack
Browse files Browse the repository at this point in the history
Hack to allow deepspeed to run fp16
  • Loading branch information
awaelchli authored Oct 18, 2021
2 parents baedcf7 + bed73b8 commit bb40e7c
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 7 deletions.
12 changes: 11 additions & 1 deletion pytorch_lightning/lite/lite.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@
import torch.nn as nn
from torch import Tensor
from torch.optim import Optimizer
from torch.utils.data import DataLoader, DistributedSampler, RandomSampler, Sampler, SequentialSampler
from torch.utils.data import DataLoader, DistributedSampler, RandomSampler, SequentialSampler


from pytorch_lightning import Trainer
from pytorch_lightning.accelerators import Accelerator, TPUAccelerator
Expand Down Expand Up @@ -296,13 +297,22 @@ def _run_wrapper(self, run_method: Callable) -> Callable:
return partial(self._run_impl, run_method)

def _run_impl(self, run_method: Callable, *args: Any, **kwargs: Any) -> None:
if isinstance(self._strategy, DeepSpeedPlugin):
# todo: this is a hack as deepspeed currently relies on the precision plugin
self._set_deepspeed_precision_variables()
self._strategy.setup_environment()
if isinstance(self._strategy, DDPSpawnPlugin):
self._strategy.spawn(run_method, *args, **kwargs)
else:
run_method(*args, **kwargs)
# TODO: any teardown needed here?

def _set_deepspeed_precision_variables(self):
amp_type = self._accelerator_connector.amp_type
amp_level = self._accelerator_connector.amp_level
precision = self._accelerator_connector.precision
self._strategy.amp_level, self._strategy.amp_type, self._strategy._precision = amp_level, amp_type, precision

def _setup_model_and_optimizers(
self,
model: nn.Module,
Expand Down
15 changes: 9 additions & 6 deletions pytorch_lightning/plugins/training_type/deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,6 +325,10 @@ def __init__(
self.hysteresis = hysteresis
self.min_loss_scale = min_loss_scale

self._precision = None
self.amp_level = None
self.amp_type = None

def _load_config(self, config):
if config is None and self.DEEPSPEED_ENV_VAR in os.environ:
rank_zero_info(f"Loading DeepSpeed config from set {self.DEEPSPEED_ENV_VAR} environment variable")
Expand Down Expand Up @@ -516,7 +520,7 @@ def model_sharded_context(self) -> Generator[None, None, None]:

@property
def precision(self) -> Union[str, int]:
return self.lightning_module.trainer.precision
return self._precision or self.lightning_module.trainer.precision

def _set_deepspeed_activation_checkpointing(self):
if self.config.get("activation_checkpointing"):
Expand Down Expand Up @@ -633,11 +637,10 @@ def _auto_select_batch_size(self):
return batch_size

def _format_precision_config(self):
# TODO: support precision
return
amp_type = self.lightning_module.trainer.accelerator_connector.amp_type
amp_level = self.lightning_module.trainer.accelerator_connector.amp_level
precision = self.lightning_module.trainer.accelerator_connector.precision
amp_type = self.amp_type or self.lightning_module.trainer.accelerator_connector.amp_type
precision = self.precision or self.lightning_module.trainer.accelerator_connector.precision
if amp_type == AMPType.APEX:
amp_level = self.amp_level or self.lightning_module.trainer.accelerator_connector.amp_level
if precision in (16, "mixed"):
if "fp16" not in self.config and amp_type == AMPType.NATIVE:
# FP16 is a DeepSpeed standalone AMP implementation
Expand Down

0 comments on commit bb40e7c

Please sign in to comment.