-
Notifications
You must be signed in to change notification settings - Fork 24k
AOTDispatch: allow subclasses to correct when we guess metadata of tangents incorrectly #118670
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
…ngents incorrectly [ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/118670
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 6d43803 with merge base 9347a79 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
…adata of tangents incorrectly" This PR is enough to fix #118600. More description of the problem is in the issue, but the high-level problem is similar to the "tangents might be non-contiguous" problem that we handle today, via forcing all tangents to be contiguous. There, the problem was something like: "We guessed the tangent strides incorrectly, because strides on the runtime tangents were different from strides on the forward outputs, which we used to generate tangents" Here, the problem is similar: "We guessed the tangent tensor subclass's metadata incorrectly, because the runtime tangent was a subclass with different metadata than the forward output subclass". This happened in an internal DTensor issue, where the metadata in question was the `placements` (shard vs. replicate vs. Partial). One option is to solve this problem via backward guards. This is needed to unblock internal though, so I figured handling this similarly to how we handle non-contiguous tangents would be reasonable. I did this by: (1) Assert that the metadata on subclass tangents is the same as what we guessed, and if not raise a loud error (2) In the error message, provide the name of an optional method that the subclass must implement to handle this case: `def __force_same_metadata__(self, metadata_tensor):`: If the forward output had a `Replicate()` placement, but the runtime tangent had a `Shard(1)` placement, this method allows a subclass to take the tangent and "convert" it to one with a `Replicate()` placement. `__force_standard_metadata__(self)`: One issue is that there is another placement called `_Partial`, and its semantics are such that DTensor is **unable** to convert a DTensor with some placement type into another DTensor with a `_Partial` placement. `__force_standard_metadata__` is now called on all (fake) subclass forward outs at trace-time to generate tangents, and gives subclasses a chance to "fix" any outputs with metadata that they cannot convert to later. Morally, this is similar to the fact that we force a `contiguous()` call on all tangents at trace-time. I'm interested in thoughts/feedback! Two new dunder methods on traceable subclasses is definitely a contentious change. cc mrshenli pritamdamania87 zhaojuanmao satgera rohan-varma gqchen aazzolini osalpekar jiayisuse H-Huang kwen2501 awgu penguinwu fegin XilunWu wanchaol fduwjj wz337 tianyu-l wconstab yf225 [ghstack-poisoned]
…adata of tangents incorrectly" This PR is enough to fix #118600. More description of the problem is in the issue, but the high-level problem is similar to the "tangents might be non-contiguous" problem that we handle today, via forcing all tangents to be contiguous. There, the problem was something like: "We guessed the tangent strides incorrectly, because strides on the runtime tangents were different from strides on the forward outputs, which we used to generate tangents" Here, the problem is similar: "We guessed the tangent tensor subclass's metadata incorrectly, because the runtime tangent was a subclass with different metadata than the forward output subclass". This happened in an internal DTensor issue, where the metadata in question was the `placements` (shard vs. replicate vs. Partial). One option is to solve this problem via backward guards. This is needed to unblock internal though, so I figured handling this similarly to how we handle non-contiguous tangents would be reasonable. I did this by: (1) Assert that the metadata on subclass tangents is the same as what we guessed, and if not raise a loud error (2) In the error message, provide the name of an optional method that the subclass must implement to handle this case: `def __force_same_metadata__(self, metadata_tensor):`: If the forward output had a `Replicate()` placement, but the runtime tangent had a `Shard(1)` placement, this method allows a subclass to take the tangent and "convert" it to one with a `Replicate()` placement. `__force_standard_metadata__(self)`: One issue is that there is another placement called `_Partial`, and its semantics are such that DTensor is **unable** to convert a DTensor with some placement type into another DTensor with a `_Partial` placement. `__force_standard_metadata__` is now called on all (fake) subclass forward outs at trace-time to generate tangents, and gives subclasses a chance to "fix" any outputs with metadata that they cannot convert to later. Morally, this is similar to the fact that we force a `contiguous()` call on all tangents at trace-time. I'm interested in thoughts/feedback! Two new dunder methods on traceable subclasses is definitely a contentious change. cc mrshenli pritamdamania87 zhaojuanmao satgera rohan-varma gqchen aazzolini osalpekar jiayisuse H-Huang kwen2501 awgu penguinwu fegin XilunWu wanchaol fduwjj wz337 tianyu-l wconstab yf225 [ghstack-poisoned]
…adata of tangents incorrectly" This PR is enough to fix #118600. More description of the problem is in the issue, but the high-level problem is similar to the "tangents might be non-contiguous" problem that we handle today, via forcing all tangents to be contiguous. There, the problem was something like: "We guessed the tangent strides incorrectly, because strides on the runtime tangents were different from strides on the forward outputs, which we used to generate tangents" Here, the problem is similar: "We guessed the tangent tensor subclass's metadata incorrectly, because the runtime tangent was a subclass with different metadata than the forward output subclass". This happened in an internal DTensor issue, where the metadata in question was the `placements` (shard vs. replicate vs. Partial). One option is to solve this problem via backward guards. This is needed to unblock internal though, so I figured handling this similarly to how we handle non-contiguous tangents would be reasonable. I did this by: (1) Assert that the metadata on subclass tangents is the same as what we guessed, and if not raise a loud error (2) In the error message, provide the name of an optional method that the subclass must implement to handle this case: `def __force_same_metadata__(self, metadata_tensor):`: If the forward output had a `Replicate()` placement, but the runtime tangent had a `Shard(1)` placement, this method allows a subclass to take the tangent and "convert" it to one with a `Replicate()` placement. `__force_standard_metadata__(self)`: One issue is that there is another placement called `_Partial`, and its semantics are such that DTensor is **unable** to convert a DTensor with some placement type into another DTensor with a `_Partial` placement. `__force_standard_metadata__` is now called on all (fake) subclass forward outs at trace-time to generate tangents, and gives subclasses a chance to "fix" any outputs with metadata that they cannot convert to later. Morally, this is similar to the fact that we force a `contiguous()` call on all tangents at trace-time. I'm interested in thoughts/feedback! Two new dunder methods on traceable subclasses is definitely a contentious change. cc mrshenli pritamdamania87 zhaojuanmao satgera rohan-varma gqchen aazzolini osalpekar jiayisuse H-Huang kwen2501 awgu penguinwu fegin XilunWu wanchaol fduwjj wz337 tianyu-l wconstab yf225 [ghstack-poisoned]
…adata of tangents incorrectly" This PR is enough to fix #118600. More description of the problem is in the issue, but the high-level problem is similar to the "tangents might be non-contiguous" problem that we handle today, via forcing all tangents to be contiguous. There, the problem was something like: "We guessed the tangent strides incorrectly, because strides on the runtime tangents were different from strides on the forward outputs, which we used to generate tangents" Here, the problem is similar: "We guessed the tangent tensor subclass's metadata incorrectly, because the runtime tangent was a subclass with different metadata than the forward output subclass". This happened in an internal DTensor issue, where the metadata in question was the `placements` (shard vs. replicate vs. Partial). One option is to solve this problem via backward guards. This is needed to unblock internal though, so I figured handling this similarly to how we handle non-contiguous tangents would be reasonable. I did this by: (1) Assert that the metadata on subclass tangents is the same as what we guessed, and if not raise a loud error (2) In the error message, provide the name of an optional method that the subclass must implement to handle this case: `def __force_same_metadata__(self, metadata_tensor):`: If the forward output had a `Replicate()` placement, but the runtime tangent had a `Shard(1)` placement, this method allows a subclass to take the tangent and "convert" it to one with a `Replicate()` placement. `__force_standard_metadata__(self)`: One issue is that there is another placement called `_Partial`, and its semantics are such that DTensor is **unable** to convert a DTensor with some placement type into another DTensor with a `_Partial` placement. `__force_standard_metadata__` is now called on all (fake) subclass forward outs at trace-time to generate tangents, and gives subclasses a chance to "fix" any outputs with metadata that they cannot convert to later. Morally, this is similar to the fact that we force a `contiguous()` call on all tangents at trace-time. I'm interested in thoughts/feedback! Two new dunder methods on traceable subclasses is definitely a contentious change. cc mrshenli pritamdamania87 zhaojuanmao satgera rohan-varma gqchen aazzolini osalpekar jiayisuse H-Huang kwen2501 awgu penguinwu fegin XilunWu wanchaol fduwjj wz337 tianyu-l wconstab yf225 [ghstack-poisoned]
…adata of tangents incorrectly" This PR is enough to fix #118600. More description of the problem is in the issue, but the high-level problem is similar to the "tangents might be non-contiguous" problem that we handle today, via forcing all tangents to be contiguous. There, the problem was something like: "We guessed the tangent strides incorrectly, because strides on the runtime tangents were different from strides on the forward outputs, which we used to generate tangents" Here, the problem is similar: "We guessed the tangent tensor subclass's metadata incorrectly, because the runtime tangent was a subclass with different metadata than the forward output subclass". This happened in an internal DTensor issue, where the metadata in question was the `placements` (shard vs. replicate vs. Partial). One option is to solve this problem via backward guards. This is needed to unblock internal though, so I figured handling this similarly to how we handle non-contiguous tangents would be reasonable. I did this by: (1) Assert that the metadata on subclass tangents is the same as what we guessed, and if not raise a loud error (2) In the error message, provide the name of an optional method that the subclass must implement to handle this case: `def __force_same_metadata__(self, metadata_tensor):`: If the forward output had a `Replicate()` placement, but the runtime tangent had a `Shard(1)` placement, this method allows a subclass to take the tangent and "convert" it to one with a `Replicate()` placement. `__force_standard_metadata__(self)`: One issue is that there is another placement called `_Partial`, and its semantics are such that DTensor is **unable** to convert a DTensor with some placement type into another DTensor with a `_Partial` placement. `__force_standard_metadata__` is now called on all (fake) subclass forward outs at trace-time to generate tangents, and gives subclasses a chance to "fix" any outputs with metadata that they cannot convert to later. Morally, this is similar to the fact that we force a `contiguous()` call on all tangents at trace-time. I'm interested in thoughts/feedback! Two new dunder methods on traceable subclasses is definitely a contentious change. cc mrshenli pritamdamania87 zhaojuanmao satgera rohan-varma gqchen aazzolini osalpekar jiayisuse H-Huang kwen2501 awgu penguinwu fegin XilunWu wanchaol fduwjj wz337 tianyu-l wconstab yf225 [ghstack-poisoned]
if not isinstance(x, Tensor): | ||
return x | ||
out = x.detach().contiguous() | ||
# Note [Tangents must be contiguous, Part 2] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cc @ezyang / @zou3519, I have CI mostly passing on my stack now so this PR is ready for another look. (I couldn't quite remember the final names we came up with from API bikeshedding Ed, but I tried to update them here).
My read of the situation of this PR is something like:
(1) these two new API's are not very ideal (two new subclass API's particular to tangents), but at the very least they are purely optional, you get a loud error in the rare situation that your subclass needed them but they weren't provided, and in the long-term state we can effectively forget about them
(2) the "right" solution would be to retrace the backward. There is still some design necessary, we should probably sit down with Horace and get in agreement on a long term design. It's probably not worth blocking internal models on this though, hence the PR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is nice and simple, thanks.
FWIW - after looking at the performance profile of the model that needs this fix, this PR is even more of a reason to eventually do the "retrace the backward" work. This PR fixes the problem by essentially forcing DTensor to perform extra collectives at runtime, and each of these collectives can potentially be extremely bad for performance (but for other subclasses, maybe this "coercing" won't have as much of a runtime cost). |
…adata of tangents incorrectly" This PR is enough to fix #118600. More description of the problem is in the issue, but the high-level problem is similar to the "tangents might be non-contiguous" problem that we handle today, via forcing all tangents to be contiguous. There, the problem was something like: "We guessed the tangent strides incorrectly, because strides on the runtime tangents were different from strides on the forward outputs, which we used to generate tangents" Here, the problem is similar: "We guessed the tangent tensor subclass's metadata incorrectly, because the runtime tangent was a subclass with different metadata than the forward output subclass". This happened in an internal DTensor issue, where the metadata in question was the `placements` (shard vs. replicate vs. Partial). One option is to solve this problem via backward guards. This is needed to unblock internal though, so I figured handling this similarly to how we handle non-contiguous tangents would be reasonable. I did this by: (1) Assert that the metadata on subclass tangents is the same as what we guessed, and if not raise a loud error (2) In the error message, provide the name of an optional method that the subclass must implement to handle this case: `def __force_same_metadata__(self, metadata_tensor):`: If the forward output had a `Replicate()` placement, but the runtime tangent had a `Shard(1)` placement, this method allows a subclass to take the tangent and "convert" it to one with a `Replicate()` placement. `__force_standard_metadata__(self)`: One issue is that there is another placement called `_Partial`, and its semantics are such that DTensor is **unable** to convert a DTensor with some placement type into another DTensor with a `_Partial` placement. `__force_standard_metadata__` is now called on all (fake) subclass forward outs at trace-time to generate tangents, and gives subclasses a chance to "fix" any outputs with metadata that they cannot convert to later. Morally, this is similar to the fact that we force a `contiguous()` call on all tangents at trace-time. I'm interested in thoughts/feedback! Two new dunder methods on traceable subclasses is definitely a contentious change. cc mrshenli pritamdamania87 zhaojuanmao satgera rohan-varma gqchen aazzolini osalpekar jiayisuse H-Huang kwen2501 awgu penguinwu fegin XilunWu wanchaol fduwjj wz337 tianyu-l wconstab yf225 [ghstack-poisoned]
…adata of tangents incorrectly" This PR is enough to fix #118600. More description of the problem is in the issue, but the high-level problem is similar to the "tangents might be non-contiguous" problem that we handle today, via forcing all tangents to be contiguous. There, the problem was something like: "We guessed the tangent strides incorrectly, because strides on the runtime tangents were different from strides on the forward outputs, which we used to generate tangents" Here, the problem is similar: "We guessed the tangent tensor subclass's metadata incorrectly, because the runtime tangent was a subclass with different metadata than the forward output subclass". This happened in an internal DTensor issue, where the metadata in question was the `placements` (shard vs. replicate vs. Partial). One option is to solve this problem via backward guards. This is needed to unblock internal though, so I figured handling this similarly to how we handle non-contiguous tangents would be reasonable. I did this by: (1) Assert that the metadata on subclass tangents is the same as what we guessed, and if not raise a loud error (2) In the error message, provide the name of an optional method that the subclass must implement to handle this case: `def __force_same_metadata__(self, metadata_tensor):`: If the forward output had a `Replicate()` placement, but the runtime tangent had a `Shard(1)` placement, this method allows a subclass to take the tangent and "convert" it to one with a `Replicate()` placement. `__force_standard_metadata__(self)`: One issue is that there is another placement called `_Partial`, and its semantics are such that DTensor is **unable** to convert a DTensor with some placement type into another DTensor with a `_Partial` placement. `__force_standard_metadata__` is now called on all (fake) subclass forward outs at trace-time to generate tangents, and gives subclasses a chance to "fix" any outputs with metadata that they cannot convert to later. Morally, this is similar to the fact that we force a `contiguous()` call on all tangents at trace-time. I'm interested in thoughts/feedback! Two new dunder methods on traceable subclasses is definitely a contentious change. cc mrshenli pritamdamania87 zhaojuanmao satgera rohan-varma gqchen aazzolini osalpekar jiayisuse H-Huang kwen2501 awgu penguinwu fegin XilunWu wanchaol fduwjj wz337 tianyu-l wconstab yf225 [ghstack-poisoned]
…adata of tangents incorrectly" This PR is enough to fix #118600. More description of the problem is in the issue, but the high-level problem is similar to the "tangents might be non-contiguous" problem that we handle today, via forcing all tangents to be contiguous. There, the problem was something like: "We guessed the tangent strides incorrectly, because strides on the runtime tangents were different from strides on the forward outputs, which we used to generate tangents" Here, the problem is similar: "We guessed the tangent tensor subclass's metadata incorrectly, because the runtime tangent was a subclass with different metadata than the forward output subclass". This happened in an internal DTensor issue, where the metadata in question was the `placements` (shard vs. replicate vs. Partial). One option is to solve this problem via backward guards. This is needed to unblock internal though, so I figured handling this similarly to how we handle non-contiguous tangents would be reasonable. I did this by: (1) Assert that the metadata on subclass tangents is the same as what we guessed, and if not raise a loud error (2) In the error message, provide the name of an optional method that the subclass must implement to handle this case: `def __force_same_metadata__(self, metadata_tensor):`: If the forward output had a `Replicate()` placement, but the runtime tangent had a `Shard(1)` placement, this method allows a subclass to take the tangent and "convert" it to one with a `Replicate()` placement. `__force_standard_metadata__(self)`: One issue is that there is another placement called `_Partial`, and its semantics are such that DTensor is **unable** to convert a DTensor with some placement type into another DTensor with a `_Partial` placement. `__force_standard_metadata__` is now called on all (fake) subclass forward outs at trace-time to generate tangents, and gives subclasses a chance to "fix" any outputs with metadata that they cannot convert to later. Morally, this is similar to the fact that we force a `contiguous()` call on all tangents at trace-time. I'm interested in thoughts/feedback! Two new dunder methods on traceable subclasses is definitely a contentious change. cc mrshenli pritamdamania87 zhaojuanmao satgera rohan-varma gqchen aazzolini osalpekar jiayisuse H-Huang kwen2501 awgu penguinwu fegin XilunWu wanchaol fduwjj wz337 tianyu-l wconstab yf225 [ghstack-poisoned]
…adata of tangents incorrectly" This PR is enough to fix #118600. More description of the problem is in the issue, but the high-level problem is similar to the "tangents might be non-contiguous" problem that we handle today, via forcing all tangents to be contiguous. There, the problem was something like: "We guessed the tangent strides incorrectly, because strides on the runtime tangents were different from strides on the forward outputs, which we used to generate tangents" Here, the problem is similar: "We guessed the tangent tensor subclass's metadata incorrectly, because the runtime tangent was a subclass with different metadata than the forward output subclass". This happened in an internal DTensor issue, where the metadata in question was the `placements` (shard vs. replicate vs. Partial). One option is to solve this problem via backward guards. This is needed to unblock internal though, so I figured handling this similarly to how we handle non-contiguous tangents would be reasonable. I did this by: (1) Assert that the metadata on subclass tangents is the same as what we guessed, and if not raise a loud error (2) In the error message, provide the name of an optional method that the subclass must implement to handle this case: `def __force_same_metadata__(self, metadata_tensor):`: If the forward output had a `Replicate()` placement, but the runtime tangent had a `Shard(1)` placement, this method allows a subclass to take the tangent and "convert" it to one with a `Replicate()` placement. `__force_standard_metadata__(self)`: One issue is that there is another placement called `_Partial`, and its semantics are such that DTensor is **unable** to convert a DTensor with some placement type into another DTensor with a `_Partial` placement. `__force_standard_metadata__` is now called on all (fake) subclass forward outs at trace-time to generate tangents, and gives subclasses a chance to "fix" any outputs with metadata that they cannot convert to later. Morally, this is similar to the fact that we force a `contiguous()` call on all tangents at trace-time. I'm interested in thoughts/feedback! Two new dunder methods on traceable subclasses is definitely a contentious change. cc mrshenli pritamdamania87 zhaojuanmao satgera rohan-varma gqchen aazzolini osalpekar jiayisuse H-Huang kwen2501 awgu penguinwu fegin XilunWu wanchaol fduwjj wz337 tianyu-l wconstab yf225 [ghstack-poisoned]
…adata of tangents incorrectly" This PR is enough to fix #118600. More description of the problem is in the issue, but the high-level problem is similar to the "tangents might be non-contiguous" problem that we handle today, via forcing all tangents to be contiguous. There, the problem was something like: "We guessed the tangent strides incorrectly, because strides on the runtime tangents were different from strides on the forward outputs, which we used to generate tangents" Here, the problem is similar: "We guessed the tangent tensor subclass's metadata incorrectly, because the runtime tangent was a subclass with different metadata than the forward output subclass". This happened in an internal DTensor issue, where the metadata in question was the `placements` (shard vs. replicate vs. Partial). One option is to solve this problem via backward guards. This is needed to unblock internal though, so I figured handling this similarly to how we handle non-contiguous tangents would be reasonable. I did this by: (1) Assert that the metadata on subclass tangents is the same as what we guessed, and if not raise a loud error (2) In the error message, provide the name of an optional method that the subclass must implement to handle this case: `def __force_same_metadata__(self, metadata_tensor):`: If the forward output had a `Replicate()` placement, but the runtime tangent had a `Shard(1)` placement, this method allows a subclass to take the tangent and "convert" it to one with a `Replicate()` placement. `__force_standard_metadata__(self)`: One issue is that there is another placement called `_Partial`, and its semantics are such that DTensor is **unable** to convert a DTensor with some placement type into another DTensor with a `_Partial` placement. `__force_standard_metadata__` is now called on all (fake) subclass forward outs at trace-time to generate tangents, and gives subclasses a chance to "fix" any outputs with metadata that they cannot convert to later. Morally, this is similar to the fact that we force a `contiguous()` call on all tangents at trace-time. I'm interested in thoughts/feedback! Two new dunder methods on traceable subclasses is definitely a contentious change. cc mrshenli pritamdamania87 zhaojuanmao satgera rohan-varma gqchen aazzolini osalpekar jiayisuse H-Huang kwen2501 awgu penguinwu fegin XilunWu wanchaol fduwjj wz337 tianyu-l wconstab yf225 [ghstack-poisoned]
…adata of tangents incorrectly" This PR is enough to fix #118600. More description of the problem is in the issue, but the high-level problem is similar to the "tangents might be non-contiguous" problem that we handle today, via forcing all tangents to be contiguous. There, the problem was something like: "We guessed the tangent strides incorrectly, because strides on the runtime tangents were different from strides on the forward outputs, which we used to generate tangents" Here, the problem is similar: "We guessed the tangent tensor subclass's metadata incorrectly, because the runtime tangent was a subclass with different metadata than the forward output subclass". This happened in an internal DTensor issue, where the metadata in question was the `placements` (shard vs. replicate vs. Partial). One option is to solve this problem via backward guards. This is needed to unblock internal though, so I figured handling this similarly to how we handle non-contiguous tangents would be reasonable. I did this by: (1) Assert that the metadata on subclass tangents is the same as what we guessed, and if not raise a loud error (2) In the error message, provide the name of an optional method that the subclass must implement to handle this case: `def __force_same_metadata__(self, metadata_tensor):`: If the forward output had a `Replicate()` placement, but the runtime tangent had a `Shard(1)` placement, this method allows a subclass to take the tangent and "convert" it to one with a `Replicate()` placement. `__force_standard_metadata__(self)`: One issue is that there is another placement called `_Partial`, and its semantics are such that DTensor is **unable** to convert a DTensor with some placement type into another DTensor with a `_Partial` placement. `__force_standard_metadata__` is now called on all (fake) subclass forward outs at trace-time to generate tangents, and gives subclasses a chance to "fix" any outputs with metadata that they cannot convert to later. Morally, this is similar to the fact that we force a `contiguous()` call on all tangents at trace-time. I'm interested in thoughts/feedback! Two new dunder methods on traceable subclasses is definitely a contentious change. cc mrshenli pritamdamania87 zhaojuanmao satgera rohan-varma gqchen aazzolini osalpekar jiayisuse H-Huang kwen2501 awgu penguinwu fegin XilunWu wanchaol fduwjj wz337 tianyu-l wconstab yf225 [ghstack-poisoned]
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
…ngents incorrectly (#118670) This PR is enough to fix #118600. More description of the problem is in the issue, but the high-level problem is similar to the "tangents might be non-contiguous" problem that we handle today, via forcing all tangents to be contiguous. There, the problem was something like: "We guessed the tangent strides incorrectly, because strides on the runtime tangents were different from strides on the forward outputs, which we used to generate tangents" Here, the problem is similar: "We guessed the tangent tensor subclass's metadata incorrectly, because the runtime tangent was a subclass with different metadata than the forward output subclass". This happened in an internal DTensor issue, where the metadata in question was the `placements` (shard vs. replicate vs. Partial). One option is to solve this problem via backward guards. This is needed to unblock internal though, so I figured handling this similarly to how we handle non-contiguous tangents would be reasonable. I did this by: (1) Assert that the metadata on subclass tangents is the same as what we guessed, and if not raise a loud error (2) In the error message, provide the name of an optional method that the subclass must implement to handle this case: `def __force_same_metadata__(self, metadata_tensor):`: If the forward output had a `Replicate()` placement, but the runtime tangent had a `Shard(1)` placement, this method allows a subclass to take the tangent and "convert" it to one with a `Replicate()` placement. `__force_standard_metadata__(self)`: One issue is that there is another placement called `_Partial`, and its semantics are such that DTensor is **unable** to convert a DTensor with some placement type into another DTensor with a `_Partial` placement. `__force_standard_metadata__` is now called on all (fake) subclass forward outs at trace-time to generate tangents, and gives subclasses a chance to "fix" any outputs with metadata that they cannot convert to later. Morally, this is similar to the fact that we force a `contiguous()` call on all tangents at trace-time. I'm interested in thoughts/feedback! Two new dunder methods on traceable subclasses is definitely a contentious change. Pull Request resolved: #118670 Approved by: https://github.com/ezyang
This PR is enough to fix #118600.
More description of the problem is in the issue, but the high-level problem is similar to the "tangents might be non-contiguous" problem that we handle today, via forcing all tangents to be contiguous. There, the problem was something like:
"We guessed the tangent strides incorrectly, because strides on the runtime tangents were different from strides on the forward outputs, which we used to generate tangents"
Here, the problem is similar:
"We guessed the tangent tensor subclass's metadata incorrectly, because the runtime tangent was a subclass with different metadata than the forward output subclass".
This happened in an internal DTensor issue, where the metadata in question was the
placements
(shard vs. replicate vs. Partial).One option is to solve this problem via backward guards. This is needed to unblock internal though, so I figured handling this similarly to how we handle non-contiguous tangents would be reasonable. I did this by:
(1) Assert that the metadata on subclass tangents is the same as what we guessed, and if not raise a loud error
(2) In the error message, provide the name of an optional method that the subclass must implement to handle this case:
def __force_same_metadata__(self, metadata_tensor):
: If the forward output had aReplicate()
placement, but the runtime tangent had aShard(1)
placement, this method allows a subclass to take the tangent and "convert" it to one with aReplicate()
placement.__force_standard_metadata__(self)
: One issue is that there is another placement called_Partial
, and its semantics are such that DTensor is unable to convert a DTensor with some placement type into another DTensor with a_Partial
placement.__force_standard_metadata__
is now called on all (fake) subclass forward outs at trace-time to generate tangents, and gives subclasses a chance to "fix" any outputs with metadata that they cannot convert to later. Morally, this is similar to the fact that we force acontiguous()
call on all tangents at trace-time.I'm interested in thoughts/feedback! Two new dunder methods on traceable subclasses is definitely a contentious change.
Stack from ghstack (oldest at bottom):
cc @mrshenli @pritamdamania87 @zhaojuanmao @satgera @rohan-varma @gqchen @aazzolini @osalpekar @jiayisuse @H-Huang @kwen2501 @awgu @penguinwu @fegin @XilunWu @wanchaol @fduwjj @wz337 @tianyu-l @wconstab @yf225 @chauhang