Skip to content

Commit 2909b30

Browse files
authored
Merge 1005d2d into 226d141
2 parents 226d141 + 1005d2d commit 2909b30

File tree

2 files changed

+29
-9
lines changed

2 files changed

+29
-9
lines changed

arviz/data/base.py

Lines changed: 24 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import functools
44
import warnings
55
from copy import deepcopy
6-
from typing import Any, Dict, List
6+
from typing import Any, Callable, Dict, List, Optional, Tuple, TypeVar, Union
77

88
import numpy as np
99
import pkg_resources
@@ -21,6 +21,8 @@
2121

2222
CoordSpec = Dict[str, List[Any]]
2323
DimSpec = Dict[str, List[str]]
24+
RequiresArgTypeT = TypeVar("RequiresArgTypeT")
25+
RequiresReturnTypeT = TypeVar("RequiresReturnTypeT")
2426

2527

2628
class requires: # pylint: disable=invalid-name
@@ -32,19 +34,34 @@ class requires: # pylint: disable=invalid-name
3234
missing. Both functionalities can be combined as desired.
3335
"""
3436

35-
def __init__(self, *props):
36-
self.props = props
37-
38-
def __call__(self, func): # noqa: D202
37+
def __init__(self, *props: Union[str, List[str]]) -> None:
38+
self.props: Tuple[Union[str, List[str]], ...] = props
39+
40+
# Until typing.ParamSpec (https://www.python.org/dev/peps/pep-0612/) is available
41+
# in all our supported Python versions, there is no way to simultaneously express
42+
# the following two properties:
43+
# - the input function may take arbitrary args/kwargs, and
44+
# - the output function takes those same arbitrary args/kwargs, but has a different return type.
45+
# We either have to limit the input function to e.g. only allowing a "self" argument,
46+
# or we have to adopt the current approach of annotating the returned function as if
47+
# it was defined as "def f(*args: Any, **kwargs: Any) -> Optional[RequiresReturnTypeT]".
48+
#
49+
# Since all functions decorated with @requires currently only accept a single argument,
50+
# we choose to limit application of @requires to only functions of one argument.
51+
# When typing.ParamSpec is available, this definition can be updated to use it.
52+
# See https://github.com/arviz-devs/arviz/pull/1504 for more discussion.
53+
def __call__(
54+
self, func: Callable[[RequiresArgTypeT], RequiresReturnTypeT]
55+
) -> Callable[[RequiresArgTypeT], Optional[RequiresReturnTypeT]]: # noqa: D202
3956
"""Wrap the decorated function."""
4057

41-
def wrapped(cls, *args, **kwargs):
58+
def wrapped(cls: RequiresArgTypeT) -> Optional[RequiresReturnTypeT]:
4259
"""Return None if not all props are available."""
4360
for prop in self.props:
4461
prop = [prop] if isinstance(prop, str) else prop
4562
if all([getattr(cls, prop_i) is None for prop_i in prop]):
4663
return None
47-
return func(cls, *args, **kwargs)
64+
return func(cls)
4865

4966
return wrapped
5067

arviz/data/io_pymc3.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -214,10 +214,13 @@ def log_likelihood_vals_point(self, point, var, log_like_fun):
214214
log_like_val = np.where(mask, np.nan, log_like_val)
215215
return log_like_val
216216

217-
@requires("trace")
218-
@requires("model")
219217
def _extract_log_likelihood(self, trace):
220218
"""Compute log likelihood of each observation."""
219+
if self.trace is None:
220+
return None
221+
if self.model is None:
222+
return None
223+
221224
# If we have predictions, then we have a thinned trace which does not
222225
# support extracting a log likelihood.
223226
if self.log_likelihood is True:

0 commit comments

Comments
 (0)