Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Fix] Fix PyTorch1.11 Dist Remove _sync_params #1816

Merged
merged 2 commits into from
Mar 26, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 34 additions & 8 deletions mmcv/parallel/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,15 @@ def train_step(self, *inputs, **kwargs):
'Reducer buckets have been rebuilt in this iteration.',
logger='mmcv')

if getattr(self, 'require_forward_param_sync', True):
self._sync_params()
if ('parrots' not in TORCH_VERSION
and digit_version(TORCH_VERSION) >= digit_version('1.11.0')):
if self._check_sync_bufs_pre_fwd():
self._sync_buffers()
else:
if (getattr(self, 'require_forward_param_sync', False)
and self.require_forward_param_sync):
Comment on lines +52 to +53
Copy link
Collaborator

@zhouzaida zhouzaida Mar 24, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

https://github.com/pytorch/pytorch/blob/50c90a22be3ee6a547ad0222951f2c9f50c02b50/torch/nn/parallel/distributed.py#L277

require_forward_param_sync is a attribute of DDP since torch1.2.0 so do we still need to use getattr to get it rather than self.require_forward_param_sync?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please @luopeichao have a look.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It works in parrots. LGTM.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And parrots also support self.require_forward_param_sync.

self._sync_params()

if self.device_ids:
inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids)
if len(self.device_ids) == 1:
Expand All @@ -57,8 +64,14 @@ def train_step(self, *inputs, **kwargs):
else:
output = self.module.train_step(*inputs, **kwargs)

if torch.is_grad_enabled() and getattr(
self, 'require_backward_grad_sync', True):
if ('parrots' not in TORCH_VERSION
and digit_version(TORCH_VERSION) >= digit_version('1.11.0')):
if self._check_sync_bufs_post_fwd():
self._sync_buffers()

if (torch.is_grad_enabled()
and getattr(self, 'require_backward_grad_sync', False)
and self.require_backward_grad_sync):
if self.find_unused_parameters:
self.reducer.prepare_for_backward(list(_find_tensors(output)))
else:
Expand Down Expand Up @@ -86,8 +99,15 @@ def val_step(self, *inputs, **kwargs):
'Reducer buckets have been rebuilt in this iteration.',
logger='mmcv')

if getattr(self, 'require_forward_param_sync', True):
self._sync_params()
if ('parrots' not in TORCH_VERSION
and digit_version(TORCH_VERSION) >= digit_version('1.11.0')):
if self._check_sync_bufs_pre_fwd():
self._sync_buffers()
else:
if (getattr(self, 'require_forward_param_sync', False)
and self.require_forward_param_sync):
self._sync_params()

if self.device_ids:
inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids)
if len(self.device_ids) == 1:
Expand All @@ -99,8 +119,14 @@ def val_step(self, *inputs, **kwargs):
else:
output = self.module.val_step(*inputs, **kwargs)

if torch.is_grad_enabled() and getattr(
self, 'require_backward_grad_sync', True):
if ('parrots' not in TORCH_VERSION
and digit_version(TORCH_VERSION) >= digit_version('1.11.0')):
if self._check_sync_bufs_post_fwd():
self._sync_buffers()

if (torch.is_grad_enabled()
and getattr(self, 'require_backward_grad_sync', False)
and self.require_backward_grad_sync):
if self.find_unused_parameters:
self.reducer.prepare_for_backward(list(_find_tensors(output)))
else:
Expand Down