-
Notifications
You must be signed in to change notification settings - Fork 20
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Refactor domains to be types #352
Conversation
@fehiepsi any idea whether the jax failure is real?
|
@fritzo Did you rerun the job? I can't observe the failure in the link. I observed this error when
I didn't find something wrong with the current implementation of binomial sampler in NumPyro. |
@fehiepsi Thanks for checking. I did rerun the test, and it passed. Before running I captured the failure output (in |
@eb8680 I'd like to perform this refactoring in two PRs:
The first PR is small, and the second PR will be big but easy to review. |
funsor/domains.py
Outdated
""" | ||
def __new__(cls, shape, dtype): | ||
assert isinstance(shape, tuple) | ||
class Domain(type): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would it be more Pythonic to define Domain
with typing.Generic
?
from typing import Generic, Tuple, TypeVar, Union
from typing_extensions import Literal # backport of typing.Literal in 3.8
S = TypeVar("S", bound=Tuple)
D = TypeVar("D", bound=Union[int, str])
class Domain(Generic[S, D]):
def __init__(self, shape: S, dtype: D):
self.shape = shape
self.dtype = dtype
# version 1, more concise
Bint = Domain[Tuple[()], TypeVar("DB", bound=int)]
Reals = Domain[S, Literal['real']]
# version 2, nicer printing/pickling
class Bint(Domain[Tuple[()], TypeVar("DB", bound=int)]):
pass
class Reals(Domain[S, Literal['real']]):
pass
def reals(*shape):
return Reals[Tuple[tuple(Literal[size] for size in shape)]]
def bint(size):
return Bint[Literal[size]]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My longer term vision is for Domain
to be the type of more general ground values in Funsor including Tuple[]
, eventually including pytrees. Therefore I think it no longer makes sense to paramterize Domain
by (shape
, dtype
).
Indeed I feel we are already shoe-horning bint()
and reals()
into an ill-fitting abstraction. For example to address #322 we will probably need heterogeneous size
attribute which is currently aliased to dtype
. Even that alias feels like a hack to me.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My longer term vision is for Domain to be the type of more general ground values in Funsor including Tuple[], eventually including pytrees
I'm definitely on board with this, but it seems like in that vision there's still a need for something like Reals
and Bint
(and hence my version of Domain
above, which we could call Shape
or Array
after Numpy?) to represent dependent tensor output types. Then we could just have the more general ground value base type be typing.Any
, or a very broad Union[int, float, str, Shape, Tuple]
, rather than a custom Domain
type?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm, I guess there's is a future where Bint
and Real
are more similar... I'll have to think about this more. I really don't like the way we currently use domain.dtype
as an alias for size
, and I'd prefer to pull these concepts apart.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm fine with separating Bint
and Real
, it's the BintType
/RealType
metaclasses in this implementation that I am wondering if we could replace with something more Pythonic based on typing.Generic
and typing.Literal
. To clarify, this version of Reals
simplified from my previous comment inherits directly from Generic
:
S = TypeVar("S", bound=Tuple[int, ...])
class Reals(Generic[S]): # subtypes are cons-hashed
def __init__(self, shape: S):
self.shape = shape
# dtype, etc. here
def reals(*shape): # for illustration
return Reals[Tuple[tuple(Literal[size] for size in shape)]]
It's more verbose, but conforms more closely to the spirit of typing
.
Also, I'm not sure pointwise functions annotated with the metaclass-based design (def f(x: Reals[1]): ...
) would be compatible with static type checkers when checked without the function
wrapper (no types parametrized directly by values are allowed, regardless of the details of their implementations) - to what extent should that serve as a design constraint?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I do like the idea of being more pythonic and trying to use Generic
. However it remains to be seen whether Generic
can serve our purpose. Issues include: we want to memoize via WeakValueDict
rather than lru_cache
imho; Generic
's implementation changes across Python versions, so it may be difficult to override __class_getitem__
to cast parameters to int while jitting). But you have more experience with typing than I, so let's pair code on it.
should [compatibility with type checkers] serve as a design constraint?
Yeah good question. I guess the objective of this refactoring is to replace Funsor's bespoke type annotations with more standard annotations and thereby make it easier for new Funsor users to learn and contribute. While that suggests mypy support is a secondary goal, I think it would be appreciated by IDE-based users. Let's see how hard it is, maybe we can add a mypy
test stage.
Also I'd like to see whether we can roughly follow NumPy types. Let's discuss tomorrow.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some other challenges with more pythonic type annotations are:
- some of
typing
's annotations simply disappear during runtime (mypy inspects the asts rather than the runtime values, so Funsor inspection may be more challenging) - there is not yet a standard way to inspect type hints; we may need to roll our own version-agnostic module similar to typing_inspect
Bint[5] # integers ranging in {0,1,2,3,4} | ||
|
||
To dispatch on domain type, we recommend either ``@singledispatch``, | ||
``@multipledispatch``, or ``isinstance(domain, BintType)``. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When would we dispatch on domain type with isinstance
rather than issubclass
? More broadly, when will domain types be instantiated, if ever?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Currently we dispatch via if domain.dtype == "real"
. I was thinking that if we are to generalize to other domain types (like Tuple[...]
) then it would be cleanest to use @singledispatch
. For lighter-weight uses (in my immediate refactoring plans) I will probably replace if x.dtype == "real"
with isinstance()
. Do you have a cleaner suggestion?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To be clear, we would never instantiate domains like Real
, but we do instantiate domain types like RealType
.
funsor/domains.py
Outdated
# in some JAX versions, shape can be np.int64 type | ||
if get_tracing_state() or funsor.get_backend() == "jax": | ||
shape = tuple(map(int, shape)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
note that if we switch to typing.Generic
we'll need to override __getitem__
or __class_getitem__
to do some casting.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good after our discussion today, just one nit
Thanks for reviewing! |
Addresses #351
This refactors so that
Domain = type
Bint[n]
andReals[m,n]
are type hints (still instances ofDomain
)This does not appear to cause any overhead, e.g. test_sum_product.py seems to reduce from 95sec to 89 sec.
Tasks
Domain
,Reals[]
,Bint[]
reals()
andbint()
for follow-up PRs:
bint()
,reals()
, etc.)