Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add check for bf16 in deepspeed inference #16973

Merged
merged 9 commits into from
Mar 21, 2023
Merged
3 changes: 3 additions & 0 deletions src/lightning/pytorch/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed WandbLogger not showing "best" aliases for model checkpoints when `ModelCheckpoint(save_top_k>0)` is used ([#17121](https://github.com/Lightning-AI/lightning/pull/17121))


- Fixed parsing the precision config for inference in `DeepSpeedStrategy` ([#16973](https://github.com/Lightning-AI/lightning/pull/16973))


- Fixed the availability check for `rich` that prevented Lightning to be imported in Google Colab ([#17156](https://github.com/Lightning-AI/lightning/pull/17156))


Expand Down
2 changes: 2 additions & 0 deletions src/lightning/pytorch/strategies/deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -544,6 +544,8 @@ def _initialize_deepspeed_inference(self, model: Module) -> None:
inference_config = {"train_micro_batch_size_per_gpu": 1}
if "fp16" in self.config:
inference_config.update({"fp16": self.config["fp16"]})
if "bf16" in self.config:
inference_config.update({"bf16": self.config["bf16"]})
if self.zero_stage_3:
inference_config.update(
{
Expand Down
26 changes: 26 additions & 0 deletions tests/tests_pytorch/strategies/test_deepspeed_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,6 +349,32 @@ def on_train_start(self, trainer, pl_module) -> None:
trainer.fit(model)


@RunIf(min_cuda_gpus=1, standalone=True, deepspeed=True)
@pytest.mark.parametrize("precision", ["fp16", "bf16"])
def test_deepspeed_inference_precision_during_inference(precision, tmpdir):
"""Ensure if we modify the precision for deepspeed and execute inference-only, the deepspeed config contains
these changes."""

class TestCB(Callback):
def on_validation_start(self, trainer, pl_module) -> None:
assert trainer.strategy.config[precision]
raise SystemExit()

model = BoringModel()
strategy = DeepSpeedStrategy(config={precision: {"enabled": True}})

trainer = Trainer(
default_root_dir=tmpdir,
strategy=strategy,
accelerator="cuda",
devices=1,
callbacks=[TestCB()],
barebones=True,
)
with pytest.raises(SystemExit):
trainer.validate(model)


@RunIf(deepspeed=True)
def test_deepspeed_custom_activation_checkpointing_params(tmpdir):
"""Ensure if we modify the activation checkpointing parameters, the deepspeed config contains these changes."""
Expand Down