-
Notifications
You must be signed in to change notification settings - Fork 98
Representing DTensor in thunder traces #1907
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Co-authored-by: Masaki Kozuki <mkozuki@nvidia.com>
… dtensor-init-support
… dtensor-init-support
… dtensor-init-support
… dtensor-init-support
This reverts commit 225f2e3.
… dtensor-init-support
Gentle ping @IvanYashchuk |
…5/lightning-thunder into dtensor-init-support
for more information, see https://pre-commit.ci
…5/lightning-thunder into dtensor-init-support
With the latest merge, I am seeing failure in the test for Cause -
But this PR updates this map to be
Workaround - cc: @IvanYashchuk |
for more information, see https://pre-commit.ci
Ping @t-vi for stamp |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you @kshitij12345 @IvanYashchuk @crcrpar
@kshitij12345 you will file an issue about USE_DISTRIBUTED=OFF ? |
Opened #2233 for tracking that thunder will work with PyTorch compiled without distributed. |
Fixes #1898
Design Doc - https://docs.google.com/document/d/1Gqb_jXrL-sSqs-D8KrZdcQinxuUSlccZBnnvbYJfYl0/edit?usp=sharing
Changes -
This PR adds support for DTensor inputs to the jitted function. Most of the additions required to support DTensor are present in
thunder/torch/experimental
like theDTensorProxy
, related prims.NOTE:
torch.mul
and no broadcast). Coverage will be followed in subsequent PRs.Following are the main updates:
Prologue: Adds a new primitive
check_dtensor_spec_repr
which will match the repr ofDTensorSpec
of the DTensor in question (see the example below). PR also makes sure that besices theDTensorSpec
there is tensor metadata check for theDTensor
object as well as for the local tensor that it points to. NOTE - Other option for checkingDTensorSpec
would be to keep the inputsDTensorSpec
in the TracingContext and prologue could verify for equality.DTensorProxy: Adds a new Proxy object to represent the
DTensor
. This class inherits fromTensorProxy
asDTensor
is a tensor subclass and implements all the same methods that a tensor implements.Prims and Operations: For computation trace, we add prims and torch level operations for DTensor. We add new prims and operations instead of re-using the existing ones to prevent the executors from claiming an operation on DTensor by-mistake.
Representation in trace -
Example Program
Prologue Trace (relevant snippet)
Computation Trace : There is a
torch
level symboldtensor_mul
which is decomposed into prims for DTensor operations.Backward Trace (initial trace)
Thank you Masaki, Ivan and Mike for the helpful discussions and guidance!