Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Incorrect gradients for plan_rfft(x) * x #1496

Closed
awadell1 opened this issue Jan 17, 2024 · 2 comments
Closed

Incorrect gradients for plan_rfft(x) * x #1496

awadell1 opened this issue Jan 17, 2024 · 2 comments

Comments

@awadell1
Copy link

Gradients given by Zygote for a planned rfft differ significantly from ChainRulesCore gradients. Appears to be releated to #1437 #899 #1377

Adjoint in question:

Zygote.jl/src/lib/array.jl

Lines 645 to 659 in 54f1e80

# the adjoint jacobian of an FFT with respect to its input is the reverse FFT of the
# gradient of its inputs, but with different normalization factor
@adjoint function *(P::AbstractFFTs.Plan, xs)
return P * xs, function(Δ)
N = prod(size(xs)[[P.region...]])
return (nothing, N * (P \ Δ))
end
end
@adjoint function \(P::AbstractFFTs.Plan, xs)
return P \ xs, function(Δ)
N = prod(size(Δ)[[P.region...]])
return (nothing, (P * Δ)/N)
end
end

Minimal Workable Example

using Pkg
Pkg.activate(; temp=true)
Pkg.add(["Zygote", "FFTW", "ChainRulesCore"])
using Zygote
using FFTW
using ChainRulesCore

x = rand(3,3)

# No Dims
p = plan_rfft(x)
y, back = Zygote.pullback(*, p, x)
@info "Zygote - No Dims" back(one.(y))[2]
y, back = ChainRulesCore.rrule(*, p, x)
@info "ChainRules - No Dims" back(one.(y))[3]

# dims = 1
p = plan_rfft(x, 1)
y, back = Zygote.pullback(*, p, x)
@info "Zygote - dims=1" back(one.(y))[2]
y, back = ChainRulesCore.rrule(*, p, x)
@info "ChainRules - dims=1" back(one.(y))[3]

Output:

❯ julia zygote.jl 
  Activating new project at `/tmp/jl_9uzHHd`
   Resolving package versions...
    Updating `/tmp/jl_9uzHHd/Project.toml`
  [d360d2e6] + ChainRulesCore v1.19.1
  [7a1cc6ca] + FFTW v1.7.2
  [e88e6eb3] + Zygote v0.6.68
    Updating `/tmp/jl_9uzHHd/Manifest.toml`
  [621f4979] + AbstractFFTs v1.5.0
[...]
┌ Info: Zygote - No Dims
│   (back(one.(y)))[2] =3×3 Matrix{Float64}:9.0  0.0  0.00.0  0.0  0.00.0  0.0  0.0
┌ Info: ChainRules - No Dims
│   (back(one.(y)))[3] =3×3 Matrix{Float64}:6.0  0.0  0.01.5  0.0  0.01.5  0.0  0.0
┌ Info: Zygote - dims=1
│   (back(one.(y)))[2] =3×3 Matrix{Float64}:3.0  3.0  3.00.0  0.0  0.00.0  0.0  0.0
┌ Info: ChainRules - dims=1
│   (back(one.(y)))[3] =3×3 Matrix{Float64}:2.0  2.0  2.00.5  0.5  0.50.5  0.5  0.5
@ToucheSir
Copy link
Member

Can you elaborate on why this isn't just a duplicate of #899? Otherwise I'll probably close it as such.

@awadell1
Copy link
Author

Seeing as it's the same root cause, that makes sense to me

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants