Skip to content
This repository has been archived by the owner on Nov 3, 2023. It is now read-only.

Add support for fairscale sharded ddp #3415

Closed
stephenroller opened this issue Jan 26, 2021 · 1 comment
Closed

Add support for fairscale sharded ddp #3415

stephenroller opened this issue Jan 26, 2021 · 1 comment
Assignees

Comments

@stephenroller
Copy link
Contributor

stephenroller commented Jan 26, 2021

Let's add support for Fairscale's Sharded DDP.

Right now we hardcode that we must use PyTorch's DDP, but let's generalize this:

self.model = torch.nn.parallel.DistributedDataParallel(
self.model, device_ids=device_ids, broadcast_buffers=False
)

Create a new "distributed" folder inside parlai/nn. Inside parlai/nn/distributed/init.py, create a helper class which finds and instantiates the "right" version of data parallel. Something like

class DistributedFactory:
    @classmethod
    def add_cmdline_args(cls, parser, partial_opt):
        # --distributed-method here, as well as moving options from https://github.com/facebookresearch/ParlAI/blob/67433e376fc361dee5aa045cb6bb2b68d3faa478/parlai/core/params.py#L760-L768

    @classmethod
    def factory(cls, model: torch.nn.Module, opt: Opt)
         # based on opt['distributed_method'], instantiate a DDP or ShardedDDP object

Upgrade TGA/TCA/TRA to use this helper.

It would also be nice to use a @register_distributed pattern (see how we do it for Agents, Teachers, and scripts presently), so that we can add internal-only approaches.

@stephenroller
Copy link
Contributor Author

Implemented in #3740.

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Projects
None yet
Development

No branches or pull requests

2 participants