Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RFC: Rules for real-to-complex and complex-to-real functions #176

Open
sethaxen opened this issue Jun 25, 2020 · 2 comments
Open

RFC: Rules for real-to-complex and complex-to-real functions #176

sethaxen opened this issue Jun 25, 2020 · 2 comments
Labels
Complex Differentiation Relating to any form of complex differentiation

Comments

@sethaxen
Copy link
Member

Consider a case where we have a function f: ℝᵐ → ℂʳ → ℂˢ → ℝⁿ = ℝᵐ → ℝⁿ, which we can write as f = f₃ ∘ f₂ ∘ f₁.
Typically f₁ will produce a complex output by adding, subtracting, multiplying or dividing the real by a complex number
or by calling promote, complex, Complex or cis.
Typically f₃ will produce a real output by calling a non-holomorphic function like real, imag, abs, abs2, hypot, or angle.

From #167, the fact that there are complex intermediates to f is just an implementation detail. We could have defined f: ℝᵐ → ℝ²ʳ → ℝ²ˢ → ℝⁿ, and the pushforwards and pullbacks of this new f should behave the same.

Since in general tangents are derivatives of a primal wrt a real, and co-tangents are derivatives of a real wrt a primal,
the pushforward through f₁: ℝᵐ → ℂʳ should produce a complex tangent, while the pushforward through f₃: ℂˢ → ℝⁿ should produce a real tangent.
Conversely, the pullback through f₃ should produce a complex cotangent, and the pullback through f₁ should produce a real cotangent.

The pushforward case is pretty easy to handle. We can 1) assume that a non-sensical tangent will not be passed and do nothing special (i.e. assume upstream AD did the right thing) or 2) define custom frules that ensure that the produced tangent of unary functions f₃(::Complex)::Real is real.

The pullback case is more complicated. Right now e.g. in Zygote, unless you create a complex number from reals by calling complex, you'll end up pulling back complex numbers through the initial real part of your program, which not only is wasteful but could break assumptions of the rrules of upstream functions. I propose for the binary functions f₁ adding custom rrules for f₁(::Real, ::Complex)::Complex and f₁(::Complex, ::Real)::Complex to ensure that the co-tangent pulled back to a real primal is actually real.

This came up a point of discussion in JuliaDiff/ChainRules.jl#196, and I would appreciate feedback so we can clarify our conventions here.

@sethaxen
Copy link
Member Author

See also this issue about why Zygote doesn't do this (tl/dr Zygote basically treats all reals as embedded in the complex numbers): FluxML/Zygote.jl#342 and an update here: FluxML/Zygote.jl#472

Also ccing @MikeInnes because this could change the behavior of Zygote.

@ettersi
Copy link
Contributor

ettersi commented Jun 26, 2020

+1 for the pushforward / pullback of f(::Real)::Real) with real sensitivities / adjoints to be real, i.e. for ChainRules to stay real if all its input is real. This convention allows complex pushforwards / pullbacks to be obtained using

invoke(frule, Tuple{typeof(Δx), typeof(f), complex(typeof(x))}, Δx, f, x)
invoke(rrule, Tuple{typeof(f), complex(typeof(x))}, f, x)

Conversely, if we defined pushforwards / pullbacks to be complex for some functions f(::Real)::Real, then there would be no way to get the real version of the derivatives.


Edit: Actually, this does not work since !(Real <: Complex) 😒

Of course, a similar effect could be achieved using

frule(Δx, f, complex(x))
rrule(f, complex(x))

but this incurs some runtime penalty. In the case of Complex vs Real, this penalty would probably be acceptable in most circumstances, but I've been thinking that a similar approach could also be used for similar issues with AbstractArrays, see e.g. JuliaDiff/ChainRules.jl#191 and JuliaDiff/ChainRules.jl#52. The problem there is that it is not clear whether the adjoint of e.g. f(A,B) = A*B with respect to A::Diagonal should be a Diagonal or a Matrix. If the above invoke worked, then that would provide an interface for clarifying the intent: rrule(*, A::Diagonal, B)[2](ΔC)[1] would be a Diagonal, and if you wanted a Matrix instead then you could invoke the Matrix method of the rrule. And in this case, it would clearly not be acceptable to call rrule(*, Matrix(A::Diagonal), B).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Complex Differentiation Relating to any form of complex differentiation
Projects
None yet
Development

No branches or pull requests

3 participants