-
-
Notifications
You must be signed in to change notification settings - Fork 127
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add vi_kwargs to fit #553
add vi_kwargs to fit #553
Conversation
Codecov Report
@@ Coverage Diff @@
## main #553 +/- ##
==========================================
- Coverage 86.93% 86.50% -0.43%
==========================================
Files 32 32
Lines 2655 2667 +12
==========================================
- Hits 2308 2307 -1
- Misses 347 360 +13
📣 We’re building smart automated test selection to slash your CI/CD build times. Learn more |
@@ -104,7 +105,9 @@ def run( | |||
**kwargs, | |||
) | |||
elif method.lower() == "vi": | |||
result = self._run_vi(**kwargs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: add exception somewhere above that if vi_kwargs is passed but method is not VI either a warning or exception is provided
Also nit: It feels like if theres vi kwargs, there should also be mcmc kwargs to mirror the API
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agree with the first comment. Not so sure about the second. I understand that having vi_kwargs
makes us think why not mcmc_kwargs
as well. But it also makes me wonder whether we need mcmc_kwargs
, vi_kwargs
, and laplace_kwargs
, which makes the signature much longer, or if it is better to replace method
with a better name that does not overlap with parameters in pymc.fit()
and keep a single **kwargs
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As method="mcmc" is the most common method, is very convenient to allow for kwargs. Having mcmc_kwargs will turn super annoying. Instead I see vi (and laplace) as unusual methods so is ok if they deviate a little from that pattern. Having said that changing the method argument is an option. I just wanted to avoid breaking backward compatibility, but given that mcmc is the most common method probably that change will most likely unnoticed by most users.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree that backwards compatibility may be a concern. Is it too overkill to have a deprecation warning now, saying we'll change the name of the method
parameter, and clean up the signature later?
I propose inference_method
as an alternative
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK, let's do that
model.fit
has an argument calledmethod
as well aspymc.fit
to avoid those name clashing, this PR introduces avi_kwargs
argument that is passed topymc.fit
. Asmethod="mcmc"
is the most common method, we allow passing arguments as kwargs, instead of adding amcmc_kwargs argument
. This also add a missing test.