Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add complex tensor with subclassing #48

Merged
merged 1 commit into from
Oct 20, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 57 additions & 0 deletions complex_tensor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import torch


class ComplexTensor(torch.Tensor):
def __new__(cls, re, im):
assert (
re.device == im.device
and re.layout == im.layout
and re.requires_grad == im.requires_grad
and re.dtype == im.dtype
)
res = torch.Tensor._make_wrapper_subclass( # type: ignore[attr-defined]
cls,
size=re.size(),
strides=re.stride(), # todo: contiguous only
storage_offset=0,
dtype=torch.complex64, # todo: real to complex dtype

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
dtype=torch.complex64, # todo: real to complex dtype
dtype=re.dtype.to_complex(),

since v2.1

layout=re.layout,
device=re.device,
requires_grad=False, # todo: autograd support
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The best way to add autograd support here is to do a parallel to Tensor (which is never differentiable.
So I would recommend that ComplexTensor(...) is never differentiable and you have a create_complex_tensor(...) which is differentiable and built with a custom autograd Function (where you create a ComplexTensor during the fw).

)
res.re = re
res.im = im
return res

def __torch_dispatch__(self, func, types, args=(), kwargs=None):
if func is torch.ops.aten.mm.default:
assert not kwargs
x, y = args
re = x.re * y.re - x.im * y.im
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These should be @ right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes! 🙇🏼‍♂️

im = x.re * y.im + x.im * y.re
return ComplexTensor(re, im)
raise NotImplementedError(f"todo {func}")

def __tensor_flatten__(self):
return ["re", "im"], None

@staticmethod
def __tensor_unflatten__(inner_tensors, meta):
assert meta is None
re, im = inner_tensors["re"], inner_tensors["im"]
return ComplexTensor(re, im)

def __repr__(self):
return f"ComplexTensor(real={self.re}, imag={self.im})"


if __name__ == "__main__":

@torch.compile()
def f(x, y):
return x @ y

x = ComplexTensor(torch.tensor([[1]]), torch.tensor([[2]]))
y = ComplexTensor(torch.tensor([[3]]), torch.tensor([[4]]))

print(f(x, y)) # (1 + 2i) * (3 + 4i) = -5 + 10i
Loading