Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

consistent behavior for reduce method across all Plugins #6011

Merged
merged 8 commits into from
Feb 20, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
### Fixed


- Made the `Plugin.reduce` method more consistent across all Plugins to reflect a mean-reduction by default ([#6011](https://github.com/PyTorchLightning/pytorch-lightning/pull/6011)



## [1.2.0] - 2021-02-18

### Added
Expand Down
20 changes: 16 additions & 4 deletions pytorch_lightning/plugins/training_type/ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,10 +278,22 @@ def model_to_device(self):
torch.cuda.set_device(self.root_device)
self.model.to(self.root_device)

def reduce(self, output, group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = None):
if isinstance(output, torch.Tensor):
output = sync_ddp_if_available(output, group, reduce_op)
return output
def reduce(self, tensor, group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = "mean"):
"""
Reduces a tensor from several distributed processes to one aggregated tensor.

Args:
tensor: the tensor to sync and reduce
group: the process group to gather results from. Defaults to all processes (world)
reduce_op: the reduction operation. Defaults to 'mean'/'avg'.
Can also be a string 'sum' to calculate the sum during reduction.

Return:
reduced value, except when the input was not a tensor the output remains is unchanged
"""
if isinstance(tensor, torch.Tensor):
tensor = sync_ddp_if_available(tensor, group, reduce_op=(reduce_op or "mean"))
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed this, to make the default "mean".
Let me know what you think @tchaton @justusschock

If you agree with this, I will change the PR title, otherwise I will revert back to a pure docs update.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it makes sense to have mean as default. What is PyTorch default ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pytorch differentiates between reduce (only process 0 gets result) and all_reduce (all processes get the result). The default reduction is SUM.
https://pytorch.org/docs/stable/distributed.html?highlight=all_reduce#torch.distributed.all_reduce

Our internal implementation uses torch.distributed.all_reduce with SUM and for mean we divide by world_size manually.

return tensor

def training_step(self, *args, **kwargs):
return self.model(*args, **kwargs)
Expand Down
24 changes: 18 additions & 6 deletions pytorch_lightning/plugins/training_type/ddp2.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,26 @@ def setup(self, model):
self.task_idx = self.cluster_environment.local_rank()
# the difference to DDP is that we don't call children processes here

def reduce(self, output, *args, **kwargs):
if isinstance(output, Result):
output.dp_reduce()
def reduce(self, tensor, *args, **kwargs):
"""
Reduces a tensor from all processes to one aggregated tensor.
In DDP2, the reduction here is only across local devices within the node.

elif isinstance(output, torch.Tensor):
output = output.mean()
Args:
tensor: the tensor to sync and reduce
*args: ignored for DDP2
**kwargs: ignored for DDP2

return output
Return:
reduced value, except when the input was not a tensor the output remains is unchanged
"""
if isinstance(tensor, Result):
tensor.dp_reduce()

elif isinstance(tensor, torch.Tensor):
tensor = tensor.mean()

return tensor

@property
def root_device(self):
Expand Down
20 changes: 16 additions & 4 deletions pytorch_lightning/plugins/training_type/ddp_spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,10 +256,22 @@ def pre_backward(self, closure_loss: torch.Tensor, should_accumulate: bool, opti
if not self.lightning_module.automatic_optimization and self.model.require_backward_grad_sync:
prepare_for_backward(self.model, closure_loss)

def reduce(self, output, group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = None):
if isinstance(output, torch.Tensor):
output = sync_ddp_if_available(output, group, reduce_op)
return output
def reduce(self, tensor, group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = "mean"):
"""
Reduces a tensor from several distributed processes to one aggregated tensor.

Args:
tensor: the tensor to sync and reduce
group: the process group to gather results from. Defaults to all processes (world)
reduce_op: the reduction operation. Defaults to 'mean'/'avg'.
Can also be a string 'sum' to calculate the sum during reduction.

Return:
reduced value, except when the input was not a tensor the output remains is unchanged
"""
if isinstance(tensor, torch.Tensor):
tensor = sync_ddp_if_available(tensor, group, reduce_op=(reduce_op or "mean"))
return tensor

def training_step(self, *args, **kwargs):
return self.model(*args, **kwargs)
Expand Down
23 changes: 17 additions & 6 deletions pytorch_lightning/plugins/training_type/dp.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,14 +31,25 @@ def setup(self, model):
model.to(self.root_device)
self._model = DataParallel(LightningParallelModule(model), self.parallel_devices)

def reduce(self, output, *args, **kwargs):
if isinstance(output, Result):
output.dp_reduce()
def reduce(self, tensor, *args, **kwargs):
"""
Reduces a tensor from all parallel processes to one aggregated tensor.

elif isinstance(output, torch.Tensor):
output = output.mean()
Args:
tensor: the tensor to sync and reduce
*args: ignored for DP
**kwargs: ignored for DP

return output
Return:
reduced value, except when the input was not a tensor the output remains is unchanged
"""
if isinstance(tensor, Result):
tensor.dp_reduce()

elif isinstance(tensor, torch.Tensor):
tensor = tensor.mean()

return tensor

@property
def root_device(self):
Expand Down
22 changes: 17 additions & 5 deletions pytorch_lightning/plugins/training_type/horovod.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,23 +127,35 @@ def model_to_device(self):
torch.cuda.set_device(self.root_device)
self.model.to(self.root_device)

def reduce(self, output, group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = None):
def reduce(self, tensor, group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = "mean"):
"""
Reduces a tensor from several distributed processes to one aggregated tensor.

Args:
tensor: the tensor to sync and reduce
group: the process group to gather results from. Defaults to all processes (world)
reduce_op: the reduction operation. Defaults to 'mean'/'avg'.
Can also be a string 'sum' to calculate the sum during reduction.

Return:
reduced value, except when the input was not a tensor the output remains is unchanged
"""
if group is not None:
raise ValueError(
"Horovod does not support allreduce using a subcommunicator at this time. "
"Unset `group`."
)

if reduce_op is None or reduce_op == "sum":
reduce_op = hvd.Sum
elif isinstance(reduce_op, str) and reduce_op in ("avg", "mean"):
if reduce_op in (None, "avg", "mean"):
reduce_op = hvd.Average
elif reduce_op == "sum":
reduce_op = hvd.Sum
else:
raise ValueError(f"unrecognized `reduce_op`: {reduce_op}")

# sync all processes before reduction
hvd.join()
return hvd.allreduce(output, op=reduce_op)
return hvd.allreduce(tensor, op=reduce_op)

def gather_all_tensors(self, result: Union[torch.Tensor], group: Optional[Any] = None):
if group is not None:
Expand Down
16 changes: 14 additions & 2 deletions pytorch_lightning/plugins/training_type/single_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,20 @@ def on_tpu(self) -> bool:
def on_gpu(self) -> bool:
return self.device.type == "cuda" and torch.cuda.is_available()

def reduce(self, output: Union[Any, torch.Tensor], *args: Any, **kwargs: Any) -> Union[Any, torch.Tensor]:
return output
def reduce(self, tensor: Union[Any, torch.Tensor], *args: Any, **kwargs: Any) -> Union[Any, torch.Tensor]:
"""
Reduces a tensor from several distributed processes to one aggregated tensor.
As this plugin only operates with a single device, the reduction is simply the identity.

Args:
tensor: the tensor to sync and reduce
*args: ignored
**kwargs: ignored

Return:
the unmodified input as reduction is not needed for single process operation
"""
return tensor

@property
def root_device(self) -> torch.device:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,15 @@ def is_global_zero(self) -> bool:
"""Whether the current process is the rank zero process not only on the local node, but for all nodes."""

@abstractmethod
def reduce(self, output: Union[torch.Tensor, Any], *args: Any, **kwargs: Any) -> Union[torch.Tensor, Any]:
"""Reduces the given output (e.g. across GPUs/Processes)"""
def reduce(self, tensor: Union[torch.Tensor, Any], *args: Any, **kwargs: Any) -> Union[torch.Tensor, Any]:
"""
Reduces the given tensor (e.g. across GPUs/processes).

Args:
tensor: the tensor to sync and reduce
*args: plugin-specific positional arguments
**kwargs: plugin-specific keyword arguments
"""

@abstractmethod
def barrier(self, name: Optional[str] = None) -> None:
Expand Down