Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

pullback of norm causes forward pass to be type unstable #663

Closed
dsweber2 opened this issue May 28, 2020 · 6 comments · Fixed by #847
Closed

pullback of norm causes forward pass to be type unstable #663

dsweber2 opened this issue May 28, 2020 · 6 comments · Fixed by #847

Comments

@dsweber2
Copy link
Contributor

Minimal example

y, back = pullback(norm, randn(Float32, 10))
typeof(y)
typeof(back(y))

as expected, back(y) also has an eltype of Float64. I'm also not sure why the pullback has length 2, with the second entry a single Complex{float64}. Perhaps something strange is happening with p? Not sure how that would end up as a complex number though.

@dsweber2 dsweber2 changed the title pullback causes forward pass to be type unstable pullback of norm causes forward pass to be type unstable May 28, 2020
@sethaxen
Copy link
Contributor

This is not unique to norm but rather to any function that computes the pullback of an exponent of a potentially negative number. A minimal example:

julia> y, back = Zygote.pullback(^, -2.0, 2.0);

julia> y
4.0

julia> back(1.0)
(-4.0, 2.772588722239781 - 12.566370614359172im)

See also ChainRules' definition of this rule: https://github.com/JuliaDiff/ChainRules.jl/blob/master/src/rulesets/Base/fastmath_able.jl#L49-L52

The issue is that only for integer powers is the power of a negative real well-defined. Otherwise you must make it complex first:

julia> (-2.0)^2.1
ERROR: DomainError with -2.0:
Exponentiation yielding a complex result requires a complex argument.
Replace x^y with (x+0im)^y, Complex(x)^y, or similar.
Stacktrace:
 [1] throw_exp_domainerror(::Float64) at ./math.jl:37
 [2] ^(::Float64, ::Float64) at ./math.jl:872
 [3] top-level scope at REPL[30]:1

Probably the only type-stable thing to do here is to add a rule for integer exponents to avoid complexifying the pullback of the exponent. But if the exponent must be exactly an integer for this, then strictly speaking, its pullback should be 0 or nothing. This is what we do for exponents of Symmetric/Hermitian matrices.

Related proposal for norm: JuliaDiff/ChainRules.jl#204

@dsweber2
Copy link
Contributor Author

so this nicely explains why the exponent is complex, and that does seem like a thorny issue.

However, the primary problem I'm having is the conversion of Float32 to Float64. This breaks ∇maxpool as some of the arguments are Float32 and some are Float64, and it insists on the same element type.

Perhaps I should file this on ChainRules.jl? Briefly looking at the norm rrule in the LinearAlgebra rulesets it's not clear to me why it is switching a Float64 to a Float32.

@dsweber2
Copy link
Contributor Author

the Float32 -> Float64 also appears to happen for square rooting via (x).^(1/2)

@sethaxen
Copy link
Contributor

Perhaps I should file this on ChainRules.jl? Briefly looking at the norm rrule in the LinearAlgebra rulesets it's not clear to me why it is switching a Float64 to a Float32.

Zygote actually doesn't use that rule. Instead it has its own norm: https://github.com/FluxML/Zygote.jl/blob/master/src/lib/array.jl#L426-L429. The type promotion with norm is happening for the same reason as the (x).^(1/2) example you shared. 1/2 evaluates to a Float64, hence the type is promoted. Simply changing the exponent to 1f0/2 fixes this. I'll open a PR to add the fix to norm.

@dsweber2
Copy link
Contributor Author

awesome, thanks for the help seth!

@sethaxen
Copy link
Contributor

sethaxen commented Jul 8, 2020

So just to update on this, I've abandoned #666 (😈 ) in favor of getting a more general solution into ChainRules. JuliaDiff/ChainRules.jl#224 corrected the rule for ^(x::Real, p::Real) to avoid introducing a complex adjoint (you should be able to see that now with a fresh Zygote install). JuliaDiff/ChainRules.jl#226 will introduce rules for norm that will be significantly faster and avoid the unnecessary type promotion that you're seeing. Once that is finished, I'll open a PR to remove Zygote's rule for norm.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
2 participants