Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add chainrules for r2r, dct #272

Open
vpuri3 opened this issue Jun 14, 2023 · 3 comments
Open

add chainrules for r2r, dct #272

vpuri3 opened this issue Jun 14, 2023 · 3 comments

Comments

@vpuri3
Copy link

vpuri3 commented Jun 14, 2023

No description provided.

@vpuri3 vpuri3 changed the title add rrules for r2r add chainrules for r2r, dct Jun 14, 2023
@vpuri3
Copy link
Author

vpuri3 commented Jun 18, 2023

how is the gradient computed for plan_dct, if there's not rrule dor dct??

using LinearAlgebra, FFTW, Zygote

x = rand(4)
C = plan_dct(x)

f(x) = C \ (C * x) |> norm
g(x) = x |> dct |> idct |> norm
h(x) = plan_dct(x) \ (plan_dct(x) * x) |> norm

@show Zygote.gradient(f, x) # ([0.7499995699183157, 0.5170775887690442, 0.3522881598130941, 0.2145331321046639],)
@show Zygote.gradient(g, x) # errors
@show Zygote.gradient(h, x) # errors

error message:

julia> Zygote.gradient(f, x)                                                          
ERROR: Compiling Tuple{Type{FFTW.r2rFFTWPlan{Float64, Any, false, 1}}, Vector{Float64}, FFTW.FakeArray{Float64, 1}, UnitRange{Int64}, Int64, UInt32, Float64}: try/catch is n
ot supported.                                                                         
Refer to the Zygote documentation for fixes.                                          
https://fluxml.ai/Zygote.jl/latest/limitations                          

@danielwe
Copy link

Cosigning this, having AdjointStyle and adjoint_mul for r2r plans would be great. I'm working out the one-dimensional case, but haven't quite wrapped my head around multidimensional FFTW yet

@danielwe
Copy link

Here's an extremely rudimentary implementation of adjoint_mul for the 1d REDFT10, in case anyone finds it helpful as a starting point

using AbstractFFTs
using FFTW

struct R2RFFTAdjointStyle <: AbstractFFTs.AdjointStyle end

AbstractFFTs.AdjointStyle(::FFTW.r2rFFTWPlan) = R2RFFTAdjointStyle()

function AbstractFFTs.adjoint_mul(
    p::FFTW.r2rFFTWPlan{T}, x::AbstractVector{T}, ::R2RFFTAdjointStyle
) where {T}
    (length(p.kinds) == 1) || throw(ArgumentError("Multidimensional r2r transforms not yet supported"))
    (only(p.kinds) == 5) || throw(ArgumentError("r2r kinds other than REDFT10 not yet supported"))
    pinv = inv(p)
    unscaled_pinv = (pinv isa AbstractFFTs.ScaledPlan) ? pinv.p : pinv
    y = unscaled_pinv * x
    # REDFT10 is unitary except for the first row, so the unscaled inverse is its adjoint
    # except for the first column. To obtain the true adjoint, add more DC.
    y .+= first(x)
    return y
end

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants