You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Tapir implements rules at a very low level at the moment. This design choice helps reduce the burden of writing and maintaining rules, thanks to the small number of primitives that require manually written rules. The good news is that Tapir seems to have excellent performance even with most of its rules derived automatically by Tapir's autodiff transform (known as DerivedRule).
However, this design choice occasionally (it seems most cases are due to one common root cause, i.e. vectorisation of specific loops) gets in the way of Julia's and LLVM's compiler optimisation passes. Our CI benchmarks (e.g., sum, kron, kron_view_sum, and gp_lml) reflect this. Meanwhile, Zygote, which imports most (or all) rules from ChainRules, performs well in these test cases.
Given that ChainRules is pretty well tested, should we import a more significant number of its rules into Tapir by default? If so, what specific rules should we import?
The text was updated successfully, but these errors were encountered:
I'm broadly in favour of selectively importing more rules. I'm working on the interface to import rules for ChainRules.jl at the minute, but the broad picture will remain as it currently is: making use of methods of rrule to implement methods of rrule!! is fine, but we'll need to restrict the argument types to ones that we are happy with. So, for example, if a method of rrule has the signature rrule(::typeof(*), ::AbstractMatrix, ::AbstractMatrix), the appropriate thing to do is not to define a method of Tapir.rrule!! with signature
Rather, we would define rules for concrete types, or (small) finite unions of concrete types, and call out to rrule inside. For example, we might define a method of rrule!! with argument types
Tapir.rrule!!(::CoDual{typeof(*)}, ::CoDual{Matrix{P}}, ::CoDual{Matrix{P}}) where {P<:IEEEFloat}
If so, what specific rules should we import?
I can think of two questions for any given signature:
is a rule for this sufficiently beneficial to warrant adding a rule for it?
is making use of the rule in CR the best way to implement it?
On 1, we have to consider the maintenance burden + change of making mistakes vs performance. For 2: for anything involving heap-allocated data, making use of a CR rrule will probably involve slightly more allocations than would be necessary if we wrote an rrule!! directly. Consequently, for really simple functions like sum, we might want to consider just writing our own rule.
(it seems most cases are due to one common root cause, i.e. vectorisation of specific loops) gets in the way of Julia's and LLVM's compiler optimisation passes
On a technical note, this actually isn't the whole story, it's just one facet. There are several more things that can be done at the Julia SSA IR level to improve performance in these cases that do not involve vectorisation (although may impact how vectorisable the code is as a side-effect). See #156
Tapir implements rules at a very low level at the moment. This design choice helps reduce the burden of writing and maintaining rules, thanks to the small number of primitives that require manually written rules. The good news is that Tapir seems to have excellent performance even with most of its rules derived automatically by Tapir's autodiff transform (known as
DerivedRule
).However, this design choice occasionally (it seems most cases are due to one common root cause, i.e. vectorisation of specific loops) gets in the way of Julia's and LLVM's compiler optimisation passes. Our CI benchmarks (e.g.,
sum
,kron
,kron_view_sum
, andgp_lml
) reflect this. Meanwhile, Zygote, which imports most (or all) rules fromChainRules
, performs well in these test cases.Given that
ChainRules
is pretty well tested, should we import a more significant number of its rules into Tapir by default? If so, what specific rules should we import?The text was updated successfully, but these errors were encountered: