Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cannot test rule for structure where one (non-differentiable) field cannot be vectorized #256

Open
gaurav-arya opened this issue Jul 15, 2022 · 7 comments

Comments

@gaurav-arya
Copy link

gaurav-arya commented Jul 15, 2022

I am writing an rrule for *(::Struct, arr) where the structure and function look like,

struct Struct where {T<:Real}
     x # stuff that can't be differentiated
     y::T  
end

*(s::Struct, arr) = y * x(arr)

However, I am unable to test my rrule, even with a manually provided tangent for an instance of Struct that looks like Tangent{Struct}(;x=NoTangent(), y=one(T)).

The reason seems to be that finite differences tries to to_vec the instance of struct. Given that the struct is not completely ignored, only the field x, it ends up trying to to_vec the field x as well. But this field is a reference to a rather crazy mutable structure with circular references, and so I end up with an error, and am unable to test that the rrule is correct w.r.t. y.

(To make it more concrete, the structure in question is a ScaledPlan from AbstractFFTs and the field x refers to a primitive FFT plan from FFTW, which is mutable because of its pinv cache.)

Ideally, it shouldn't matter what value is in the field x, since it is marked as NoTangent in the user-provided tangent? Just as how the entire input gets ignored if it is marked as NoTangent.

@gaurav-arya gaurav-arya changed the title to_vec is applied recursively to entire structure, even when struct is only partially differentiable Cannot test rule for structure where one (non-differentiable) field cannot be vectorized Jul 15, 2022
@devmotion
Copy link
Member

Tangent{Struct}(;x=NoTangent(), y=one(T))

Did you try Tangent{typeof(mystruct)}(; y=randn(T)) or similar as well (ie without specifying x)?

@gaurav-arya
Copy link
Author

gaurav-arya commented Jul 15, 2022

I just did, and unfortunately that fails too. Looking at the code,

sigargs = zip(xs[.!ignores], ẋs[.!ignores])

it seems like the only influence of the tangent on the finite difference call is whether to ignore the full input or not? (This code is called from test_frule, in which I also get an error)

@mzgubic
Copy link
Member

mzgubic commented Jul 18, 2022

Ideally, it shouldn't matter what value is in the field x, since it is marked as NoTangent in the user-provided tangent? Just as how the entire input gets ignored if it is marked as NoTangent.

I agree this is the ideal behaviour. CRTU is not perfect, you can be almost certain that your rule is correct if tests pass, but tests failing could mean either an issue with the rule or an issue with CRTU.

Slightly less ideal solution, but quicker would be to try defining a to_vec method for the x field. (it's type piracy, but should be ok if just used in tests).

@gaurav-arya
Copy link
Author

Any pointers on doing the necessary type piracy?

For test_frule, I think I would need to overload to_vec, but what should the overload be so that the check will go through when the input tangent has a NoTangent() for field x? And for test_rrule, I think I might need to overload something else, like ProjectTo?

@mzgubic
Copy link
Member

mzgubic commented Jul 26, 2022

Sure, I would try with something like

function FiniteDifferences.to_vec(x::FFTPlan) # or whatever the type of field x is
        function FFTPlan_from_vec(x_vec::Vector)
            return x
        end
        return Bool[], FFTPlan_from_vec
    end

This should work for both the frule and rrule I believe. See more examples here: https://github.com/JuliaDiff/FiniteDifferences.jl/blob/main/src/to_vec.jl

@gaurav-arya
Copy link
Author

gaurav-arya commented Aug 6, 2022

For posterity, here's what I needed to do to get test_frule and test_rrule to work on a structure with a non-differentiable, troublesome nested field (in this case of type InnerPlan):

function FiniteDifferences.to_vec(x::InnerPlan)
    function FFTPlan_from_vec(x_vec::Vector)
        return x
    end
    return Bool[], FFTPlan_from_vec
end
ChainRulesTestUtils.test_approx(::ChainRulesCore.AbstractZero, x::InnerPlan, msg=""; kwargs...) = true
ChainRulesTestUtils.rand_tangent(::AbstractRNG, x::InnerPlan) = ChainRulesCore.NoTangent()

Note that I needed the test_approx one as well, otherwise there was a call zero(::InnerPlan) being made in test_rrule. The rand_tangent one is not strictly necessary if one manually supplies the primal tangent, but is nicer as then ChainRulesTestUtils automatically fills out random tangents for all the other fields in the struct.

@oxinabox
Copy link
Member

oxinabox commented Aug 8, 2022

it seems like the only influence of the tangent on the finite difference call is whether to ignore the full input or not? (This code is called from test_frule, in which I also get an error)

That is correct for that.
It is also used for testing accumulation (+) some times.
Overloading FiniteDifferences.to_vec is the way to go.
Though we should probably make that one a little smarter, so this is needed less often.

And in the long term we want to stop using it entirely.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants