Skip to content

Commit

Permalink
Packaging updates
Browse files Browse the repository at this point in the history
- Use pyproject.toml, hatch, hatch-vcs
- Use ruff

Closes #984
  • Loading branch information
TomAugspurger committed Mar 30, 2024
1 parent b3954e9 commit 3114c7f
Show file tree
Hide file tree
Showing 21 changed files with 187 additions and 203 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -125,3 +125,4 @@ docs/source/auto_examples/
docs/source/examples/mydask.png

dask-worker-space
.direnv
27 changes: 9 additions & 18 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,19 +1,10 @@
repos:
- repo: https://github.com/psf/black
rev: 23.12.1
hooks:
- id: black
language_version: python3
args:
- --target-version=py39
- repo: https://github.com/pycqa/flake8
rev: 7.0.0
hooks:
- id: flake8
language_version: python3
args: ["--ignore=E501,W503,E203,E741,E731"]
- repo: https://github.com/pycqa/isort
rev: 5.13.2
hooks:
- id: isort
language_version: python3
- repo: https://github.com/astral-sh/ruff-pre-commit
# Ruff version.
rev: v0.3.4
hooks:
# Run the linter.
- id: ruff
args: [ --fix ]
# Run the formatter.
- id: ruff-format
1 change: 0 additions & 1 deletion ci/environment-latest.yaml

This file was deleted.

18 changes: 3 additions & 15 deletions dask_ml/__init__.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,6 @@
from pkg_resources import DistributionNotFound, get_distribution

# Ensure we always register tokenizers
from dask_ml.model_selection import _normalize

__all__ = []

try:
__version__ = get_distribution(__name__).version
__all__.append("__version__")
except DistributionNotFound:
# package is not installed
pass
from dask_ml.model_selection import _normalize # noqa: F401

from ._version import __version__

del DistributionNotFound
del get_distribution
del _normalize
__all__ = ["__version__"]
17 changes: 17 additions & 0 deletions dask_ml/_version.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# file generated by setuptools_scm
# don't change, don't track in version control
TYPE_CHECKING = False
if TYPE_CHECKING:
from typing import Tuple, Union

VERSION_TUPLE = Tuple[Union[int, str], ...]
else:
VERSION_TUPLE = object

version: str
__version__: str
__version_tuple__: VERSION_TUPLE
version_tuple: VERSION_TUPLE

__version__ = version = "2024.3.21.dev0+gb3954e9e.d20240330"
__version_tuple__ = version_tuple = (2024, 3, 21, "dev0", "gb3954e9e.d20240330")
14 changes: 6 additions & 8 deletions dask_ml/cluster/spectral.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# -*- coding: utf-8 -*-
"""Algorithms for spectral clustering
"""
"""Algorithms for spectral clustering"""

import logging

import dask.array as da
Expand Down Expand Up @@ -272,9 +272,7 @@ def fit(self, X, y=None):
# Eq 16. This is OK when V2 is orthogonal
V2 = da.sqrt(float(n_components) / n) * da.vstack([A2, B2.T]).dot(
U_A[:, :n_clusters]
).dot(
da.diag(1.0 / da.sqrt(S_A[:n_clusters]))
) # (n, k)
).dot(da.diag(1.0 / da.sqrt(S_A[:n_clusters]))) # (n, k)
_log_array(logger, V2, "V2.1")

if isinstance(B2, da.Array):
Expand Down Expand Up @@ -366,9 +364,9 @@ def _slice_mostly_sorted(array, keep, rest, ind=None):
slices.append([keep[0]])
windows = zip(keep[:-1], keep[1:])

for l, r in windows:
if r > l + 1: # avoid creating empty slices
slices.append(slice(l + 1, r))
for left, r in windows:
if r > left + 1: # avoid creating empty slices
slices.append(slice(left + 1, r))
slices.append([r])

if keep[-1] < len(array) - 1: # avoid creating empty slices
Expand Down
5 changes: 3 additions & 2 deletions dask_ml/decomposition/truncated_svd.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,8 +148,9 @@ def fit(self, X, y=None):
def _check_array(self, X):
if self.n_components >= X.shape[1]:
raise ValueError(
"n_components must be < n_features; "
"got {} >= {}".format(self.n_components, X.shape[1])
"n_components must be < n_features; " "got {} >= {}".format(
self.n_components, X.shape[1]
)
)
return X

Expand Down
2 changes: 1 addition & 1 deletion dask_ml/ensemble/_blockwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def fit(self, X, y, **kwargs):
]
results = [
estimator_.fit(X_, y_, **kwargs)
for estimator_, X_, y_, in zip(estimators, Xs, ys)
for estimator_, X_, y_ in zip(estimators, Xs, ys)
]
results = list(dask.compute(*results))
self.estimators_ = results
Expand Down
5 changes: 3 additions & 2 deletions dask_ml/impute.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,9 @@ def fit(self, X, y=None):
allowed_strategies = ["mean", "median", "most_frequent", "constant"]
if self.strategy not in allowed_strategies:
raise ValueError(
"Can only use these strategies: {0} "
" got strategy={1}".format(allowed_strategies, self.strategy)
"Can only use these strategies: {0} " " got strategy={1}".format(
allowed_strategies, self.strategy
)
)

if not (pd.isna(self.missing_values) or self.strategy == "constant"):
Expand Down
1 change: 1 addition & 0 deletions dask_ml/linear_model/glm.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
"""Generalized Linear Models for large datasets."""

import textwrap

from dask_glm import algorithms, families
Expand Down
3 changes: 1 addition & 2 deletions dask_ml/linear_model/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
"""
"""
""" """

import dask.array as da
import dask.dataframe as dd
Expand Down
5 changes: 3 additions & 2 deletions dask_ml/metrics/scorer.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,9 @@ def get_scorer(scoring: Union[str, Callable], compute: bool = True) -> Callable:
scorer, kwargs = SCORERS[scoring]
except KeyError:
raise ValueError(
"{} is not a valid scoring value. "
"Valid options are {}".format(scoring, sorted(SCORERS))
"{} is not a valid scoring value. " "Valid options are {}".format(
scoring, sorted(SCORERS)
)
)
else:
scorer = scoring
Expand Down
3 changes: 1 addition & 2 deletions dask_ml/model_selection/_split.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
"""Utilities for splitting datasets.
"""
"""Utilities for splitting datasets."""

import itertools
import logging
Expand Down
20 changes: 11 additions & 9 deletions dask_ml/model_selection/utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,16 +171,18 @@ def fit(self, X, y, **fit_params):
self.classes_ = np.unique(check_array(y, ensure_2d=False, allow_nd=True))
if self.expected_fit_params:
missing = set(self.expected_fit_params) - set(fit_params)
assert (
len(missing) == 0
), "Expected fit parameter(s) %s not " "seen." % list(missing)
assert len(missing) == 0, (
"Expected fit parameter(s) %s not " "seen." % list(missing)
)
for key, value in fit_params.items():
assert len(value) == len(
X
), "Fit parameter %s has length" "%d; expected %d." % (
key,
len(value),
len(X),
assert len(value) == len(X), (
"Fit parameter %s has length"
"%d; expected %d."
% (
key,
len(value),
len(X),
)
)
return self

Expand Down
3 changes: 1 addition & 2 deletions dask_ml/preprocessing/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
"""Utilties for Preprocessing data.
"""
"""Utilties for Preprocessing data."""

from ._block_transformer import BlockTransformer
from ._encoders import OneHotEncoder
Expand Down
5 changes: 3 additions & 2 deletions dask_ml/preprocessing/label.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,8 +219,9 @@ def _check_and_search_block(arr, uniques, onehot_dtype=None, block_info=None):

if diff:
msg = (
"Block contains previously unseen values {}.\nBlock info:\n\n"
"{}".format(diff, block_info)
"Block contains previously unseen values {}.\nBlock info:\n\n" "{}".format(
diff, block_info
)
)
raise ValueError(msg)

Expand Down
28 changes: 14 additions & 14 deletions dask_ml/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,9 +119,9 @@ def assert_estimator_equal(left, right, exclude=None, **kwargs):
assert left_attrs2 == right_attrs2, left_attrs2 ^ right_attrs2

for attr in left_attrs2:
l = getattr(left, attr)
r = getattr(right, attr)
_assert_eq(l, r, name=attr, **kwargs)
lattr = getattr(left, attr)
rattr = getattr(right, attr)
_assert_eq(lattr, rattr, name=attr, **kwargs)


def check_array(
Expand Down Expand Up @@ -218,27 +218,27 @@ def check_array(
return sk_validation.check_array(array, *args, **kwargs)


def _assert_eq(l, r, name=None, **kwargs):
def _assert_eq(lattr, rattr, name=None, **kwargs):
array_types = (np.ndarray, da.Array)
if getattr(dd, "_dask_expr_enabled", lambda: False)():
from dask_expr import FrameBase

frame_types = (pd.core.generic.NDFrame, FrameBase)
else:
frame_types = (pd.core.generic.NDFrame, dd._Frame)
if isinstance(l, array_types):
assert_eq_ar(l, r, **kwargs)
elif isinstance(l, frame_types):
assert_eq_df(l, r, **kwargs)
elif isinstance(l, Sequence) and any(
isinstance(x, array_types + frame_types) for x in l
if isinstance(lattr, array_types):
assert_eq_ar(lattr, rattr, **kwargs)
elif isinstance(lattr, frame_types):
assert_eq_df(lattr, rattr, **kwargs)
elif isinstance(lattr, Sequence) and any(
isinstance(x, array_types + frame_types) for x in lattr
):
for a, b in zip(l, r):
for a, b in zip(lattr, rattr):
_assert_eq(a, b, **kwargs)
elif np.isscalar(r) and np.isnan(r):
assert np.isnan(l), (name, l, r)
elif np.isscalar(rattr) and np.isnan(rattr):
assert np.isnan(lattr), (name, lattr, rattr)
else:
assert l == r, (name, l, r)
assert lattr == rattr, (name, lattr, rattr)


def check_random_state(random_state):
Expand Down
105 changes: 105 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
[build-system]
requires = ["hatchling", "hatch-vcs"]
build-backend = "hatchling.build"

[project]
name = "dask-ml"
dynamic = ["version"]
description = "A library for distributed and parallel machine learning"
readme = "README.rst"
license = {file = 'LICENSE.txt'}
requires-python = ">=3.8"
authors = [{ name = "Tom Augspurger", email = "taugspurger@anaconda.com" }]
classifiers = [
"Development Status :: 5 - Production/Stable",
"Intended Audience :: Developers",
"License :: OSI Approved :: BSD License",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Topic :: Database",
"Topic :: Scientific/Engineering",
]
dependencies = [
"dask-glm>=0.2.0",
"dask[array,dataframe]>=2.4.0",
"distributed>=2.4.0",
"multipledispatch>=0.4.9",
"numba>=0.51.0",
"numpy>=1.20.0",
"packaging",
"pandas>=0.24.2",
"scikit-learn>=1.2.0",
"scipy",
]

[project.optional-dependencies]
complete = ["dask-xgboost", "xgboost"]
dev = [
"black",
"coverage",
"flake8",
"isort",
"nbsphinx",
"numpydoc",
"pytest",
"pytest-cov",
"pytest-mock",
"sphinx",
"sphinx-gallery",
"sphinx-rtd-theme",
]
docs = ["nbsphinx", "numpydoc", "sphinx", "sphinx-gallery", "sphinx-rtd-theme"]
test = [
"black",
"coverage",
"flake8",
"isort",
"pytest",
"pytest-cov",
"pytest-mock",
]
xgboost = ["dask-xgboost", "xgboost"]

[project.urls]
Homepage = "https://github.com/dask/dask-ml"

[tool.hatch.version]
source = "vcs"

[tool.hatch.build.hooks.vcs]
version-file = "dask_ml/_version.py"

[tool.hatch.build.targets.sdist]
include = ["/dask_ml"]

[tool.mypy]
ignore_missing_imports = true
no_implicit_optional = true
check_untyped_defs = true
strict_equality = true

[[tool.mypy-dask_ml.metrics]]
check_untyped_defs = false

[[tool.mypy.overrides]]
module = "dask_ml.model_selection"
follow_imports = "skip"

[tool.coverage]
source = "dask_ml"

[tool.pytest]
addopts = "-rsx -v --durations=10 --color=yes"
minversion = "3.2"
xfail_strict = true
junit_family = "xunit2"
filterwarnings = [
"error:::dask_ml[.*]",
"error:::sklearn[.*]",
]


[tool.ruff.lint]
ignore = ["E721", "E731", "E741"]
Loading

0 comments on commit 3114c7f

Please sign in to comment.