-
Notifications
You must be signed in to change notification settings - Fork 3
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Import more rules from ChainRules
?
#249
Comments
I'm broadly in favour of selectively importing more rules. I'm working on the interface to import rules for ChainRules.jl at the minute, but the broad picture will remain as it currently is: making use of methods of Tapir.rrule!!(::CoDual{typeof(*)}, ::CoDual{<:AbstractMatrix}, ::CoDual{<:AbstractMatrix}) Rather, we would define rules for concrete types, or (small) finite unions of concrete types, and call out to Tapir.rrule!!(::CoDual{typeof(*)}, ::CoDual{Matrix{P}}, ::CoDual{Matrix{P}}) where {P<:IEEEFloat}
I can think of two questions for any given signature:
On 1, we have to consider the maintenance burden + change of making mistakes vs performance. For 2: for anything involving heap-allocated data, making use of a CR
On a technical note, this actually isn't the whole story, it's just one facet. There are several more things that can be done at the Julia SSA IR level to improve performance in these cases that do not involve vectorisation (although may impact how vectorisable the code is as a side-effect). See #156 |
Tapir implements rules at a very low level at the moment. This design choice helps reduce the burden of writing and maintaining rules, thanks to the small number of primitives that require manually written rules. The good news is that Tapir seems to have excellent performance even with most of its rules derived automatically by Tapir's autodiff transform (known as
DerivedRule
).However, this design choice occasionally (it seems most cases are due to one common root cause, i.e. vectorisation of specific loops) gets in the way of Julia's and LLVM's compiler optimisation passes. Our CI benchmarks (e.g.,
sum
,kron
,kron_view_sum
, andgp_lml
) reflect this. Meanwhile, Zygote, which imports most (or all) rules fromChainRules
, performs well in these test cases.Given that
ChainRules
is pretty well tested, should we import a more significant number of its rules into Tapir by default? If so, what specific rules should we import?The text was updated successfully, but these errors were encountered: