Skip to content

Avoid wrapping LightningModule in *DataParallel overrides when not fitting #6977

@ananthsub

Description

@ananthsub

🚀 Feature

For distributed testing or prediction, we don't need to wrap the LightningModule inside of DistributedDataParallel or DataParallel for testing as there are no gradients we need to synchronize. We only need this during the fit stage when model training occurs

Motivation

This can reduce overhead with distributed inference in Lightning. We can also use torchscript modules or models without any trainable parameters purely for inference.

Pitch

We'd need the training type plugins to be aware of the Trainer state somehow. Then we could only apply the wrapper here in case the trainer is set to fit: https://github.com/PyTorchLightning/pytorch-lightning/blob/80c529351439a0f8d3d6e9449cd47d16ba3abbec/pytorch_lightning/plugins/training_type/ddp.py#L249-L256

Alternatives

Additional context

Metadata

Metadata

Assignees

Labels

featureIs an improvement or enhancementhelp wantedOpen to be worked onlet's do it!approved to implement

Type

No type

Projects

No projects

Milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions