Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/cython 3 0 #36

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open

Conversation

nucccc
Copy link
Contributor

@nucccc nucccc commented Oct 19, 2023

This PR proposes a version I tried on my system in order to allow for cythonization of the codebase.

In cython 3.0 the are several functionalities working in pure python mode, basically allowing some legal python lines to be used to enable a better cythonization (such as i = cython.declare(cython.int) to indicate that i shall be transpiled to a C integer).

I added a pyproject.toml to let the build system temporarily install cython during the package build.

I also modified the setup.py file to import the cython build in a try statement. I tried with a boolean actually overwritten to false and that led to a normal installation.

An idea that I'm having is that it would be possible to add a function in the package to indicate whether or not the package being used is cythonized or not (I think cython already provides a native functionality to check that).

I then just declared some variables to be integers, and I ran some benchmarks (the code I used can be found in my repo benchmark_similarity_measures, which installs first one library and then the other to run the benchmarks).

The results on my machine with python 3.11 are recapped in this image, with the improvement being in %:

cython_improv

with numbers coming from the following csv

benchmark_name,master_avg,cython_avg,improv
area_between_two_curves_c1_c2,1.7039237545600463,1.6884736603800048,0.9067362397345489
frechet_c5_c6,2.0500493293800175,1.5103637769600209,26.325491035047925
area_between_two_curves_c5_c6,0.47237960685998587,0.47055805053996663,0.3856128193440721
dtw_c1_c2,0.9581847325999843,0.7472904101000131,22.009776958950333
frechet_c1_c2,1.0203404406200298,0.7517932900799679,26.319367521772836
curve_length_measure_c1_c2,0.051991073359977234,0.05167398480002703,0.6098903897496746
dtw_c5_c6,1.958265644539997,1.570278274120028,19.812805862256162
pcm_c1_c2,0.05582267186003264,0.05474561917999381,1.9294180019533789

Not all functions showcase a huge performance improvement, while previous benchmarks showcased a 40% performance improvement for frechet distance, which now decreased to 25%.

Probably that is because my previous benchmarks were executed against python 3.10, while python 3.11 greatly improved execution speed of native python code.

I hope this can be of interest.

@cjekel
Copy link
Owner

cjekel commented Nov 8, 2023

Hey! This is great! and totally missed my radar. I promise to look more in depth with this around the time I push in the other PR for the bugfix and do a new release.

Basically, I really want to read over the latest cython docs for best practices.

I wonder how much of the other performance gains were from the .pyx file changes you made as well.

@cjekel
Copy link
Owner

cjekel commented Nov 18, 2023

So, maybe it makes sense to fork this repo and go all out on performance. Honestly this library isn't that much to maintain (as there are not many commits), that I think I could also maintain a performant fork.

What are alternatives to cython?

  • a c++ version
  • jax
  • pytorch
  • numba

I don't really want to support a c++ port, but it would certainly have the most headroom for performance...

It turns out getting this ported to jax is messy. I heavily utilize in-place array operations in things like dtw and frechet. It's easy enough to do the minimum to get it working in jax, but it is very slow (making way too many array copies).

I haven't touched numba in sometime.

I happened to give the pytorch version a shot. Just porting dtw directly to plain pytorch was like 10-100x slower. Using torch.jit.script was like 2x slower. Something crazt happens with torch.jit.trace which was an order of magnitude faster that the numpy version, however appears to just make up numbers...

Any thoughts on this?

@nucccc
Copy link
Contributor Author

nucccc commented Nov 25, 2023

I'm quite busy, but I'll try to give you my humble opinion:

  • A C++ version would be great, but I'm still worried if bypassing numpy would lose some SIMD that that library may do (I don't know if it does). But C++ rules, and cython can be used to wrap it, together with several other tools.
  • JAX: I humbly have no idea what it is.
  • Pytorch: I used it to play with neural networks, and I don't know how this could work on such algorithms that involve such things as dynamic programming
  • numba: humbly never used

I sincerely proposed cython as this doesn't modify much the existing library (which is already popular and stable), while it could remove some "python overhead" from loops. I think that already would be a level of code optimization quite near to the one possibly provided by a pure C++ version, but that is just my thoughts I don't have benchmarks to prove that.

Just as a remark I link you this video on youtube (Speed up Python with hardware accelerators) in which between minute 16:20 and 16:45 pairwise distance benchmarks are shown, and the scipy cdist function results to be the fastest implementation, beating tensorflow, numba and numpy.

So I would say that the codebase has good foundations, I just thought that cdeclaring some variables and maybe passing arrays as memoryviews (something I still haven't tried yet) could lead to a very good level of performance.

I any case I think benchmarking and tracing are kings, as the first PR I made happened just from discovering one year ago that without invoking cdist the operations to build the pairwise distance invoked minkowski in pure python mode (which is superslow), and thus just reusing cdist in frechet as it was done in dtw gave a huge boost that at the time I could not imagine.

So I would just maintain the codebase as it is and cythonize correctly critical parts of it. Maybe also in a forked repo, or on a dedicated "develop" branch, I don't know. The functions that did not showcase a great level of improvement in my opinion could be modified to avoid pure python sums and subtractions, but that would be helped by tracing and would need some work. If possible in time I would try to see if cythonizing some things could tweak some performances, but that can be verified only through a worthy benchmarking activity.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants