Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Import more rules from ChainRules ? #249

Open
yebai opened this issue Sep 11, 2024 · 1 comment
Open

Import more rules from ChainRules ? #249

yebai opened this issue Sep 11, 2024 · 1 comment
Labels
enhancement (performance) Would reduce the time it takes to run some bit of the code

Comments

@yebai
Copy link
Contributor

yebai commented Sep 11, 2024

Tapir implements rules at a very low level at the moment. This design choice helps reduce the burden of writing and maintaining rules, thanks to the small number of primitives that require manually written rules. The good news is that Tapir seems to have excellent performance even with most of its rules derived automatically by Tapir's autodiff transform (known as DerivedRule).

However, this design choice occasionally (it seems most cases are due to one common root cause, i.e. vectorisation of specific loops) gets in the way of Julia's and LLVM's compiler optimisation passes. Our CI benchmarks (e.g., sum, kron, kron_view_sum, and gp_lml) reflect this. Meanwhile, Zygote, which imports most (or all) rules from ChainRules, performs well in these test cases.

Given that ChainRules is pretty well tested, should we import a more significant number of its rules into Tapir by default? If so, what specific rules should we import?

@yebai yebai added the enhancement (performance) Would reduce the time it takes to run some bit of the code label Sep 11, 2024
@willtebbutt
Copy link
Member

willtebbutt commented Sep 11, 2024

I'm broadly in favour of selectively importing more rules. I'm working on the interface to import rules for ChainRules.jl at the minute, but the broad picture will remain as it currently is: making use of methods of rrule to implement methods of rrule!! is fine, but we'll need to restrict the argument types to ones that we are happy with. So, for example, if a method of rrule has the signature rrule(::typeof(*), ::AbstractMatrix, ::AbstractMatrix), the appropriate thing to do is not to define a method of Tapir.rrule!! with signature

Tapir.rrule!!(::CoDual{typeof(*)}, ::CoDual{<:AbstractMatrix}, ::CoDual{<:AbstractMatrix})

Rather, we would define rules for concrete types, or (small) finite unions of concrete types, and call out to rrule inside. For example, we might define a method of rrule!! with argument types

Tapir.rrule!!(::CoDual{typeof(*)}, ::CoDual{Matrix{P}}, ::CoDual{Matrix{P}}) where {P<:IEEEFloat}

If so, what specific rules should we import?

I can think of two questions for any given signature:

  1. is a rule for this sufficiently beneficial to warrant adding a rule for it?
  2. is making use of the rule in CR the best way to implement it?

On 1, we have to consider the maintenance burden + change of making mistakes vs performance. For 2: for anything involving heap-allocated data, making use of a CR rrule will probably involve slightly more allocations than would be necessary if we wrote an rrule!! directly. Consequently, for really simple functions like sum, we might want to consider just writing our own rule.

(it seems most cases are due to one common root cause, i.e. vectorisation of specific loops) gets in the way of Julia's and LLVM's compiler optimisation passes

On a technical note, this actually isn't the whole story, it's just one facet. There are several more things that can be done at the Julia SSA IR level to improve performance in these cases that do not involve vectorisation (although may impact how vectorisable the code is as a side-effect). See #156

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement (performance) Would reduce the time it takes to run some bit of the code
Projects
None yet
Development

No branches or pull requests

2 participants