Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Work in Progress DO NOT MERGE] Port math library from clojure to python #1893

Open
wants to merge 26 commits into
base: edge
Choose a base branch
from

Conversation

jucor
Copy link
Contributor

@jucor jucor commented Jan 30, 2025

⚠️ [DO NOT MERGE, FILED HERE FOR TRACKING PROGRESS!] ⚠️

Context

We (polis core team and advisors) have been discussing for the least few years about whether Clojure was still the optimal language for the math library, given the evolution of the landscape and of polis needs.
This came up again when @metasoarous raised potential performance issues #1579 (comment) .

Porting any codebase, let alone in this case 7000+ lines of scientific Clojure written by a very smart developer (hats off @metasoarous !), with tons of embedded real-world safeguards, and ten years of battle-testing is a Very Big Endeavour, super risky. Most of all, we need to keep all the domain knowledge that is embedded in the current codebase. This is not a rewrite from a scratch, but a port!

So, crazy, but there's a lot to gain (massive ML ecosystem: people, libraries, etc), so we would be remiss not to at least explore how far we can go. Worst that can happen is that this completely fails, we lost time and I've got egg on my face. I'll mitigate the former by still working on the new LLM features with @colinmegill, and for the latter, well, I can live with that :)

So let's go!

Plan

I'll be focusing first on the core functionality: the math. Once that is clear and done, then I will work on the poller and runner. I'll be using numpy for all vector operations.

Preparation

  • 0/ beside the current existing unit tests, have an integration test that updates through a full conversation.
  • 1a/ hack a basic json serialize/deserialize clojure function for vectors/matrices, and similar for python
  • 1b/ write a basic clojure function that serializes its arguments, calls a python command , deserializes the return

Core iteration

then, iterating this way:

  • 2/ identify one core function in the clojure game (pre-filtering, core PCA, etc) , and its unit tests if any
  • 3/ implement that exact function in python with numpy, with unit tests, possibly adding some
  • 3/ replace the clojure function by a serialization + call to the replacing python function + deserialization
  • 4/ check that the end result of the full pipeline with and without python is still the same, on one or more real conversations from the database (in addition to the unit tests ofc).
  • 5/ Go to 2 with another core function, then climb up the call tree.

Expectations

It'll be a real slog at first for step 0 and 1, getting familiar with running the various functions one by one in clojure when needed (i.e. without the realtime poller etc, which I'll keep for the end), but pace should then increase.

Performance should also mechanically improve, as per #1579 and #1062 and #1580 .

Math part will be fun -- although we might start to see some small numerical differences appearing as we go, hopefully keeping them small.

@jucor jucor changed the title [Work in Progress DO NOT MERGE] Update math library from clojure to python [Work in Progress DO NOT MERGE] Port math library from clojure to python Jan 30, 2025
@ballPointPenguin
Copy link
Contributor

This would be amazing!

jucor added 5 commits January 31, 2025 09:01
We serialize to JSON and, optionally, to temp file.
We probably will need to expand to take kw-args into account.
And some unit tests for fun, and even coverage!
The makefile will avoid having to remember the test syntax, making them
frictionless to run.
The partial runs are a way to keep focusing on just the output of
whichever part we want -- although it doesn't save the biggest
drag: clojure launch time.
@jucor
Copy link
Contributor Author

jucor commented Jan 31, 2025

We now have serialization of arguments in Clojure, with tests (in Docker). And Serialization+Deserialization of arguments in Python, with tests (locally, not in Docker).
And a full running makefile calling those tests.

Next: wrapping a clojure function with a serializer of input and output, and a python script to replay that call but to the python function, and check the output matches.

@jucor
Copy link
Contributor Author

jucor commented Jan 31, 2025

I am adding notes to README.pythonport.md, but also pasting them in this thread for visibilty.

PCA

Let's start with the PCA, as it's well known and nicely isolated.

Let's first:

  • Task 1: document the calling graph of PCA to get a lay of the land
  • Task 2: add more clojure tests to the PCA functions
  • Task 3: code the Python PCA and the tests
  • Task 4: wrap the clojure PCA call to store its input and output, and run it on a basic conversation.
  • Task 5: load the input from clojure and run the PCA on that.
  • Task 6: load a big conversation into the database
  • Task 7: record PCA input/output on that, and compare.

Checking its clojure calling graph in pca.clj:

graph TD
    wrapped_pca[wrapped-pca] --> powerit_pca[powerit-pca]
    powerit_pca --> power_iteration[power-iteration]
    powerit_pca --> rand_starting_vec[rand-starting-vec]
    powerit_pca --> factor_matrix[factor-matrix]
    
    power_iteration --> xtxr[xtxr]
    power_iteration --> repeatv[repeatv]
    
    factor_matrix --> proj_vec[proj-vec]
    
    xtxr --> repeatv
    
    pca_project[pca-project]
    
    sparsity_aware_project_ptpts[sparsity-aware-project-ptpts] --> sparsity_aware_project_ptpt[sparsity-aware-project-ptpt]
    
    pca_project_cmnts[pca-project-cmnts] --> sparsity_aware_project_ptpts
Loading

- Add cloverage
- Explicitely deactivate conv-man tests, which require a full database,
  as warned in the comments.
- Fix a bug in the new test runner when parsing command line arguments.
@jucor
Copy link
Contributor Author

jucor commented Jan 31, 2025

We want to make sure matching our python and clojure tests cover as much code as possible, so we need to measure it. Hence I have added coverage measurement to the test runner.

Now focusing on PCA coverage, we see some codepaths not tested, so let’s add tests for that.

image image

I am trying to get full _branch_ coverage of the function
`power-iteration`, but failing to. There must be some corner cases I am
not surfacing (and Cursor neither).
I will leave it aside for now, make a note that it needs fixed, and will
go to cover the entirely missing lines in the wrapped-pca function.
@jucor
Copy link
Contributor Author

jucor commented Jan 31, 2025

In spite of lots of PCA tests, I do not get full branch coverage in the key function power-iteration. To avoid going crazy, I will move to the simpler line-coverage of other functions, and will come back to branch coverage.

I do wonder whether we really need full branch coverage, knowing that the PCA method, while currently being an elegantly manually coded power-iteration, might eventually be passed to SKlearn or Lapack. But before that, we will need to check whether the iterative nature of the power-iteration is exploited for incremental updates of the conversation, as I suspect it is. So for now, we will stick to power-iteration.

jucor added 2 commits January 31, 2025 15:21
See compdemocracy#1894 and
`README.pythonport.md` in this commit for explanations of why
the code is unreachable.
@jucor
Copy link
Contributor Author

jucor commented Jan 31, 2025

pca-project does not seem to be called anywhere. conversation.clj calls sparsity-aware-project-ptpt[s] instead, which makes sense. Therefore, deleting this function to remove dead code.

jucor added 4 commits February 3, 2025 16:07
I would rather keep it in the math folder, but calva does not find it
there...
This will allow to reflect changes or new files added in the container,
which is easier when developing.
@NewJerseyStyle
Copy link
Contributor

Awesome! 🔥 Using JAX?

@jucor
Copy link
Contributor Author

jucor commented Feb 6, 2025

Thanks for the enthusiasm :)

At first, using numpy and sklearn for the port, because we are so far using CPU nodes (constant worker polling the database every few seconds), and as far as I know Jax is not necessarily faster than numpy on those type of nodes. Plus don't necessarily want the cost of running a GPU node 24/7 for a PCA and a k-means 😅 Also thinking the codebase will be more accessible without Jax, for now.

However the idea is to keep with the current functional- style code organization (thanks Clojure!) which will then lend itself well to a potential move to jax.numpy once we start seeing the need (whether through scale of matrices, or through algorithmic ideas -- for example RL when we review comment routing ;) )

I'm curious: what use case did you have in mind for a Jax port? Now's a great time to send tons of ideas :)

@NewJerseyStyle
Copy link
Contributor

Ah, sorry it was just a random thought as my Google friend always throw Jax on any topic 😆 Maybe some JIT compilers will work better than Jax for optimizing numpy

@jucor
Copy link
Contributor Author

jucor commented Feb 6, 2025

Yes, for example numba is a very powerful jit compiler these days. Does not necessarily accelerate the core functions of numpy which are often outsourced to BLAS (or Intel MKL if you have it set up), but all the python code outside indeed.

Jax is really powerful indeed, when you have hardware accelerators :) And I love the split-stream RNG for reproducibility, as well as the instruction-level parallelization with vmap. Super elegant. I'm using it on another project for Bayesian inference on Ordinary Differential Equations, it's awesome.

The marginal cost to tech into account, however, is that it does come with extra complexity (see e.g. https://docs.jax.dev/en/latest/notebooks/Common_Gotchas_in_JAX.html ), and can get difficult to debug. Within Google Deepmind I used to have tons Research Engineer colleagues, which are terrific to help sort out such issues, but which we do not have outside of Google 😆 Plus it's not yet version 1.0, so the API can and does shift from version to version, thus requiring extra maintenance -- unlike numpy which is stable (well, except when moving from 1.x to 2.0 !)

It's definitely a cost worth paying when the time comes, though. I'm dreaming of having full-on differentiable reinforcement learning environments with the conversations...

@NewJerseyStyle
Copy link
Contributor

I am not familiar with libraries, my comfort zone was to optimize specific algorithms myself with C++/rust using SIMD or maybe highway based on deployment environment, guess I need to adopt modern development practice.

My friend always try to convince me using Jax, but I feel so comfortable with PyTorch and maybe also PyTorch/XLA for TPUs never really used Jax (except for Neural Tangent I guess).

I am not sure about the differentiable reinforcement learning environments with the conversations... I prefer a rule based/ranking algorithm based routing strategy with a metric reflecting opinion diversity as input to the routing strategy. [sidetracked] Working with people in CivicTechTO for user research to ensure the deliverables are aligned with the need of end user IRL (and a data lakehouse based pipeline as an modularized backend upgrade to our polis fork so that ranking algorithm can be integrated/test other routing solution).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants