Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fabric.all_reduce modifies the input value inplace on GPU #18228

Closed
function2-llx opened this issue Aug 4, 2023 · 4 comments · Fixed by #18235
Closed

Fabric.all_reduce modifies the input value inplace on GPU #18228

function2-llx opened this issue Aug 4, 2023 · 4 comments · Fixed by #18235
Labels
bug Something isn't working fabric lightning.fabric.Fabric strategy: ddp DistributedDataParallel ver: 2.0.x

Comments

@function2-llx
Copy link
Contributor

function2-llx commented Aug 4, 2023

Bug description

When calling Fabric.all_reduce on tensors on GPU, it will modify the input value to the sum of all values across processes..

What version are you seeing the problem on?

v2.0

How to reproduce the bug

Run the following code with multiple (e.g., 4) GPUs:

from lightning import Fabric
import torch

def main():
    torch.set_default_device('cuda')
    fabric = Fabric()
    fabric.launch()
    x = torch.tensor(fabric.global_rank).float()
    rx = fabric.all_reduce(x)
    for i in range(fabric.world_size):
        if i == fabric.global_rank:
            print('rank:', i)
            print('input:', x)
            print('reduce:', rx)
        fabric.barrier()

if __name__ == '__main__':
    main()

Error messages and logs

output:

rank: 0
input: tensor(6., device='cuda:0')
reduce: tensor(1.5000, device='cuda:0')
rank: 1
input: tensor(6., device='cuda:1')
reduce: tensor(1.5000, device='cuda:1')
rank: 2
input: tensor(6., device='cuda:2')
reduce: tensor(1.5000, device='cuda:2')
rank: 3
input: tensor(6., device='cuda:3')
reduce: tensor(1.5000, device='cuda:3')

Environment

Current environment
  • CUDA:
    • GPU:
      • NVIDIA A100 80GB PCIe
      • NVIDIA A100 80GB PCIe
      • NVIDIA A100 80GB PCIe
      • NVIDIA A100 80GB PCIe
    • available: True
    • version: 11.8
  • Lightning:
    • lightning: 2.0.6
    • lightning-cloud: 0.5.37
    • lightning-utilities: 0.9.0
    • pytorch-lightning: 2.0.5
    • torch: 2.0.1
    • torchmetrics: 0.11.4
    • torchvision: 0.15.2
  • Packages:
    • addict: 2.4.0
    • aiohttp: 3.8.4
    • aiosignal: 1.3.1
    • antlr4-python3-runtime: 4.9.3
    • anyio: 3.7.1
    • appdirs: 1.4.4
    • arrow: 1.2.3
    • async-timeout: 4.0.2
    • attrs: 23.1.0
    • backoff: 2.2.1
    • beautifulsoup4: 4.12.2
    • blessed: 1.20.0
    • brotli: 1.0.9
    • certifi: 2023.7.22
    • charset-normalizer: 3.2.0
    • click: 8.1.4
    • colorama: 0.4.6
    • contourpy: 1.1.0
    • croniter: 1.4.1
    • cycler: 0.11.0
    • cytoolz: 0.12.1
    • dataclasses: 0.8
    • datasets: 2.13.1
    • dateutils: 0.6.12
    • deepdiff: 6.3.1
    • dill: 0.3.6
    • docker-pycreds: 0.4.0
    • docstring-parser: 0.15
    • einops: 0.6.1
    • et-xmlfile: 1.1.0
    • fastapi: 0.100.0
    • filelock: 3.12.2
    • fonttools: 4.41.0
    • frozenlist: 1.3.3
    • fsspec: 2023.6.0
    • gitdb: 4.0.10
    • gitpython: 3.1.32
    • gmpy2: 2.1.2
    • h11: 0.14.0
    • huggingface-hub: 0.16.2
    • idna: 3.4
    • importlib-metadata: 6.8.0
    • importlib-resources: 6.0.0
    • inquirer: 3.1.3
    • itk-core: 5.3.0
    • itk-filtering: 5.3.0
    • itk-io: 5.3.0
    • itk-numerics: 5.3.0
    • itk-registration: 5.3.0
    • itk-segmentation: 5.3.0
    • itsdangerous: 2.1.2
    • jinja2: 3.1.2
    • joblib: 1.3.0
    • jsonargparse: 4.23.0
    • kiwisolver: 1.4.4
    • lightning: 2.0.6
    • lightning-cloud: 0.5.37
    • lightning-utilities: 0.9.0
    • markdown-it-py: 3.0.0
    • markupsafe: 2.1.3
    • matplotlib: 3.7.2
    • mdurl: 0.1.2
    • mmcv-full: 1.7.1
    • monai: 1.2.0+78.g71aaa2259
    • mpmath: 1.3.0
    • multidict: 6.0.4
    • multiprocess: 0.70.14
    • munkres: 1.1.4
    • mypy-extensions: 1.0.0
    • networkx: 3.1
    • nibabel: 5.1.0
    • nptyping: 2.5.0
    • numpy: 1.25.1
    • omegaconf: 2.3.0
    • opencv-python: 4.8.0.74
    • openpyxl: 3.1.2
    • ordered-set: 4.1.0
    • packaging: 23.1
    • pandas: 2.0.3
    • pathtools: 0.1.2
    • pillow: 9.4.0
    • pip: 23.1.2
    • platformdirs: 3.8.1
    • ply: 3.11
    • pooch: 1.7.0
    • protobuf: 4.23.3
    • psutil: 5.9.5
    • pyarrow: 12.0.1
    • pydantic: 1.10.11
    • pydicom: 2.4.1
    • pygments: 2.15.1
    • pyjwt: 2.7.0
    • pynrrd: 1.0.0
    • pyparsing: 3.0.9
    • pyqt5: 5.15.7
    • pyqt5-sip: 12.11.0
    • pyre-extensions: 0.0.29
    • pysocks: 1.7.1
    • python-dateutil: 2.8.2
    • python-editor: 1.0.4
    • python-multipart: 0.0.6
    • pytorch-lightning: 2.0.5
    • pytz: 2023.3
    • pyyaml: 6.0
    • readchar: 4.0.5
    • regex: 2023.6.3
    • requests: 2.31.0
    • responses: 0.18.0
    • rich: 13.4.2
    • sacremoses: 0.0.53
    • safetensors: 0.3.1
    • scipy: 1.11.1
    • sentry-sdk: 1.28.0
    • setproctitle: 1.3.2
    • setuptools: 68.0.0
    • sip: 6.7.9
    • six: 1.16.0
    • smmap: 3.0.5
    • sniffio: 1.3.0
    • soupsieve: 2.4.1
    • starlette: 0.27.0
    • starsessions: 1.3.0
    • sympy: 1.12
    • timm: 0.9.2
    • tokenizers: 0.13.3
    • toml: 0.10.2
    • tomli: 2.0.1
    • toolz: 0.12.0
    • torch: 2.0.1
    • torchmetrics: 0.11.4
    • torchvision: 0.15.2
    • tornado: 6.3.2
    • tqdm: 4.65.0
    • traitlets: 5.9.0
    • transformers: 4.30.2
    • triton: 2.0.0
    • typeshed-client: 2.3.0
    • typing-extensions: 4.7.1
    • typing-inspect: 0.9.0
    • tzdata: 2023.3
    • urllib3: 2.0.3
    • uvicorn: 0.22.0
    • wandb: 0.15.7
    • wcwidth: 0.2.6
    • websocket-client: 1.6.1
    • websockets: 11.0.3
    • wheel: 0.40.0
    • xformers: 0.0.20
    • xxhash: 0.0.0
    • yapf: 0.40.1
    • yarl: 1.9.2
    • zipp: 3.16.0
  • System:
    • OS: Linux
    • architecture:
      • 64bit
      • ELF
    • processor: x86_64
    • python: 3.11.4
    • release: 5.15.0-73-generic
    • version: updated support for 1.2.0 #80-Ubuntu SMP Mon May 15 15:18:26 UTC 2023

More info

The bug does not occur for tensors on the CPU.

cc @carmocca @justusschock @awaelchli

@function2-llx function2-llx added bug Something isn't working needs triage Waiting to be triaged by maintainers labels Aug 4, 2023
@awaelchli awaelchli added fabric lightning.fabric.Fabric strategy: ddp DistributedDataParallel and removed needs triage Waiting to be triaged by maintainers labels Aug 4, 2023
@awaelchli
Copy link
Contributor

Hey @function2-llx

This is the default behavior in torch, see here how we call it under the hood:
https://github.com/Lightning-AI/lightning/blob/0aeeb60566cc0375df3cf1a4458592651f143717/src/lightning/fabric/utilities/distributed.py#L149

Here are the docs for PyTorch's all_reduce:
https://pytorch.org/docs/stable/distributed.html#torch.distributed.all_reduce

Reducing the tensor this way is more efficient than making a copy first, so I think we should consider keeping it this way. The documentation should be clearer though. WDYT?

@peymanrahi
Copy link

Hey @function2-llx

This is the default behavior in torch, see here how we call it under the hood:

https://github.com/Lightning-AI/lightning/blob/0aeeb60566cc0375df3cf1a4458592651f143717/src/lightning/fabric/utilities/distributed.py#L149

Here are the docs for PyTorch's all_reduce: https://pytorch.org/docs/stable/distributed.html#torch.distributed.all_reduce

Reducing the tensor this way is more efficient than making a copy first, so I think we should consider keeping it this way. The documentation should be clearer though. WDYT?

that was my question too, thanks

@function2-llx
Copy link
Contributor Author

@awaelchli Thanks for the clarification. I agree that most of the time, reducing the tensors without copying first is a good idea for performance, and users can copy the tensors themselves if necessary. So, this is generally not a bug while a clarification in the documentation will be better.

There's still a little problem currently, though. When calling Fabric.all_reduce with "mean" operation, it actually reduces by "sum" first (why not just use ReduceOp.AVG?), making the input tensors become the summation across all processes. However, the final result is calculated in an out-of-place way with copy, leaving the input tensors with an unexpected change.
https://github.com/Lightning-AI/lightning/blob/0aeeb60566cc0375df3cf1a4458592651f143717/src/lightning/fabric/utilities/distributed.py#L126-L128
https://github.com/Lightning-AI/lightning/blob/0aeeb60566cc0375df3cf1a4458592651f143717/src/lightning/fabric/utilities/distributed.py#L151-L152

@awaelchli
Copy link
Contributor

This part could indeed be considered a bug yes. If my memory is correct, the reason this was done is to support other backends that didn't have the "avg" option. See this note in the torch docs:

AVG divides values by the world size before summing across ranks. AVG is only available with the NCCL backend, and only for NCCL versions 2.10 or later.

That is why the SUM + divide by size approach was chosen. This is of course just one way to solve this. Another way would have been to explicitly error out on the option if the backend didn't support it, or implement either approach depending on the backend.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working fabric lightning.fabric.Fabric strategy: ddp DistributedDataParallel ver: 2.0.x
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants
@awaelchli @function2-llx @peymanrahi and others