Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Proof of Concept: Types and MyPy #1906

Merged
merged 5 commits into from
Nov 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ jobs:
pip install https://github.com/pyro-ppl/funsor/archive/master.zip
pip install -r docs/requirements.txt
pip freeze
- name: Lint with ruff
- name: Lint with mypy and ruff
run: |
make lint
- name: Build documentation
Expand Down
7 changes: 7 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,13 @@ repos:
language: system
files: "(.py$)|(.*.ipynb$)"

- id: mypy
name: mypy
language: python
entry: mypy --install-types --non-interactive
files: ^numpyro/


- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.6.0
hooks:
Expand Down
1 change: 1 addition & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ lint: FORCE
ruff check .
ruff format . --check
python scripts/update_headers.py --check
mypy --install-types --non-interactive numpyro
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.


license: FORCE
python scripts/update_headers.py
Expand Down
3 changes: 2 additions & 1 deletion numpyro/contrib/control_flow/cond.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# SPDX-License-Identifier: Apache-2.0

from functools import partial
from typing import Any, Callable

from jax import device_put, lax

Expand Down Expand Up @@ -72,7 +73,7 @@ def cond_wrapper(
return lax.cond(pred, wrapped_true_fun, wrapped_false_fun, wrapped_operand)


def cond(pred, true_fun, false_fun, operand):
def cond(pred: bool, true_fun: Callable, false_fun: Callable, operand: Any) -> Any:
"""
This primitive conditionally applies ``true_fun`` or ``false_fun``. See
:func:`jax.lax.cond` for more information.
Expand Down
17 changes: 14 additions & 3 deletions numpyro/contrib/control_flow/scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from collections import OrderedDict
from functools import partial
from typing import Callable

import jax
from jax import device_put, lax, random
Expand Down Expand Up @@ -278,14 +279,17 @@ def scan_wrapper(
length,
reverse,
rng_key=None,
substitute_stack=[],
substitute_stack=None,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is something I spotted while adding hints. We would want to avoid having mutables as defaults.

enum=False,
history=1,
first_available_dim=None,
):
if length is None:
length = jnp.shape(jax.tree.flatten(xs)[0][0])[0]

if substitute_stack is None:
substitute_stack = []

if enum and history > 0:
return scan_enum( # TODO: replay for enum
f,
Expand Down Expand Up @@ -339,7 +343,14 @@ def body_fn(wrapped_carry, x):
return last_carry, (pytree_trace, ys)


def scan(f, init, xs, length=None, reverse=False, history=1):
def scan(
f: Callable,
init,
xs,
length: int | None = None,
reverse: bool = False,
history: int = 1,
):
"""
This primitive scans a function over the leading array axes of
`xs` while carrying along state. See :func:`jax.lax.scan` for more
Expand Down Expand Up @@ -433,7 +444,7 @@ def g(*args, **kwargs):
:param init: the initial carrying state
:param xs: the values over which we scan along the leading axis. This can
be any JAX pytree (e.g. list/dict of arrays).
:param length: optional value specifying the length of `xs`
:param int | None length: optional value specifying the length of `xs`
but can be used when `xs` is an empty pytree (e.g. None)
:param bool reverse: optional boolean specifying whether to run the scan iteration
forward (the default) or in reverse
Expand Down
8 changes: 6 additions & 2 deletions numpyro/contrib/hsgp/approximation.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,18 @@
import numpyro.distributions as dist


def _non_centered_approximation(phi: ArrayImpl, spd: ArrayImpl, m: int) -> ArrayImpl:
def _non_centered_approximation(
phi: ArrayImpl, spd: ArrayImpl, m: int | list[int]
) -> ArrayImpl:
with numpyro.plate("basis", m):
beta = numpyro.sample("beta", dist.Normal(loc=0.0, scale=1.0))

return phi @ (spd * beta)


def _centered_approximation(phi: ArrayImpl, spd: ArrayImpl, m: int) -> ArrayImpl:
def _centered_approximation(
phi: ArrayImpl, spd: ArrayImpl, m: int | list[int]
) -> ArrayImpl:
with numpyro.plate("basis", m):
beta = numpyro.sample("beta", dist.Normal(loc=0.0, scale=spd))

Expand Down
4 changes: 2 additions & 2 deletions numpyro/contrib/hsgp/spectral_densities.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def align_param(dim, param):

def spectral_density_squared_exponential(
dim: int, w: ArrayImpl, alpha: float, length: float | ArrayImpl
) -> float:
) -> ArrayImpl:
"""
Spectral density of the squared exponential kernel.

Expand All @@ -46,7 +46,7 @@ def spectral_density_squared_exponential(
:param float alpha: amplitude
:param float length: length scale
:return: spectral density value
:rtype: float
:rtype: ArrayImpl
"""
length = align_param(dim, length)
c = alpha * jnp.prod(jnp.sqrt(2 * jnp.pi) * length, axis=-1)
Expand Down
2 changes: 1 addition & 1 deletion numpyro/distributions/continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -600,7 +600,7 @@ def __init__(
transition_matrix.ndim == 2
), "`transition_matrix` argument should be a square matrix"
self.transition_matrix = transition_matrix
# Expand the covariance/presicion/scale matrices to the right number of steps.
# Expand the covariance/precision/scale matrices to the right number of steps.
args = {
"covariance_matrix": covariance_matrix,
"precision_matrix": precision_matrix,
Expand Down
3 changes: 2 additions & 1 deletion numpyro/infer/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# SPDX-License-Identifier: Apache-2.0

from collections import namedtuple
from collections.abc import Sequence
from contextlib import contextmanager
from functools import partial
from typing import Callable, Optional
Expand Down Expand Up @@ -931,7 +932,7 @@ def __init__(
guide: Optional[Callable] = None,
params: Optional[dict] = None,
num_samples: Optional[int] = None,
return_sites: Optional[list[str]] = None,
return_sites: Optional[Sequence[str]] = None,
infer_discrete: bool = False,
parallel: bool = False,
batch_ndims: Optional[int] = None,
Expand Down
12 changes: 12 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -102,3 +102,15 @@ doctest_optionflags = [
"NORMALIZE_WHITESPACE",
"IGNORE_EXCEPTION_DETAIL",
]

[tool.mypy]
ignore_errors = true
ignore_missing_imports = true

[[tool.mypy.overrides]]
module = [
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here, we can keep adding the modules we want to check. Eventually, we would like to simply remove this and check everything

"numpyro.contrib.control_flow.*", # types missing
"numpyro.contrib.funsor.*", # types missing
"numpyro.contrib.hsgp.*",
]
ignore_errors = false
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
"test": [
"importlib-metadata<5.0",
"ruff>=0.1.8",
"mypy>=1.13",
"pytest>=4.1",
"pyro-api>=0.1.1",
"scikit-learn",
Expand Down
Loading