Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor domains to be types #352

Merged
merged 11 commits into from
Aug 17, 2020
Merged

Refactor domains to be types #352

merged 11 commits into from
Aug 17, 2020

Conversation

fritzo
Copy link
Member

@fritzo fritzo commented Aug 16, 2020

Addresses #351

This refactors so that

  • Domain = type
  • Bint[n] and Reals[m,n] are type hints (still instances of Domain)

This does not appear to cause any overhead, e.g. test_sum_product.py seems to reduce from 95sec to 89 sec.

Tasks

  • refactor Domain, Reals[], Bint[]
  • get tests to pass using shims reals() and bint()
  • mark shims deprecated

for follow-up PRs:

  • remove usage of shims from library code
  • remove shims (bint(), reals(), etc.)

@fritzo fritzo changed the title Domain type Refactor domains to be types Aug 16, 2020
@fritzo
Copy link
Member Author

fritzo commented Aug 16, 2020

@fehiepsi any idea whether the jax failure is real?

FAILED test/test_distribution.py::test_binomial_sample[()-sample_inputs0-True] 
=================================== FAILURES ===================================
_________________ test_binomial_sample[()-sample_inputs0-True] _________________
[gw1] linux -- Python 3.6.7 /home/travis/virtualenv/python3.6.7/bin/python
with_lazy = True, batch_shape = (), sample_inputs = ()
    @pytest.mark.parametrize("with_lazy", [True, xfail_param(False, reason="missing pattern")])
    @pytest.mark.parametrize('sample_inputs', [(), ('ii',), ('ii', 'jj'), ('ii', 'jj', 'kk')])
    @pytest.mark.parametrize('batch_shape', [(), (5,), (2, 3)], ids=str)
    def test_binomial_sample(with_lazy, batch_shape, sample_inputs):
        batch_dims = ('i', 'j', 'k')[:len(batch_shape)]
        inputs = OrderedDict((k, bint(v)) for k, v in zip(batch_dims, batch_shape))
    
        max_count = 10
        total_count_data = random_tensor(inputs, bint(max_count)).data
        if get_backend() == "torch":
            total_count_data = ops.astype(total_count_data, 'float')
        total_count = total_count_data
        probs = rand(batch_shape)
        funsor_dist_class = dist.Binomial
        params = (total_count, probs)
    
>       _check_sample(funsor_dist_class, params, sample_inputs, inputs, atol=2e-2, skip_grad=True, with_lazy=with_lazy)
test/test_distribution.py:931: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
test/test_distribution.py:791: in _check_sample
    _get_stat_diff_fn(params)
test/test_distribution.py:719: in _get_stat_diff
    sample_value = funsor_dist.sample(frozenset(['value']), sample_inputs, rng_key=rng_key)
funsor/terms.py:482: in sample
    result = interpreter.debug_logged(self.unscaled_sample)(sampled_vars, sample_inputs, rng_key)
funsor/distribution.py:129: in unscaled_sample
    raw_sample = ops.detach(raw_dist.sample(*sample_args))
../../../virtualenv/python3.6.7/lib/python3.6/site-packages/numpyro/distributions/discrete.py:161: in sample
    return binomial(key, self.probs, n=self.total_count, shape=sample_shape + self.batch_shape)
../../../virtualenv/python3.6.7/lib/python3.6/site-packages/numpyro/distributions/util.py:161: in binomial
    return _binomial(key, p, n, shape)
../../../virtualenv/python3.6.7/lib/python3.6/site-packages/jax/api.py:171: in f_jitted
    name=flat_fun.__name__, donated_invars=donated_invars)
../../../virtualenv/python3.6.7/lib/python3.6/site-packages/jax/core.py:1134: in bind
    return call_bind(self, fun, *args, **params)
../../../virtualenv/python3.6.7/lib/python3.6/site-packages/jax/core.py:1123: in call_bind
    outs = primitive.impl(fun, *args, **params)
../../../virtualenv/python3.6.7/lib/python3.6/site-packages/jax/interpreters/xla.py:527: in _xla_call_impl
    *unsafe_map(arg_spec, args))
../../../virtualenv/python3.6.7/lib/python3.6/site-packages/jax/linear_util.py:224: in memoized_fun
    ans = call(fun, *args)
../../../virtualenv/python3.6.7/lib/python3.6/site-packages/jax/interpreters/xla.py:598: in _xla_callable
    fun, pvals, instantiate=False, stage_out=True, bottom=True)
../../../virtualenv/python3.6.7/lib/python3.6/site-packages/jax/interpreters/partial_eval.py:423: in trace_to_jaxpr
    jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals)
../../../virtualenv/python3.6.7/lib/python3.6/site-packages/jax/linear_util.py:150: in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
../../../virtualenv/python3.6.7/lib/python3.6/site-packages/numpyro/distributions/util.py:154: in _binomial
    (key, p, n))
../../../virtualenv/python3.6.7/lib/python3.6/site-packages/jax/lax/lax_control_flow.py:1848: in map
    _, ys = scan(g, (), xs)
../../../virtualenv/python3.6.7/lib/python3.6/site-packages/jax/lax/lax_control_flow.py:1232: in scan
    jaxpr, consts, out_tree = _initial_style_jaxpr(f, in_tree, carry_avals + x_avals)
../../../virtualenv/python3.6.7/lib/python3.6/site-packages/jax/lax/lax_control_flow.py:74: in _initial_style_jaxpr
    fun, in_tree, in_avals)
../../../virtualenv/python3.6.7/lib/python3.6/site-packages/jax/lax/lax_control_flow.py:68: in _initial_style_untyped_jaxpr
    wrapped_fun, in_pvals, instantiate=True, stage_out=False)
../../../virtualenv/python3.6.7/lib/python3.6/site-packages/jax/interpreters/partial_eval.py:423: in trace_to_jaxpr
    jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals)
../../../virtualenv/python3.6.7/lib/python3.6/site-packages/jax/linear_util.py:150: in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
../../../virtualenv/python3.6.7/lib/python3.6/site-packages/jax/lax/lax_control_flow.py:1847: in <lambda>
    g = lambda _, x: ((), f(x))
../../../virtualenv/python3.6.7/lib/python3.6/site-packages/numpyro/distributions/util.py:153: in <lambda>
    ret = lax.map(lambda x: _binomial_dispatch(*x),
../../../virtualenv/python3.6.7/lib/python3.6/site-packages/numpyro/distributions/util.py:142: in _binomial_dispatch
    lambda _: jnp.where(cond0, n, 0))
../../../virtualenv/python3.6.7/lib/python3.6/site-packages/jax/lax/lax_control_flow.py:663: in cond
    return _cond_with_per_branch_args(*ba.args)
../../../virtualenv/python3.6.7/lib/python3.6/site-packages/jax/lax/lax_control_flow.py:729: in _cond_with_per_branch_args
    (true_operand, false_operand))
../../../virtualenv/python3.6.7/lib/python3.6/site-packages/jax/lax/lax_control_flow.py:695: in _cond
    (true_fun, false_fun), ops_tree, ops_avals)
../../../virtualenv/python3.6.7/lib/python3.6/site-packages/jax/lax/lax_control_flow.py:91: in _initial_style_jaxprs_with_common_consts
    _initial_style_untyped_jaxpr(fun, in_tree, in_avals) for fun in funs])
../../../virtualenv/python3.6.7/lib/python3.6/site-packages/jax/lax/lax_control_flow.py:91: in <listcomp>
    _initial_style_untyped_jaxpr(fun, in_tree, in_avals) for fun in funs])
../../../virtualenv/python3.6.7/lib/python3.6/site-packages/jax/lax/lax_control_flow.py:68: in _initial_style_untyped_jaxpr
    wrapped_fun, in_pvals, instantiate=True, stage_out=False)
../../../virtualenv/python3.6.7/lib/python3.6/site-packages/jax/interpreters/partial_eval.py:423: in trace_to_jaxpr
    jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals)
../../../virtualenv/python3.6.7/lib/python3.6/site-packages/jax/linear_util.py:154: in call_wrapped
    ans = gen.send(ans)
../../../virtualenv/python3.6.7/lib/python3.6/site-packages/jax/interpreters/partial_eval.py:438: in trace_to_subjaxpr
    out_tracers = map(trace.full_raise, map(core.full_lower, ans))
../../../virtualenv/python3.6.7/lib/python3.6/site-packages/jax/util.py:34: in safe_map
    return list(map(f, *args))
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
self = JaxprTrace(level=1/1)
val = Traced<ShapedArray(int32[]):JaxprTrace(level=2/1)>
    def full_raise(self, val) -> 'Tracer':
      if not isinstance(val, Tracer):
        return self.pure(val)
      level = self.level
      sublevel = self.sublevel
      if val._trace.master is self.master:
        if val._trace.sublevel == sublevel:
          return val
        elif val._trace.sublevel < sublevel:
          return self.sublift(val)
        else:
          raise escaped_tracer_error("Can't lift sublevels {} to {}"
                                     .format(val._trace.sublevel, sublevel))
      elif val._trace.level < level:
        if val._trace.sublevel > sublevel:
          raise escaped_tracer_error("Incompatible sublevel: {}, {}"
                                     .format(val._trace, (level, sublevel)))
        return self.lift(val)
      elif val._trace.level > level:
        raise escaped_tracer_error("Can't lift level {} to {}"
>                                  .format(val, self))
E       jax.core.UnexpectedTracerError: Encountered an unexpected tracer. Perhaps this tracer escaped through global state from a previously traced function.
E       The functions being transformed should not save traced values to global state.
E       Details: Can't lift level Traced<ShapedArray(int32[]):JaxprTrace(level=2/1)> to JaxprTrace(level=1/1).
../../../virtualenv/python3.6.7/lib/python3.6/site-packages/jax/core.py:392: UnexpectedTracerError
 1 failed, 3362 passed, 3097 skipped, 150 xfailed, 36 xpassed in 812.00 seconds 
Makefile:46: recipe for target 'test' failed
make: *** [test] Error 1
The command "CI=1 FUNSOR_BACKEND=jax make test" exited with 2.
cache.2
store build cache
Done. Your build exited with 1.

@fehiepsi
Copy link
Member

@fritzo Did you rerun the job? I can't observe the failure in the link. I observed this error when

  • we trace the sample_shape argument of a sampler
  • or it is a JAX bug

I didn't find something wrong with the current implementation of binomial sampler in NumPyro.

@fritzo
Copy link
Member Author

fritzo commented Aug 16, 2020

@fehiepsi Thanks for checking. I did rerun the test, and it passed. Before running I captured the failure output (in details above). I guess we can just keep this in mind, that JAX samplers can spuriously fail tests.

@fritzo fritzo requested a review from eb8680 August 16, 2020 15:17
@fritzo
Copy link
Member Author

fritzo commented Aug 16, 2020

@eb8680 I'd like to perform this refactoring in two PRs:

  1. this PR that changes the underlying types but preserves the old interface in most of funsor
  2. a follow-up PR that moderinzes usage in funsor code: bint(3) -> Bint[3], reals() -> Real, etc.

The first PR is small, and the second PR will be big but easy to review.

"""
def __new__(cls, shape, dtype):
assert isinstance(shape, tuple)
class Domain(type):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it be more Pythonic to define Domain with typing.Generic?

from typing import Generic, Tuple, TypeVar, Union
from typing_extensions import Literal  # backport of typing.Literal in 3.8

S = TypeVar("S", bound=Tuple)
D = TypeVar("D", bound=Union[int, str])

class Domain(Generic[S, D]):
    def __init__(self, shape: S, dtype: D):
        self.shape = shape
        self.dtype = dtype

# version 1, more concise
Bint = Domain[Tuple[()], TypeVar("DB", bound=int)]
Reals = Domain[S, Literal['real']]

# version 2, nicer printing/pickling
class Bint(Domain[Tuple[()], TypeVar("DB", bound=int)]):
    pass

class Reals(Domain[S, Literal['real']]):
    pass

def reals(*shape):
    return Reals[Tuple[tuple(Literal[size] for size in shape)]]

def bint(size):
    return Bint[Literal[size]]

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My longer term vision is for Domain to be the type of more general ground values in Funsor including Tuple[], eventually including pytrees. Therefore I think it no longer makes sense to paramterize Domain by (shape, dtype).

Indeed I feel we are already shoe-horning bint() and reals() into an ill-fitting abstraction. For example to address #322 we will probably need heterogeneous size attribute which is currently aliased to dtype. Even that alias feels like a hack to me.

Copy link
Member

@eb8680 eb8680 Aug 16, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My longer term vision is for Domain to be the type of more general ground values in Funsor including Tuple[], eventually including pytrees

I'm definitely on board with this, but it seems like in that vision there's still a need for something like Reals and Bint (and hence my version of Domain above, which we could call Shape or Array after Numpy?) to represent dependent tensor output types. Then we could just have the more general ground value base type be typing.Any, or a very broad Union[int, float, str, Shape, Tuple], rather than a custom Domain type?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, I guess there's is a future where Bint and Real are more similar... I'll have to think about this more. I really don't like the way we currently use domain.dtype as an alias for size, and I'd prefer to pull these concepts apart.

Copy link
Member

@eb8680 eb8680 Aug 17, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm fine with separating Bint and Real, it's the BintType/RealType metaclasses in this implementation that I am wondering if we could replace with something more Pythonic based on typing.Generic and typing.Literal. To clarify, this version of Reals simplified from my previous comment inherits directly from Generic:

S = TypeVar("S", bound=Tuple[int, ...])

class Reals(Generic[S]):  # subtypes are cons-hashed
    def __init__(self, shape: S):
        self.shape = shape
        # dtype, etc. here

def reals(*shape):  # for illustration
    return Reals[Tuple[tuple(Literal[size] for size in shape)]]

It's more verbose, but conforms more closely to the spirit of typing.

Also, I'm not sure pointwise functions annotated with the metaclass-based design (def f(x: Reals[1]): ...) would be compatible with static type checkers when checked without the function wrapper (no types parametrized directly by values are allowed, regardless of the details of their implementations) - to what extent should that serve as a design constraint?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do like the idea of being more pythonic and trying to use Generic. However it remains to be seen whether Generic can serve our purpose. Issues include: we want to memoize via WeakValueDict rather than lru_cache imho; Generic's implementation changes across Python versions, so it may be difficult to override __class_getitem__ to cast parameters to int while jitting). But you have more experience with typing than I, so let's pair code on it.

should [compatibility with type checkers] serve as a design constraint?

Yeah good question. I guess the objective of this refactoring is to replace Funsor's bespoke type annotations with more standard annotations and thereby make it easier for new Funsor users to learn and contribute. While that suggests mypy support is a secondary goal, I think it would be appreciated by IDE-based users. Let's see how hard it is, maybe we can add a mypy test stage.

Also I'd like to see whether we can roughly follow NumPy types. Let's discuss tomorrow.

Copy link
Member Author

@fritzo fritzo Aug 17, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some other challenges with more pythonic type annotations are:

  • some of typing's annotations simply disappear during runtime (mypy inspects the asts rather than the runtime values, so Funsor inspection may be more challenging)
  • there is not yet a standard way to inspect type hints; we may need to roll our own version-agnostic module similar to typing_inspect

funsor/domains.py Outdated Show resolved Hide resolved
Bint[5] # integers ranging in {0,1,2,3,4}

To dispatch on domain type, we recommend either ``@singledispatch``,
``@multipledispatch``, or ``isinstance(domain, BintType)``.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When would we dispatch on domain type with isinstance rather than issubclass? More broadly, when will domain types be instantiated, if ever?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently we dispatch via if domain.dtype == "real". I was thinking that if we are to generalize to other domain types (like Tuple[...]) then it would be cleanest to use @singledispatch. For lighter-weight uses (in my immediate refactoring plans) I will probably replace if x.dtype == "real" with isinstance(). Do you have a cleaner suggestion?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To be clear, we would never instantiate domains like Real, but we do instantiate domain types like RealType.

Comment on lines 30 to 32
# in some JAX versions, shape can be np.int64 type
if get_tracing_state() or funsor.get_backend() == "jax":
shape = tuple(map(int, shape))
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

note that if we switch to typing.Generic we'll need to override __getitem__ or __class_getitem__ to do some casting.

Copy link
Member

@eb8680 eb8680 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good after our discussion today, just one nit

@eb8680 eb8680 merged commit 2921bda into master Aug 17, 2020
@eb8680 eb8680 deleted the domain-type branch August 17, 2020 20:39
@fritzo
Copy link
Member Author

fritzo commented Aug 17, 2020

Thanks for reviewing!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants