Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pass kwargs to nutpie + create env.yml file #855

Merged
merged 8 commits into from
Dec 21, 2024

Conversation

AlexAndorra
Copy link
Contributor

Currently, some kwargs are not passed to Bayeux when fitting the model. This PR makes sure it does.
The only change is on line 270 of bambi/backend/pymc.py -- the rest is only formatting.

Also added an env file for creation with mamba. Ready for review!

@tomicapretto
Copy link
Collaborator

@AlexAndorra I think this is already handled? I have not looked deeply into the details, but have a look at this example:

https://bambinos.github.io/bambi/notebooks/alternative_samplers.html#blackjax

@AlexAndorra
Copy link
Contributor Author

Yeah I saw that @tomicapretto , but it doesn't seem to work:

data = bmb.load_data("sleepstudy")
model = bmb.Model('Reaction ~ Days', data)
kwargs = {
    "draws": 40,
    "chains": 2,
    "cores": 3,
}
results = model.fit(inference_method="nutpie", **kwargs)
results.posterior

will still give you 8 chains and 1000 draws

@AlexAndorra
Copy link
Contributor Author

Interestingly, the blackjax nuts example from the NB errors out:

ValueError: not enough values to unpack (expected 2, got 1)

@tomicapretto
Copy link
Collaborator

Interestingly, the blackjax nuts example from the NB errors out:

ValueError: not enough values to unpack (expected 2, got 1)

There is currently a problem with the dependencies. I just pinned them in a separate PR because it was being problematic. I think we need a new release of bayeux. Unforunately, I'm not familiar enough with it to work on it.

These are the dependencies I've pinned

bambi/pyproject.toml

Lines 41 to 46 in 7a18fb9

# TODO: Unpin this before making a release
jax = [
"bayeux-ml==0.1.14",
"blackjax==1.2.3",
"jax<=0.4.33",
"jaxlib<=0.4.33",

@tomicapretto
Copy link
Collaborator

Yeah I saw that @tomicapretto , but it doesn't seem to work:

data = bmb.load_data("sleepstudy")
model = bmb.Model('Reaction ~ Days', data)
kwargs = {
    "draws": 40,
    "chains": 2,
    "cores": 3,
}
results = model.fit(inference_method="nutpie", **kwargs)
results.posterior

will still give you 8 chains and 1000 draws

Interesting, I'll double check what's happening

@AlexAndorra
Copy link
Contributor Author

There is currently a problem with the dependencies. I just pinned them in a separate PR because it was being problematic. I think we need a new release of bayeux. Unforunately, I'm not familiar enough with it to work on it.

Ooooh, I definitely need to do that on my branch then! Shouldn't we merge that into main while it's an issue?
Maybe @ColCarroll can help with the Bayeux release?

Interesting, I'll double check what's happening

I think this is ignored silently by Bayeux because not passed explicitely. The changes I've done in this PR solve it, but may not cover all the cases. They will though, once I add Colin's suggestion from above

@tomicapretto

This comment was marked as resolved.

@AlexAndorra
Copy link
Contributor Author

I think so @tomicapretto , because then Bambi has to pass them explicitely to bx.sample when using nutpie. Do you confirm it works with BlackJAX once dependencies are pinned?

@tomicapretto

This comment was marked as resolved.

@tomicapretto
Copy link
Collaborator

I resolved my previous two comments because I realized they were not correct. The reason for kwargs in

idata = bx_sampler(seed=jax_seed, **kwargs)

not containing chains, draws, tune, and cores is because they are also keyword arguments of the Model.fit() method, so Python does not include them in the kwargs dictionary.

@tomicapretto
Copy link
Collaborator

I see the following alternatives

  1. Accept the changes proposed by @AlexAndorra, that always pass draws, tune, chains, and cores to the sampling function. If bayeux just ignores them when they are not expected by the underlying sampler, then it's all good.
  2. Change the signature of Model.fit() to accept arguments passed to PyMC and other samplers as kwargs. The downside I see here is lack of autocomplete.

@GStechschulte
Copy link
Collaborator

GStechschulte commented Nov 9, 2024

I am late to the party @tomicapretto. Regarding nutpie, the kwargs returned is not "pretty"

bmb.inference_methods.get_kwargs("nutpie")
{<function nutpie.compiled_pyfunc.from_pyfunc(ndim: int, make_logp_fn: Callable, make_expand_fn: Callable, expanded_dtypes: list[numpy.dtype], expanded_shapes: list[tuple[int, ...]], expanded_names: list[str], *, initial_mean: numpy.ndarray | None = None, coords: dict[str, typing.Any] | None = None, dims: dict[str, tuple[str, ...]] | None = None, shared_data: dict[str, typing.Any] | None = None)>: {'ndim': 1,
  'make_logp_fn': <function bayeux._src.mcmc.nutpie._NutpieSampler._get_aux.<locals>.make_logp_fn()>,
  'make_expand_fn': <function bayeux._src.mcmc.nutpie._NutpieSampler.get_kwargs.<locals>.make_expand_fn(*args, **kwargs)>,
  'expanded_shapes': [(1,)],
  'expanded_names': ['x'],
  'expanded_dtypes': [numpy.float64]},
 <function nutpie.sample.sample(compiled_model: nutpie.sample.CompiledModel, *, draws: int = 1000, tune: int = 300, chains: int = 6, cores: Optional[int] = None, seed: Optional[int] = None, save_warmup: bool = True, progress_bar: bool = True, low_rank_modified_mass_matrix: bool = False, init_mean: Optional[numpy.ndarray] = None, return_raw_trace: bool = False, blocking: bool = True, progress_template: Optional[str] = None, progress_style: Optional[str] = None, progress_rate: int = 100, **kwargs) -> arviz.data.inference_data.InferenceData>: {'draws': 1000,
  'tune': 300,
  'chains': 8,
  'cores': 8,
  'seed': None,
  'save_warmup': True,
  'progress_bar': True,
  'low_rank_modified_mass_matrix': False,
  'init_mean': None,
  'return_raw_trace': False,
  'blocking': True,
  'progress_template': None,
  'progress_style': None,
  'progress_rate': 100},
 'extra_parameters': {'flatten': <function bayeux._src.mcmc.nutpie._NutpieSampler._get_aux.<locals>.flatten(pytree)>,
  'unflatten': <jax._src.util.HashablePartial at 0x3299964e0>,
  'return_pytree': False}}

It is a nested dictionary, where keys are objects. A nested dictionary isn't a problem per say, e.g passing nested args to Blackjax NUTS.

kwargs = {
    "adapt.run": {"num_steps": 500},
    "num_chains": 4,
    "num_draws": 250,
    "num_adapt_draws": 250
}

@tomicapretto
Copy link
Collaborator

I think so @tomicapretto , because then Bambi has to pass them explicitely to bx.sample when using nutpie. Do you confirm it works with BlackJAX once dependencies are pinned?

Not sure what you mean with "it works". If it is passing arguments to BlackJAX, it's always worked. If it is the tests, yes, now it works after pinning deps.

@tomicapretto
Copy link
Collaborator

@GStechschulte I think that is not the problem (see #855 (comment) and my resolved comments)

@AlexAndorra
Copy link
Contributor Author

AlexAndorra commented Nov 9, 2024 via email

@tomicapretto
Copy link
Collaborator

@AlexAndorra I'm going to incorporate your changes, just modified things a bit. The environment goes under a conda-envs directory. I'm doing the same I saw here in PyMC https://github.com/pymc-devs/pymc/tree/main/conda-envs.

Could you install from your branch and try to run nutpie passing those kwargs? If tests are OK, and what you run works, then I'll merge.

Copy link

Check out this pull request on  ReviewNB

See visual diffs & provide feedback on Jupyter Notebooks.


Powered by ReviewNB

@AlexAndorra
Copy link
Contributor Author

It's all working now @tomicapretto 🥳
Just tested and reran the alternative_samplers NB. I think we can merge!

@GStechschulte
Copy link
Collaborator

GStechschulte commented Dec 11, 2024

Looks like we just need to run black 😄

@AlexAndorra
Copy link
Contributor Author

Ah indeed, thanks @GStechschulte ! Will do that locally, and see if I can run the tests locally too before I push

@AlexAndorra
Copy link
Contributor Author

I actually caught an issue with nutpie when cores and chains were None. I fixed it, and now the tests, as well as black and pylint work locally. I think we can merge once this turns green here 🥳
(can't do it myself as I don't have write access)

@ColCarroll
Copy link
Collaborator

thanks for sticking with the long breakage i caused 😅 -- i ran the tests, but will wait for tomas to merge!

@tomicapretto
Copy link
Collaborator

Very glad to see all green guys! Thanks for all the work.

I've just added a comment regarding which version of bayeux-ml we use. After we sort that out, we can merge.

@tomicapretto
Copy link
Collaborator

I'm going to work on this

@tomicapretto tomicapretto self-assigned this Dec 16, 2024
@AlexAndorra
Copy link
Contributor Author

AlexAndorra commented Dec 16, 2024

I just pushed the changes @tomicapretto ! Should be good to merge now

@tomicapretto
Copy link
Collaborator

Self reminder to add a test for this (im thinking about nutpie, I think it would be good to test against it as it's a good backend and people seem to like it, I do).

@tomicapretto
Copy link
Collaborator

Will merge after CI passes

@tomicapretto tomicapretto merged commit 27f8136 into bambinos:main Dec 21, 2024
4 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants