-
-
Notifications
You must be signed in to change notification settings - Fork 212
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Complex valued AD for TensorOperations
#151
Comments
TensorOperations
TensorOperations
TensorOperations
TensorOperations
This kind of behavior for complex gradient should be expected, that is, for holomorphic functions such matrix multiplication, the result need to be conjugated. We just write a paper (https://arxiv.org/pdf/2003.04295.pdf) which propose a automatic differentiation algorithm for general complex loss functions which could be able to derive the correct gradient (the same result as if the complex numbers are treated as a tuple of two real numbers). For a general function, the adjoint function should be defined as in Eq.15. In this special case of a holomorphic function, the result should be conjugated. |
@guochu Maybe you wanted to comment under this issue #29 . What you're writing is exactly the wirtinger's algebra. Your fomula is exactly the same as the one in Akira's book “Complex Valued Neural Networks”, and many rules in your work has already been covered in my blog. Although it is not something new, but it is very impressive to figure it out on your own! |
Thanks for your reply! I read you blog and yes, I think the correct gradients as well as the complex chain rule should be very well-known. Just that the to define a reasonable adjoint function after knowing the chain rule, I guess one needs to slightly one step further (which is from Eq.10 to Eq.15 in the reference paper). At least this step is non-trivial for myself. I knew the gradient required is 2\partial f/\partial z* and also knew the complex chain rule, but when I was using Zygote with complex functions I could just not converge. And the I tried to think how to define a correct adjoint function to make it work which lead to the manuscript mentioned. It could be a reference for myself and hopefully to clarify some doubts for newcomers. Maybe these rules are already the guideline of Zygote for complex functions, just that there exists some bugs in some functions. I would report a bug then. Thanks again! |
Better to submit an issue. |
Sure, I have already submitted an issue. If Zygote is designed to follow this rule then the definition of the adjoint function "dot" (which is a really frequently used function I think) is wrong. This is the reason I figured out why my program does not converge and I have to redefine it myself. I am sure there are other problems with other non-holomorphic functions. The goal of the manuscript is not to tell how to compute the complex gradient using the chain rule, which should be well known and is the Eq.10 in the manuscript. In the reverse-mode automatic differentiation algorithm, the compute does not directly evaluate the chain rule, but will use a user-defined "adjoint function" which you must be very familiar with, such that the compute will be able to "automatically" derive the gradient of the composite function which is the same as the result from the chain rule. This manuscript tells how to defined such an adjoint function which is Eq.15 in the general case, instead of telling the explicit form of the complex gradient. In principle Eq.15 could be "straightforwardly" derived from the chain rule in Eq.10, but I have not seen it written formally anywhere else. If you know such references I will be very grateful if you point them out to me. |
where?
explaining? I am interested to know how do you comment on the backward rules defined in OMEinsum.jl ? It is similar to dot. We use that to write our TRG and CTMRG algorithms, the training works very good thankfully.
I suppose you mean yes?
But I am not sure you can find the digital version of this book. I read that book from the library. |
Here it is. #540
The correct version could be derived from Eq.15, which is listed in the table. The table also contains the adjoint function for matrix multiplication, which is similar to tensor contract, as you can see, a conjugate has to be taken. I looked at this page. In the Autodiff section it mentions the implementation of the gradient, I am sure this form would be problematic for complex tensors, but I am not sure if it has done the conjugation internally in case of complex tensors or not. If this package does not do conjugation and your code work well with it, I am afraid that this is either because your system has time-reversal symmetry or that you are really lucky.. Thanks for pointing out the book, but I do not have access to it... I can only find this theisi. But I donot know whether it relates to Sec.4 of the book you mentioned. In this thesis I just found complex chain rule and didn't not find explicit rule for adjoint function of complex function.. |
Bump. Any updates on this? |
It seems that you can no longer get the gradient of TensorOperations functions simply from the above code now.. Anyway you can still manually write the adjoint function for tensor contract which should be relatively straightforward. |
I still not have found time for this myself, but there is the integration of TensorOperations.jl in Tullio.jl which can also compute the gradients, and there is also https://github.com/mcabbott/TensorGrad.jl and https://github.com/ho-oto/TensorRules.jl which might still be functional. |
Zygote.jl
now goes pretty well with TensorOperations.jl, it gives correct result for real numbers without any effort. But for complex numbers, the gradient is different by a conjugateAs an example
The untouched version
The correct version (I believe)
f
is equivalent tog
, the exact gradient should bef'(a) = 2a
, butg'(a)
gives2conj(a)
. Wondering which function@tensor
is called into that gives incorrect gradient.@Jutho @under-Peter
Performance
After fixing the conjugate array broadcasting issue (#146 (comment)), the backward is
5
times slower than forward (two fold may be explained by differentiating over two inputsA
andB
)BTW: I notice that the undesired gradient in the output tuple is still calculated, is it possible to avoid this kind of overhead, like knowing which is needed? @MikeInnes
The text was updated successfully, but these errors were encountered: