Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SyncBatchNorm size check update #37133

Closed

Conversation

TomoshibiAkira
Copy link
Contributor

Update the requirements on input dimensions for torch.nn.SyncBatchNorm:

  1. Checks the aggregated batch size count_all instead of batch size in every DDP process SyncBatchNorm size check #36865
  2. Added test function for SyncBatchNorm where every process only has 1 input

@dr-ci
Copy link

dr-ci bot commented Apr 23, 2020

💊 Build failures summary and remediations

As of commit 8dfd2d0 (more details on the Dr. CI page):


💚 💚 Looks good so far! There are no failures yet. 💚 💚


This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.

Please report bugs/suggestions on the GitHub issue tracker.

See how this bot performed.

This comment has been revised 11 times.

@mrshenli
Copy link
Contributor

cc @zhaojuanmao

@zhangguanheng66 zhangguanheng66 added oncall: distributed Add this issue/PR to distributed oncall triage queue module: nn Related to torch.nn triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module labels Apr 23, 2020
Copy link
Contributor

@zhaojuanmao zhaojuanmao left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm. thanks for working on it. just one minor comment.

question: did the test fail without moving to check count_all?

@@ -2135,6 +2135,56 @@ def test_DistributedDataParallel_SyncBatchNorm_2D_Input(self):
)
self._barrier()

@unittest.skipIf(BACKEND != 'nccl' and BACKEND != 'gloo',
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@unittest.skipIf(BACKEND != 'nccl' ....), syncNorm is only supported for nccl

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, you're right. But since the original SyncBatchNorm test function test_DistributedDataParallel_SyncBatchNorm() uses this flag (and I'm not sure why), I might just continue the tradition. :)

About the failing test on CI, it seems like it's related to some networking issues (weird time out) and probably not related to the code.

@zhaojuanmao
Copy link
Contributor

looks good, would you please rebase? there are some irrelevant test failures

@TomoshibiAkira TomoshibiAkira force-pushed the syncbn_size_check_fix branch from 5e7a49f to 2f71a1d Compare April 29, 2020 18:41
@TomoshibiAkira TomoshibiAkira force-pushed the syncbn_size_check_fix branch from 2f71a1d to 8dfd2d0 Compare April 29, 2020 18:51
@TomoshibiAkira
Copy link
Contributor Author

TomoshibiAkira commented Apr 30, 2020

@zhaojuanmao done. all tests are passed now.

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@zhaojuanmao is landing this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@facebook-github-bot
Copy link
Contributor

@zhaojuanmao merged this pull request in ae755a7.

@TomoshibiAkira TomoshibiAkira deleted the syncbn_size_check_fix branch May 3, 2020 16:05
size = count_all.view(-1).long().sum()
if size == 1:
raise ValueError('Expected more than 1 value per channel when training, got input size {}'.format(size))

Copy link
Collaborator

@xwang233 xwang233 Oct 27, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @ngimel @jjsjann123 @ptrblck

I found that this change to size calculation introduces the huge regression in NCCL reduction. For a sync BN forward of (64, 2048, 4, 4) float tensor on my machine with 2 GPUs, the previous code takes 0.7ms, but current code takes 2.4ms.

I think this is because the old size calculation is purely on CPU, but the new size calculation is on GPU, and needs a synchronization for the size == 1 comparison on CPU. This can be easily recovered by 1. using the old size calculation, or 2. remove the if statement.

I'm currently working on migrating the sync BN channels-last from apex, and we can discuss a fix there.

Also, I suggest that we require performance benchmark for relevant PRs in the future.

xwang233 added a commit to xwang233/pytorch that referenced this pull request Oct 27, 2020
facebook-github-bot pushed a commit that referenced this pull request Mar 3, 2021
Summary:
per title

This PR did
- Migrate `apex.parallel.SyncBatchNorm` channels_last to pytorch `torch.nn.SyncBatchNorm`
- Fix a TODO here by fusing `sum`, `div` kernels into backward elementwise kernel
https://github.com/pytorch/pytorch/blob/b167402e2e66a663cd9913885552929b4c045ffa/torch/nn/modules/_functions.py#L76-L95

Todo
- [x] Discuss a regression introduced in #37133 (comment), which is the synchronized copy here
https://github.com/pytorch/pytorch/blob/b167402e2e66a663cd9913885552929b4c045ffa/torch/nn/modules/_functions.py#L32-L34

**Comment**: This PR uses apex version for the size check. Test passed and I haven't seen anything wrong so far.

- [x] The restriction to use channels_last kernel will be like this
```
inline bool batch_norm_use_channels_last_kernels(const at::Tensor& self) {
  return self.is_contiguous(at::MemoryFormat::ChannelsLast) || self.ndimension() == 2;
}
```
I think we can relax that for channels_last_3d as well?

**Comment**: we don't have benchmark for this now, will check this and add functionality later when needed.
- [x] Add test
- [x] Add benchmark

Detailed benchmark is at https://github.com/xwang233/code-snippet/tree/master/syncbn-channels-last

Close #50781

Pull Request resolved: #46906

Reviewed By: albanD

Differential Revision: D26771437

Pulled By: malfet

fbshipit-source-id: d00387044e9d43ac7e6c0e32a2db22c63d1504de
aocsa pushed a commit to Quansight/pytorch that referenced this pull request Mar 15, 2021
…#46906)

Summary:
per title

This PR did
- Migrate `apex.parallel.SyncBatchNorm` channels_last to pytorch `torch.nn.SyncBatchNorm`
- Fix a TODO here by fusing `sum`, `div` kernels into backward elementwise kernel
https://github.com/pytorch/pytorch/blob/b167402e2e66a663cd9913885552929b4c045ffa/torch/nn/modules/_functions.py#L76-L95

Todo
- [x] Discuss a regression introduced in pytorch#37133 (comment), which is the synchronized copy here
https://github.com/pytorch/pytorch/blob/b167402e2e66a663cd9913885552929b4c045ffa/torch/nn/modules/_functions.py#L32-L34

**Comment**: This PR uses apex version for the size check. Test passed and I haven't seen anything wrong so far.

- [x] The restriction to use channels_last kernel will be like this
```
inline bool batch_norm_use_channels_last_kernels(const at::Tensor& self) {
  return self.is_contiguous(at::MemoryFormat::ChannelsLast) || self.ndimension() == 2;
}
```
I think we can relax that for channels_last_3d as well?

**Comment**: we don't have benchmark for this now, will check this and add functionality later when needed.
- [x] Add test
- [x] Add benchmark

Detailed benchmark is at https://github.com/xwang233/code-snippet/tree/master/syncbn-channels-last

Close pytorch#50781

Pull Request resolved: pytorch#46906

Reviewed By: albanD

Differential Revision: D26771437

Pulled By: malfet

fbshipit-source-id: d00387044e9d43ac7e6c0e32a2db22c63d1504de
xsacha pushed a commit to xsacha/pytorch that referenced this pull request Mar 31, 2021
…#46906)

Summary:
per title

This PR did
- Migrate `apex.parallel.SyncBatchNorm` channels_last to pytorch `torch.nn.SyncBatchNorm`
- Fix a TODO here by fusing `sum`, `div` kernels into backward elementwise kernel
https://github.com/pytorch/pytorch/blob/b167402e2e66a663cd9913885552929b4c045ffa/torch/nn/modules/_functions.py#L76-L95

Todo
- [x] Discuss a regression introduced in pytorch#37133 (comment), which is the synchronized copy here
https://github.com/pytorch/pytorch/blob/b167402e2e66a663cd9913885552929b4c045ffa/torch/nn/modules/_functions.py#L32-L34

**Comment**: This PR uses apex version for the size check. Test passed and I haven't seen anything wrong so far.

- [x] The restriction to use channels_last kernel will be like this
```
inline bool batch_norm_use_channels_last_kernels(const at::Tensor& self) {
  return self.is_contiguous(at::MemoryFormat::ChannelsLast) || self.ndimension() == 2;
}
```
I think we can relax that for channels_last_3d as well?

**Comment**: we don't have benchmark for this now, will check this and add functionality later when needed.
- [x] Add test
- [x] Add benchmark

Detailed benchmark is at https://github.com/xwang233/code-snippet/tree/master/syncbn-channels-last

Close pytorch#50781

Pull Request resolved: pytorch#46906

Reviewed By: albanD

Differential Revision: D26771437

Pulled By: malfet

fbshipit-source-id: d00387044e9d43ac7e6c0e32a2db22c63d1504de
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Merged module: nn Related to torch.nn oncall: distributed Add this issue/PR to distributed oncall triage queue open source triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging this pull request may close these issues.

8 participants