-
-
Notifications
You must be signed in to change notification settings - Fork 213
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
pullback of norm causes forward pass to be type unstable #663
Comments
This is not unique to julia> y, back = Zygote.pullback(^, -2.0, 2.0);
julia> y
4.0
julia> back(1.0)
(-4.0, 2.772588722239781 - 12.566370614359172im) See also ChainRules' definition of this rule: https://github.com/JuliaDiff/ChainRules.jl/blob/master/src/rulesets/Base/fastmath_able.jl#L49-L52 The issue is that only for integer powers is the power of a negative real well-defined. Otherwise you must make it complex first: julia> (-2.0)^2.1
ERROR: DomainError with -2.0:
Exponentiation yielding a complex result requires a complex argument.
Replace x^y with (x+0im)^y, Complex(x)^y, or similar.
Stacktrace:
[1] throw_exp_domainerror(::Float64) at ./math.jl:37
[2] ^(::Float64, ::Float64) at ./math.jl:872
[3] top-level scope at REPL[30]:1 Probably the only type-stable thing to do here is to add a rule for integer exponents to avoid complexifying the pullback of the exponent. But if the exponent must be exactly an integer for this, then strictly speaking, its pullback should be 0 or Related proposal for |
so this nicely explains why the exponent is complex, and that does seem like a thorny issue. However, the primary problem I'm having is the conversion of Perhaps I should file this on ChainRules.jl? Briefly looking at the norm rrule in the LinearAlgebra rulesets it's not clear to me why it is switching a Float64 to a Float32. |
the |
Zygote actually doesn't use that rule. Instead it has its own |
awesome, thanks for the help seth! |
So just to update on this, I've abandoned #666 (😈 ) in favor of getting a more general solution into ChainRules. JuliaDiff/ChainRules.jl#224 corrected the rule for |
Minimal example
as expected,
back(y)
also has an eltype ofFloat64
. I'm also not sure why the pullback has length 2, with the second entry a singleComplex{float64}
. Perhaps something strange is happening withp
? Not sure how that would end up as a complex number though.The text was updated successfully, but these errors were encountered: