diff --git a/mmcv/parallel/distributed.py b/mmcv/parallel/distributed.py index b799a213d8..791b6c080c 100644 --- a/mmcv/parallel/distributed.py +++ b/mmcv/parallel/distributed.py @@ -44,8 +44,15 @@ def train_step(self, *inputs, **kwargs): 'Reducer buckets have been rebuilt in this iteration.', logger='mmcv') - if getattr(self, 'require_forward_param_sync', True): - self._sync_params() + if ('parrots' not in TORCH_VERSION + and digit_version(TORCH_VERSION) >= digit_version('1.11.0')): + if self._check_sync_bufs_pre_fwd(): + self._sync_buffers() + else: + if (getattr(self, 'require_forward_param_sync', False) + and self.require_forward_param_sync): + self._sync_params() + if self.device_ids: inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids) if len(self.device_ids) == 1: @@ -57,8 +64,14 @@ def train_step(self, *inputs, **kwargs): else: output = self.module.train_step(*inputs, **kwargs) - if torch.is_grad_enabled() and getattr( - self, 'require_backward_grad_sync', True): + if ('parrots' not in TORCH_VERSION + and digit_version(TORCH_VERSION) >= digit_version('1.11.0')): + if self._check_sync_bufs_post_fwd(): + self._sync_buffers() + + if (torch.is_grad_enabled() + and getattr(self, 'require_backward_grad_sync', False) + and self.require_backward_grad_sync): if self.find_unused_parameters: self.reducer.prepare_for_backward(list(_find_tensors(output))) else: @@ -86,8 +99,15 @@ def val_step(self, *inputs, **kwargs): 'Reducer buckets have been rebuilt in this iteration.', logger='mmcv') - if getattr(self, 'require_forward_param_sync', True): - self._sync_params() + if ('parrots' not in TORCH_VERSION + and digit_version(TORCH_VERSION) >= digit_version('1.11.0')): + if self._check_sync_bufs_pre_fwd(): + self._sync_buffers() + else: + if (getattr(self, 'require_forward_param_sync', False) + and self.require_forward_param_sync): + self._sync_params() + if self.device_ids: inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids) if len(self.device_ids) == 1: @@ -99,8 +119,14 @@ def val_step(self, *inputs, **kwargs): else: output = self.module.val_step(*inputs, **kwargs) - if torch.is_grad_enabled() and getattr( - self, 'require_backward_grad_sync', True): + if ('parrots' not in TORCH_VERSION + and digit_version(TORCH_VERSION) >= digit_version('1.11.0')): + if self._check_sync_bufs_post_fwd(): + self._sync_buffers() + + if (torch.is_grad_enabled() + and getattr(self, 'require_backward_grad_sync', False) + and self.require_backward_grad_sync): if self.find_unused_parameters: self.reducer.prepare_for_backward(list(_find_tensors(output))) else: