-
Notifications
You must be signed in to change notification settings - Fork 25
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add complex tensor with subclassing #48
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks really cool!
Do you want to fix the autograd support before I merge this?
if func is torch.ops.aten.mm.default: | ||
assert not kwargs | ||
x, y = args | ||
re = x.re * y.re - x.im * y.im |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These should be @ right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes! 🙇🏼♂️
dtype=torch.complex64, # todo: real to complex dtype | ||
layout=re.layout, | ||
device=re.device, | ||
requires_grad=False, # todo: autograd support |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The best way to add autograd support here is to do a parallel to Tensor
(which is never differentiable.
So I would recommend that ComplexTensor(...)
is never differentiable and you have a create_complex_tensor(...)
which is differentiable and built with a custom autograd Function (where you create a ComplexTensor during the fw).
size=re.size(), | ||
strides=re.stride(), # todo: contiguous only | ||
storage_offset=0, | ||
dtype=torch.complex64, # todo: real to complex dtype |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
dtype=torch.complex64, # todo: real to complex dtype | |
dtype=re.dtype.to_complex(), |
since v2.1
@albanD let's just merge and fix it up on main |
Sorry I had a very busy week, I'll let you take it from here. 😉 (except if you need anything from our side) Btw, @gautierronan is one of my colleagues, he works with me on the dynamiqs library to simulate quantum systems with PyTorch. |
@pierreguilmin / @gautierronan if the two of you are interested in pushing this subclass forward, I'd recommend opening up a little repo with just this class and starting to chuck stuff into it. The subclass zoo here is just to so "it's possible", it's not a good permanent home for a feature that people want to use. |
Thanks for your advice @ezyang. Could you also advise on the next steps to make progress on this? What do you mean by "starting to chuck stuff into it", just implement more operators? |
Yup! |
Pair-programming with @ezyang at the PyTorch Conference 2023 for a WIP implementation of complex tensors working with
torch.compile
.The implementation is inspired from https://github.com/pytorch/pytorch/blob/main/torch/testing/_internal/two_tensor.py.
A few todos left, notably a custom autograd for the constructor.
This was tested with the nightly build
2.2.0.dev20231016
.