Skip to content

Commit

Permalink
feat: jaxtyping support for coreax.data.Data
Browse files Browse the repository at this point in the history
Enables `Data` and `SupervisedData` instances to be given a type hint
that indicates the type and size of their respective `Data.data`
attributes.

For example: `x: Int[Data, "n d"] = ...`, indicates `x` is an instance
of `Data` whose `data` attribute is an integer array of shape `n d`.

Refs: #765
  • Loading branch information
tc85324 committed Nov 24, 2024
1 parent 3556e02 commit c1c9d40
Show file tree
Hide file tree
Showing 4 changed files with 82 additions and 4 deletions.
1 change: 1 addition & 0 deletions .cspell/library_terms.txt
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ jacrev
jax
jaxlib
jaxopt
jaxtyped
jaxtyping
jumanjihouse
keepends
Expand Down
17 changes: 17 additions & 0 deletions coreax/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,13 @@ class Data(eqx.Module):
`n`-vector inputs for `data` are interpreted as `n` points in 1-dimension and
converted to a `(n, 1)` array.
Compatible with :mod:`jaxtyping` -- :class:`Data` is interpreted as an array type,
whose shape is the expected shape of :attr:`Data.data`.
.. example::
A `Data` object whose :attr:`Data.data` is expected to be a floating point array
with shape `a b`, can be type hinted as `x: Float[Data, " a b"] = ...`.
:param data: An :math:`n \times d` array defining the features of the unsupervised
dataset
:param weights: An :math:`n`-vector of weights where each element of the weights
Expand Down Expand Up @@ -164,6 +171,16 @@ def __len__(self) -> int:
"""Return data length."""
return len(self.data)

@property
def dtype(self):
"""Return dtype of data; used for jaxtyping annotations."""
return self.data.dtype

@property
def shape(self):
"""Return shape of data; used for jaxtyping annotations."""
return self.data.shape

def normalize(self, *, preserve_zeros: bool = False) -> Self:
"""
Return a copy of ``self`` with ``weights`` that sum to one.
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ dependencies = [
"flax",
"jax",
"jaxopt",
"jaxtyping",
"jaxtyping>0.2.31",
"optax",
"scikit-learn",
"tqdm",
Expand Down
66 changes: 63 additions & 3 deletions tests/unit/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,16 @@
produce the expected results on simple examples.
"""

from functools import partial
from contextlib import nullcontext

import equinox as eqx
import jax.numpy as jnp
import jax.tree_util as jtu
import jaxtyping
import pytest
from jax import Array
from jaxtyping import Float, Shaped, jaxtyped
from typeguard import typechecked

import coreax.data

Expand Down Expand Up @@ -99,8 +102,8 @@ def test_atleast_2d_consistent(arrays: tuple[Array]) -> None:
@pytest.mark.parametrize(
"data_type",
[
partial(coreax.data.Data, DATA_ARRAY),
partial(coreax.data.SupervisedData, DATA_ARRAY, SUPERVISION),
jtu.Partial(coreax.data.Data, DATA_ARRAY),
jtu.Partial(coreax.data.SupervisedData, DATA_ARRAY, SUPERVISION),
],
)
class TestData:
Expand Down Expand Up @@ -147,6 +150,16 @@ def test_len(self, data_type):
_data = data_type()
assert len(_data) == len(_data.data)

def dtype(self, data_type):
"""Test dtype property; required for jaxtyping annotations."""
_data = data_type()
assert _data.data.dtype == _data.dtype

def shape(self, data_type):
"""Test shape property; required for jaxtyping annotations."""
_data = data_type()
assert _data.data.shape == _data.shape

@pytest.mark.parametrize("weights", (None, 0, 3, DATA_ARRAY.reshape(-1)))
def test_normalize(self, data_type, weights):
"""Test weight normalization."""
Expand All @@ -160,6 +173,53 @@ def test_normalize(self, data_type, weights):
normalized_with_zeros.weights, jnp.nan_to_num(expected_weights)
)

def test_jaxtyping_compatibility(self, data_type):
"""
Test `Data` compatibility with jaxtyping annotations.
Checks the following cases:
- Correct shape and data type,
- Correct shape and incorrect data type,
- Incorrect shape and correct data type.
"""
float_data_type = eqx.tree_at(
lambda x: x.args,
data_type,
replace=jtu.tree_map(lambda y: jnp.astype(y, jnp.float32), data_type.args),
)
int_data_type = eqx.tree_at(
lambda x: x.args,
data_type,
replace=jtu.tree_map(lambda y: jnp.astype(y, jnp.int32), data_type.args),
)
float_data = float_data_type()
int_data = int_data_type()

@jaxtyped(typechecker=typechecked)
def good_shape_check(x: Shaped[coreax.data.Data, "3 1"]):
del x

@jaxtyped(typechecker=typechecked)
def float_dtype_check(x: Float[coreax.data.Data, "..."]):
del x

@jaxtyped(typechecker=typechecked)
def bad_shape_check(x: Shaped[coreax.data.Data, "3"]):
del x

float_check_context = pytest.raises(jaxtyping.TypeCheckError)
bad_shape_check_context = pytest.raises(jaxtyping.TypeCheckError)
for data in [int_data, float_data]:
good_shape_check(data)
with bad_shape_check_context:
bad_shape_check(data)
# We don't expect an error to be raised in the `float_data`` case, on a call
# to float_dtype_check(...)
if data is float_data:
float_check_context = nullcontext()
with float_check_context:
float_dtype_check(data)


class TestSupervisedData:
"""Test operation of SupervisedData class."""
Expand Down

0 comments on commit c1c9d40

Please sign in to comment.