Skip to content

Commit

Permalink
feat: add mise.toml and support numpy>=2 (#254)
Browse files Browse the repository at this point in the history
* `mise.toml` and support numpy>=2

* fix: remove unsupported `np.float_`

* tests: fix unfitted PCA error
  • Loading branch information
eonu authored Dec 22, 2024
1 parent b5a4b0f commit bad75a0
Show file tree
Hide file tree
Showing 5 changed files with 14 additions and 8 deletions.
6 changes: 6 additions & 0 deletions mise.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
[tools]
poetry = { version = 'latest', pyproject = 'pyproject.toml' }
python = '3.12'

[env]
_.python.venv = ".venv"
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -64,9 +64,9 @@ requires = ['poetry-core~=1.0', 'Cython>=0.28.5']
build-backend = 'poetry.core.masonry.api'

[tool.poetry.dependencies]
python = "^3.11"
python = "^3.11,<3.13"
numba = ">=0.56,<1"
numpy = "^1.19.5"
numpy = ">=1.19.5,<3"
hmmlearn = ">=0.2.8,<1"
dtaidistance = "^2.3.10"
scikit-learn = "^1.4"
Expand Down
4 changes: 2 additions & 2 deletions sequentia/_internal/_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,6 @@

__all__ = ["FloatArray", "IntArray", "Array"]

FloatArray = npt.NDArray[np.float_]
IntArray = npt.NDArray[np.int_]
FloatArray = npt.NDArray[np.float64]
IntArray = npt.NDArray[np.int64]
Array = FloatArray | IntArray
6 changes: 3 additions & 3 deletions sequentia/_internal/_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def check_X(
X: t.Iterable[int] | t.Iterable[float],
/,
*,
dtype: np.float_ | np.int_,
dtype: np.float64 | np.int64,
univariate: bool = False,
) -> Array:
if not isinstance(X, np.ndarray):
Expand Down Expand Up @@ -133,7 +133,7 @@ def check_X_lengths(
/,
*,
lengths: t.Iterable[int] | None,
dtype: np.float_ | np.int_,
dtype: np.float64 | np.int64,
univariate: bool = False,
) -> tuple[Array, IntArray]:
# validate observations
Expand Down Expand Up @@ -172,7 +172,7 @@ def check_y(
/,
*,
lengths: IntArray,
dtype: np.float_ | np.int_ | None = None,
dtype: np.float64 | np.int64 | None = None,
) -> Array:
if y is None:
msg = "No output values `y` provided"
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/test_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def test_pipeline_with_transforms(
)

# check that transforming without fitting doesn't work
with pytest.raises(NotFittedError):
with pytest.raises((NotFittedError, AttributeError)):
pipeline.transform(**data.X_lengths)

# check that fitting without y works
Expand Down

0 comments on commit bad75a0

Please sign in to comment.