Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Can rules make decisions based on which primal method is used? #237

Open
sethaxen opened this issue Jul 16, 2020 · 3 comments
Open

Can rules make decisions based on which primal method is used? #237

sethaxen opened this issue Jul 16, 2020 · 3 comments
Labels
design Requires some design before changes are made type constraints Potentially raises a question about how tightly to constrain argument types for a rule. See #232

Comments

@sethaxen
Copy link
Member

Can we safely detect if a primal function is going to hit a specific default and use that to change the logic in the frule or rrule?

For example, the rrule for det(A) calls inv(A). If we know that the primal function that would be hit is the det(A::AbstractMatrix) definition in generic.jl, then we know that the primal is using the lu decomposition to compute the determinant, and we can reuse that to compute the inverse faster. But if a specialized primal method was being hit, then we probably just want to call the primal and invert separately since that primal is probably more efficient for that type than lu.

@sethaxen
Copy link
Member Author

e.g. something like this:

function rrule(::typeof(det), x::Union{Number, AbstractMatrix})
    F = if which(det, Tuple{typeof(x)}) === which(det, Tuple{AbstractMatrix})
        lu(x; check = false)
    else
        x
    end
    Ω = det(F)
    function det_pullback(ΔΩ)
        return NO_FIELDS, Ω * ΔΩ * inv(F)'
    end
    return Ω, det_pullback
end

This is type-unstable though, and the rrule is 10x slower, though the pullback is a little faster.

@ettersi
Copy link
Contributor

ettersi commented Jul 16, 2020

It seems to me that the "correct" way to distinguish these cases is just standard dispatch:

function rrule(::typeof(det), x::AbstractMatrix)
    F = lu(x; check = false)
    Ω = det(F)
    function det_pullback(ΔΩ)
        return NO_FIELDS, Ω * ΔΩ * inv(F)'
    end
    return Ω, det_pullback
end

function rrule(::typeof(det), x::YourSpecialMatrixType)
    # Do whatever you have to do
end

The rationale would be that if "det via LU" is a good enough fallback for the primal function, then it should also be good enough for the derivative.

@nickrobinson251
Copy link
Contributor

Maybe related JuliaDiff/ChainRulesCore.jl#155

@nickrobinson251 nickrobinson251 added design Requires some design before changes are made type constraints Potentially raises a question about how tightly to constrain argument types for a rule. See #232 labels Dec 28, 2020
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
design Requires some design before changes are made type constraints Potentially raises a question about how tightly to constrain argument types for a rule. See #232
Projects
None yet
Development

No branches or pull requests

3 participants