-
Notifications
You must be signed in to change notification settings - Fork 7
Fix tv parallelization #758
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
a415093
to
0ddf3e5
Compare
It turned out this is actually not the right thing to do. The parallel map contains mappings created with the the forward-bcast-mismatch enabled, so for example, it would create a mapping between
This PR would do |
0ddf3e5
to
9f9f902
Compare
Ended up adding More broadly, I think a bigger problem is that we don't have straightforward way to know parallelism of |
|
I thought that was how we started down this path: #622 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Clearing approval as looks like we should discuss/think about this more. Broadcast parallelism is definitely a complex topioc.
The substitution is done here: If we allow I see two options:
I'd say the second option is more conservative. It only solves this particular problem with BroadcastOp. The first option may be more preferable as it conveys the parallelism information to KIR. However, there may be side effects as the substitution problem. The first option seems better to me from long-term perspectives. |
Can you try 1 and run the reduction benchmark suite before and after to see if you find any serious perf regressions. We use a lot of tensor size information in our kernels anyways, I think removing the substitution shouldn't be significant, though I could definitely be wrong. |
…onships" This reverts commit ab83e3e6367ab186498b2d0ab81ca09dcb52f434.
There is no easy way to know which parallel types are used for kir::TensorView after the lowering as the ComputeAt parallel map is not maintained. Adds that information to kir::BroadcastOp as it is needed for codegen.
This reverts commit 2caed18.
30d5976
to
c0fa0c9
Compare
@csarofeen Done. I changed |
…e registered EBCs with shardedTensors as registered modules (#758) (pytorch#88026) Summary: X-link: meta-pytorch/torchrec#758 This PR fixes a bug in FSDP/DDP, where ShardedTensors are not supported even if passed in as params to ignore. this is important for composability because TorchRec named_parameters() will return FQN of shardedTensors (as defined in goals) It defines device of ShardedTensor to be None when local_tensor() does not exist on rank update ShardedEmbeddingBagCollection to be composable according to https://docs.google.com/document/d/1TBJSd5zgEg6cRcXv3Okuj7bBkqQwGS2IPh4TLWNNzFI/edit Differential Revision: D40458625 Pull Request resolved: pytorch#88026 Approved by: https://github.com/wanchaol, https://github.com/rohan-varma
Fixes #757