-
Notifications
You must be signed in to change notification settings - Fork 8
Add gather and scatter_add strategies #81
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
They were taken from #29
|
|
||
| single_mesh_dim_strategies = [] | ||
|
|
||
| # placement list stores placements of [output, input, index] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: [output, input, index, src]
| """ | ||
| # index sharding, input replicated, index sharded, output follows index | ||
| # this only works when the sharding dimension is the gather dimension | ||
| index_sharding: PlacementList = [Shard(dim), Replicate(), Shard(dim), Shard(dim)] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I feel this may not be correct. Taking the example here https://docs.pytorch.org/docs/stable/generated/torch.Tensor.scatter_add_.html#torch.Tensor.scatter_add_:
>> src = torch.ones((2, 5))
>> index = torch.tensor([[0, 1, 2, 0, 0], [0, 1, 2, 2, 2]])
>> torch.zeros(3, 5, dtype=src.dtype).scatter_add_(0, index, src)
tensor([[2., 0., 0., 1., 1.],
[0., 2., 0., 0., 0.],
[0., 0., 2., 1., 1.]])
the output can become Partial as: [Partial(), Replicate(), Shard(dim), Shard(dim)]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, thanks for the review! I had roughly copy-pasted the gather rule and didn't fix this part. Will adapt it shortly.
Also, do you think we could have this implemented natively in PyTorch?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes! Current upstream scatter_add strategy is just a quick workaround. We can follow up.
zpcore
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think scatter_add may produce incorrect output.
| if len(input_shape) == len(index_shape): | ||
| for d in range(len(input_shape)): | ||
| if d != dim: | ||
| sharding = [Shard(d), Shard(d), Shard(d), Shard(d)] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I tried more tests and noticed that with [Shard(d), Shard(d), Shard(d), Shard(d)], we can't simply shard the output and input. E.g., if dim = 1, we want to shard on dim=0, but input can have much more rows than index, then we will most like only modify the first shard of input, because input row and index row is one to one mapping.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Correct me if I'm wrong, but I thought the shapes needed to match except of the dim ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
RIght, the shape need to match. I change it to
if d != dim and input_shape[d] == index_shape[d]:
and the op coverage pass now. I created the PR with the update here pytorch/pytorch#160140.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh that's true, I definitely missed that case!
|
Subsumed by pytorch/pytorch#160140 |
As title. This PR made a small fix on top of meta-pytorch/autoparallel#81. Pull Request resolved: #160140 Approved by: https://github.com/fmassa
They were taken from #29
Would be good to have those in PyTorch, but I've seen the
gatherwas useful forCrossEntropyLossas well, so probably better to unblock first