-
-
Notifications
You must be signed in to change notification settings - Fork 2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Require all step methods to return stats #6313
Conversation
3bc8c36
to
f21ef72
Compare
Codecov Report
Additional details and impacted files@@ Coverage Diff @@
## main #6313 +/- ##
==========================================
+ Coverage 94.17% 94.22% +0.05%
==========================================
Files 111 111
Lines 23908 23865 -43
==========================================
- Hits 22515 22487 -28
+ Misses 1393 1378 -15
|
The reason for this change is the resulting simplification of code, including simpler branching and less type ambiguity. At the same time it allowed for fixing of a lot of type hints and method signatures on step methods. Closes pymc-devs#6270
f21ef72
to
9545b62
Compare
def astep( | ||
self, apoint: RaveledVars, point: PointType, *args | ||
) -> Union[RaveledVars, Tuple[RaveledVars, StatsType]]: | ||
def astep(self, apoint: RaveledVars, *args) -> Tuple[RaveledVars, StatsType]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The exact composition of the args
is specified via the constructor parameters 🤐
def astep(self, q0: RaveledVars, logp) -> Tuple[RaveledVars, List[Dict[str, Any]]]: | ||
|
||
def astep(self, q0: RaveledVars, *args) -> Tuple[RaveledVars, StatsType]: | ||
logp = args[0] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
BinaryMetropolis
is not even covered by our tests!
pymc/step_methods/metropolis.py
Outdated
def astep(self, q0: RaveledVars, *args) -> Tuple[RaveledVars, StatsType]: | ||
logp: Callable[[RaveledVars], np.ndarray] = args[0] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
BinaryGibbsMetropolis
is also not covered by tests!
if rng is None: | ||
rng = nr | ||
return rng.normal(scale=self.s) | ||
return (rng or nr).normal(scale=self.s) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These changes resolve two mypy
errors:
rng
is no longer reassigned, and- the
(rng or nr)
has typeGenerator | Module("numpy.random")
which resolvesItem "None" of "Optional[Generator]" has no attribute "normal"
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks much less readable though
We shouldn't remove the trust_input. It allows evaluating the Aesara function much faster because it doesn't check the input types are valid/need to be converted every time it's called |
But we're not using the |
Ah I see, it's some feature of the function... I'll investigate and drop that commit. Anything else you want to get changed, @ricardoV94 ? |
9545b62
to
9e09aa4
Compare
9e09aa4
to
da7e14a
Compare
The reason for these changes are the resulting simplification of code, including simpler branching and less type ambiguity.
This change may break custom step methods, but whoever had enough skill to implement a custom step method should have no trouble to append a
, []
to thereturn
of their(a)step
method.Closes #6270
Checklist
Major / Breaking Changes
step
/asteo
method.BlockedStep.generates_stats
attribute was removed.Bugfixes / New features
Compound.step
that supportednamedtuple
stats
which are in violation of theBlockedStep.step
signature.Docs / Maintenance
step
andastep
methods. (arraystep.py
now passing mypy!)