-
Notifications
You must be signed in to change notification settings - Fork 246
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Proof of Concept: Types and MyPy #1906
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,6 +3,7 @@ | |
|
||
from collections import OrderedDict | ||
from functools import partial | ||
from typing import Callable | ||
|
||
import jax | ||
from jax import device_put, lax, random | ||
|
@@ -278,14 +279,17 @@ def scan_wrapper( | |
length, | ||
reverse, | ||
rng_key=None, | ||
substitute_stack=[], | ||
substitute_stack=None, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is something I spotted while adding hints. We would want to avoid having mutables as defaults. |
||
enum=False, | ||
history=1, | ||
first_available_dim=None, | ||
): | ||
if length is None: | ||
length = jnp.shape(jax.tree.flatten(xs)[0][0])[0] | ||
|
||
if substitute_stack is None: | ||
substitute_stack = [] | ||
|
||
if enum and history > 0: | ||
return scan_enum( # TODO: replay for enum | ||
f, | ||
|
@@ -339,7 +343,14 @@ def body_fn(wrapped_carry, x): | |
return last_carry, (pytree_trace, ys) | ||
|
||
|
||
def scan(f, init, xs, length=None, reverse=False, history=1): | ||
def scan( | ||
f: Callable, | ||
init, | ||
xs, | ||
length: int | None = None, | ||
reverse: bool = False, | ||
history: int = 1, | ||
): | ||
""" | ||
This primitive scans a function over the leading array axes of | ||
`xs` while carrying along state. See :func:`jax.lax.scan` for more | ||
|
@@ -433,7 +444,7 @@ def g(*args, **kwargs): | |
:param init: the initial carrying state | ||
:param xs: the values over which we scan along the leading axis. This can | ||
be any JAX pytree (e.g. list/dict of arrays). | ||
:param length: optional value specifying the length of `xs` | ||
:param int | None length: optional value specifying the length of `xs` | ||
but can be used when `xs` is an empty pytree (e.g. None) | ||
:param bool reverse: optional boolean specifying whether to run the scan iteration | ||
forward (the default) or in reverse | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -102,3 +102,15 @@ doctest_optionflags = [ | |
"NORMALIZE_WHITESPACE", | ||
"IGNORE_EXCEPTION_DETAIL", | ||
] | ||
|
||
[tool.mypy] | ||
ignore_errors = true | ||
ignore_missing_imports = true | ||
|
||
[[tool.mypy.overrides]] | ||
module = [ | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Here, we can keep adding the modules we want to check. Eventually, we would like to simply remove this and check everything |
||
"numpyro.contrib.control_flow.*", # types missing | ||
"numpyro.contrib.funsor.*", # types missing | ||
"numpyro.contrib.hsgp.*", | ||
] | ||
ignore_errors = false |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
See https://github.com/pyro-ppl/pyro/blob/455f7b3b8b21f8e93a96235fc6bd58cb60f8a3fa/Makefile#L24