-
Notifications
You must be signed in to change notification settings - Fork 9
Numba speedup for wiring + log potentials #133
Numba speedup for wiring + log potentials #133
Conversation
9a8b222
to
a159f1a
Compare
Codecov Report
@@ Coverage Diff @@
## master #133 +/- ##
=========================================
Coverage 100.00% 100.00%
=========================================
Files 13 13
Lines 917 950 +33
=========================================
+ Hits 917 950 +33
Continue to review full report at Codecov.
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The current implementation is valid, but one issue is to write a new factor group the developer needs to understand various bits of wiring compilation. This seems unnecessary. Instead, it would be best if we implement the wiring compilation once within the parent FactorGroup
class and everything is automatically taken care of in new factor groups.
One possible way to do this:
- For each type of factor, implement a static method
compile_wiring_numba
which takes some inputs and do numba wiring compilation. - For each type of factor, specify a property
wiring_compilation_arguments
, similar to the currentinference_arguments
in the wirings. - In
FactorGroup
, add an optionaluse_numba=True
flag tocompile_wiring
. Ifuse_numba
isTrue
, we do something like:
wiring_compilation_arguments = {
key: getattr(self, key) for key in self.factor_type.wiring_compilation_arguments
}
wiring = factor_type.compile_wiring_numba(
vars_to_starts=vars_to_starts,
**wiring_compilation_arguments
)
to get the wiring.
4. Get rid of all the customized wiring compilation implementations in the various factor groups.
In the process it would also be best if we can consolidate the number of arguments we need (for example if we make variables_for_factors
as tuple of tuples we can get rid of factor_sizes
and num_factors
).
@StannisZhou I am fine with the |
Then we can change the existing compile_wiring function to be functional. Seems like with the wiring_compilation_arguments it wouldn't make things harder to use so I'm fine with that |
@StannisZhou I have pushed a commit that should address your comments. However there is something weird with having the compile_wiring_arguments at the Factors level (note: I had to make this @staticmethod because a @Property is not iterable) because now |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Main comment is to use inspect
to get the necessary arguments. Other bits LGTM
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One more minor comment. LGTM otherwise! And remember to get coverage back to 100% after replacing the assert with raise
) | ||
|
||
@nb.jit(parallel=False, cache=True, fastmath=True, nopython=True) | ||
def _compile_var_states_numba( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there a reason you make the caller allocate these arrays? In general t's cleaner and less likely to result in error to allocate return arrays inside numba rather that mutating a passed in array. You can get a very small optimization by re-using arrays between calls (so highly performance sensitive code it can be useful), but you're not doing that here. You can refer to dtype of incoming arrays as well and copy that.
This PR is the continuation of #129 and part of our efforts to speed up the adding of FactorGroups and the wiring compilation.
As #129 has moved most of the wiring computation to the FactorGroup level, we can now use
numba
for fast computation of these wiringsAs a result: