-
Notifications
You must be signed in to change notification settings - Fork 2.1k
[BlenderBot2] add model.module to make BB2 DDP compatible #4259
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
@@ -800,7 +800,10 @@ def _set_batch_skip_search(self, valid_exs: List[Message], batch: Batch) -> Batc | |||
|
|||
def eval_step(self, batch): | |||
output = super().eval_step(batch) | |||
if output is None or not hasattr(self.model, 'retriever'): | |||
model = self.model | |||
if isinstance(self.model, torch.nn.parallel.DistributedDataParallel): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
actually, this is what self.model_api
is meant to accomplish, if you can just switch to that
model = self.model_api | ||
if isinstance(model, torch.nn.parallel.DistributedDataParallel): | ||
model = model.module | ||
if output is None or not hasattr(model, 'retriever'): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sorry, I might not have been clear enough: model_api
already takes care of this check
model = self.model_api | |
if isinstance(model, torch.nn.parallel.DistributedDataParallel): | |
model = model.module | |
if output is None or not hasattr(model, 'retriever'): | |
if output is None or not hasattr(self.model_api, 'retriever'): |
This is the only change required
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ah didn't know that! thanks for the pointer
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks for making the changes!
Patch description
Change model -> model.module in DistributedDataParallel setting in
Testing steps
Other information