You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
# the adjoint jacobian of an FFT with respect to its input is the reverse FFT of the
# gradient of its inputs, but with different normalization factor
@adjointfunction*(P::AbstractFFTs.Plan, xs)
return P * xs, function(Δ)
N =prod(size(xs)[[P.region...]])
return (nothing, N * (P \ Δ))
end
end
@adjointfunction\(P::AbstractFFTs.Plan, xs)
return P \ xs, function(Δ)
N =prod(size(Δ)[[P.region...]])
return (nothing, (P * Δ)/N)
end
end
Minimal Workable Example
using Pkg
Pkg.activate(; temp=true)
Pkg.add(["Zygote", "FFTW", "ChainRulesCore"])
using Zygote
using FFTW
using ChainRulesCore
x =rand(3,3)
# No Dims
p =plan_rfft(x)
y, back = Zygote.pullback(*, p, x)
@info"Zygote - No Dims"back(one.(y))[2]
y, back = ChainRulesCore.rrule(*, p, x)
@info"ChainRules - No Dims"back(one.(y))[3]
# dims = 1
p =plan_rfft(x, 1)
y, back = Zygote.pullback(*, p, x)
@info"Zygote - dims=1"back(one.(y))[2]
y, back = ChainRulesCore.rrule(*, p, x)
@info"ChainRules - dims=1"back(one.(y))[3]
Gradients given by Zygote for a planned rfft differ significantly from ChainRulesCore gradients. Appears to be releated to #1437 #899 #1377
Adjoint in question:
Zygote.jl/src/lib/array.jl
Lines 645 to 659 in 54f1e80
Minimal Workable Example
Output:
The text was updated successfully, but these errors were encountered: