diff --git a/mmseg/models/segmentors/base.py b/mmseg/models/segmentors/base.py index 944da0f2e4..f0f320ffbf 100644 --- a/mmseg/models/segmentors/base.py +++ b/mmseg/models/segmentors/base.py @@ -188,6 +188,17 @@ def _parse_losses(losses): loss = sum(_value for _key, _value in log_vars.items() if 'loss' in _key) + # If the loss_vars has different length, raise assertion error + # to prevent GPUs from infinite waiting. + if dist.is_available() and dist.is_initialized(): + log_var_length = torch.tensor(len(log_vars), device=loss.device) + dist.all_reduce(log_var_length) + message = (f'rank {dist.get_rank()}' + + f' len(log_vars): {len(log_vars)}' + ' keys: ' + + ','.join(log_vars.keys()) + '\n') + assert log_var_length == len(log_vars) * dist.get_world_size(), \ + 'loss log variables are different across GPUs!\n' + message + log_vars['loss'] = loss for loss_name, loss_value in log_vars.items(): # reduce loss when distributed training