Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: DBSCAN via Array API based on #2096 #2100

Draft
wants to merge 53 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
53 commits
Select commit Hold shift + click to select a range
8bedde1
ENH: array api dispatching
samir-nasibli Oct 2, 2024
b11fcf3
Deselect some scikit-learn Array API tests
samir-nasibli Oct 4, 2024
467634a
Merge branch 'intel:main' into enh/array_api_dispatching
samir-nasibli Oct 4, 2024
31030f7
Merge branch 'intel:main' into enh/array_api_dispatching
samir-nasibli Oct 8, 2024
943796e
deselect more tests
samir-nasibli Oct 8, 2024
ef42daa
deselect more tests
samir-nasibli Oct 8, 2024
3bc755d
disabled tests for
samir-nasibli Oct 8, 2024
76f1876
fix the deselection comment
samir-nasibli Oct 8, 2024
ce0b8e1
disabled test for Ridge regression
samir-nasibli Oct 8, 2024
404e8c0
Disabled tests and added comment
samir-nasibli Oct 8, 2024
ced43bf
ENH: Array API dispatching
samir-nasibli Oct 8, 2024
968365f
Merge branch 'intel:main' into enh/array_api_dispatching_testing
samir-nasibli Oct 9, 2024
c395d03
Revert adding dpctl into Array PI conformance testing
samir-nasibli Oct 9, 2024
9271479
Merge branch 'enh/array_api_dispatching_testing' of https://github.co…
samir-nasibli Oct 9, 2024
5784c25
minor refactoring onedal _array_api
samir-nasibli Oct 9, 2024
8d7f664
add tests
samir-nasibli Oct 9, 2024
63d8f30
addressed memory usage tests
samir-nasibli Oct 9, 2024
6bd0280
Address some array api test fails
samir-nasibli Oct 9, 2024
90411e7
linting
samir-nasibli Oct 9, 2024
2b7bbc5
addressed test_get_namespace
samir-nasibli Oct 9, 2024
b7b8f03
adding test case for validate_data check with Array API inputs
samir-nasibli Oct 9, 2024
169009d
minor refactoring
samir-nasibli Oct 9, 2024
9ca118c
addressed test_patch_map_match fail
samir-nasibli Oct 9, 2024
7ddcf40
Added docstrings for get_namespace
samir-nasibli Oct 9, 2024
ec90d43
docstrings for Array API tests
samir-nasibli Oct 9, 2024
6e7e547
updated minimal scikit-learn version for Array API dispatching
samir-nasibli Oct 9, 2024
e5db839
updated minimal scikit-learn version for Array API dispatching in _de…
samir-nasibli Oct 9, 2024
f99a92b
fix test test_get_namespace_with_config_context
samir-nasibli Oct 9, 2024
4e3286a
Merge branch 'main' into enh/array_api_dispatching_testing
samir-nasibli Oct 10, 2024
bc10579
ENH: DBSCAN via Array API
samir-nasibli Oct 10, 2024
1cb07a2
refactor onedal/datatypes/_data_conversion.py
samir-nasibli Oct 11, 2024
acf5689
minor fix
samir-nasibli Oct 11, 2024
9b2a8e9
minor update
samir-nasibli Oct 11, 2024
cfd91ea
added _check_sample_weight via Array API
samir-nasibli Oct 13, 2024
e18d65b
Merge branch 'main' into enh/dbscan_array_api_enh
samir-nasibli Oct 13, 2024
ce9da92
correction for array api
samir-nasibli Oct 13, 2024
016e6e0
returned relative import for _is_csr
samir-nasibli Oct 13, 2024
4f0cfa4
Merge branch 'intel:main' into enh/dbscan_array_api_enh
samir-nasibli Oct 14, 2024
67ecbda
Merge branch 'intel:main' into enh/dbscan_array_api_enh
samir-nasibli Oct 15, 2024
d41468d
Merge branch 'intel:main' into enh/dbscan_array_api_enh
samir-nasibli Oct 18, 2024
cbd9113
Merge branch 'intel:main' into enh/dbscan_array_api_enh
samir-nasibli Oct 19, 2024
1cf2a1c
Merge branch 'main' into enh/dbscan_array_api_enh
samir-nasibli Oct 22, 2024
a430271
Merge branch 'main' into enh/dbscan_array_api_enh
samir-nasibli Nov 7, 2024
b5138cb
re-impl for array api onedal4py dbscan
samir-nasibli Nov 8, 2024
8c31263
add array api ravel func
samir-nasibli Nov 8, 2024
b1411ff
minor refactoring for sklearnex/utils/_array_api.py
samir-nasibli Nov 8, 2024
ee9edd3
minor fix for _check_sample_weight
samir-nasibli Nov 8, 2024
10da160
update for DBSCAN
samir-nasibli Nov 12, 2024
255d605
Merge branch 'main' into enh/dbscan_array_api_enh
samir-nasibli Nov 12, 2024
ada49b8
update to_table call
samir-nasibli Nov 12, 2024
7ceffdc
Merge branch 'intel:main' into enh/dbscan_array_api_enh
samir-nasibli Nov 21, 2024
3b1f431
fixes
samir-nasibli Dec 4, 2024
641235f
Merge branch 'main' into enh/dbscan_array_api_enh
samir-nasibli Dec 5, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 77 additions & 19 deletions onedal/cluster/dbscan.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,22 @@
# ===============================================================================

import numpy as np
from sklearn.utils import check_array

from daal4py.sklearn._utils import get_dtype, make2d
from onedal.utils._array_api import get_dtype, make2d

from ..common._base import BaseEstimator
from ..common._mixin import ClusterMixin
from ..datatypes import _convert_to_supported, from_table, to_table
from ..utils import _check_array
from ..utils._array_api import (
_asarray,
_convert_to_numpy,
_ravel,
get_dtype,
get_namespace,
make2d,
sklearn_array_api_dispatch,
)


class BaseDBSCAN(BaseEstimator, ClusterMixin):
Expand All @@ -46,38 +55,82 @@ def __init__(
self.p = p
self.n_jobs = n_jobs

def _get_onedal_params(self, dtype=np.float32):
def _get_onedal_params(self, xp, dtype):
# TODO:
# change "fptype": dtype,
return {
"fptype": dtype,
"fptype": "float" if dtype == xp.float32 else "double",
"method": "by_default",
"min_observations": int(self.min_samples),
"epsilon": float(self.eps),
"mem_save_mode": False,
"result_options": "core_observation_indices|responses",
}

def _fit(self, X, y, sample_weight, module, queue):
@sklearn_array_api_dispatch()
def _fit(self, X, sua_iface, xp, is_array_api_compliant, y, sample_weight, queue):
policy = self._get_policy(queue, X)
X = _check_array(X, accept_sparse="csr", dtype=[np.float64, np.float32])
# TODO:
# check on dispatching and warn.
# using scikit-learn primitives will require array_api_dispatch=True
X = check_array(X, accept_sparse="csr", dtype=[xp.float64, xp.float32])

sample_weight = make2d(sample_weight) if sample_weight is not None else None
X = make2d(X)
# X_device = X.device if xp else None

# TODO:
# move to _convert_to_supported to do astype conversion
# at once.
types = [xp.float32, xp.float64]

types = [np.float32, np.float64]
# TODO:
# could be impossible, if device doesn't support fp65
# make sense update _convert_to_supported for it.
if get_dtype(X) not in types:
X = X.astype(np.float64)
X = _convert_to_supported(policy, X)
X = X.astype(xp.float64)
X = _convert_to_supported(policy, X, xp=xp)
# TODO:
# remove if not required.
sample_weight = (
_convert_to_supported(policy, sample_weight, xp=xp)
if sample_weight is not None
else None
)
dtype = get_dtype(X)
params = self._get_onedal_params(dtype)
result = module.compute(policy, params, to_table(X), to_table(sample_weight))
params = self._get_onedal_params(xp, dtype)
X_table = to_table(X)
sample_weight_table = to_table(sample_weight)

self.labels_ = from_table(result.responses).ravel()
if result.core_observation_indices is not None:
self.core_sample_indices_ = from_table(
result.core_observation_indices
).ravel()
result = self._get_backend("dbscan", "clustering", None).compute(
policy, params, X_table, sample_weight_table
)
self.labels_ = _ravel(
from_table(result.responses, sua_iface=sua_iface, sycl_queue=queue, xp=xp), xp
)
if (
result.core_observation_indices is not None
and not result.core_observation_indices.kind == "empty"
):
self.core_sample_indices_ = _ravel(
from_table(
result.core_observation_indices,
sycl_queue=queue,
sua_iface=sua_iface,
xp=xp,
),
xp,
)
else:
self.core_sample_indices_ = np.array([], dtype=np.intc)
self.components_ = np.take(X, self.core_sample_indices_, axis=0)
# TODO:
# self.core_sample_indices_ = _asarray([], xp, sycl_queue=queue, dtype=xp.int32)
if sua_iface:
self.core_sample_indices_ = xp.asarray(
[], sycl_queue=queue, dtype=xp.int32
)
else:
self.core_sample_indices_ = xp.asarray([], dtype=xp.int32)
self.components_ = xp.take(X, self.core_sample_indices_, axis=0)
self.n_features_in_ = X.shape[1]
return self

Expand Down Expand Up @@ -105,6 +158,11 @@ def __init__(
self.n_jobs = n_jobs

def fit(self, X, y=None, sample_weight=None, queue=None):
sua_iface, xp, is_array_api_compliant = get_namespace(X)
# TODO:
# update for queue getting.
if sua_iface:
queue = X.sycl_queue
return super()._fit(
X, y, sample_weight, self._get_backend("dbscan", "clustering", None), queue
X, sua_iface, xp, is_array_api_compliant, y, sample_weight, queue
)
20 changes: 20 additions & 0 deletions onedal/cluster/tests/test_dbscan.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,16 @@

import numpy as np
import pytest
from numpy.testing import assert_allclose
from sklearn.cluster import DBSCAN as DBSCAN_SKLEARN
from sklearn.cluster.tests.common import generate_clustered_data

from onedal.cluster import DBSCAN as ONEDAL_DBSCAN
from onedal.tests.utils._dataframes_support import (
_as_numpy,
_convert_to_dataframe,
get_dataframes_and_queues,
)
from onedal.tests.utils._device_selection import get_queues


Expand Down Expand Up @@ -123,3 +129,17 @@ def _test_across_grid_parameter_numpy_gen(queue, metric, use_weights: bool):
@pytest.mark.parametrize("queue", get_queues())
def test_across_grid_parameter_numpy_gen(queue, metric, use_weights: bool):
_test_across_grid_parameter_numpy_gen(queue, metric=metric, use_weights=use_weights)


# TODO:
# dtypes.
@pytest.mark.parametrize("dataframe,queue", get_dataframes_and_queues())
def test_base_dbscan(dataframe, queue):

X = np.array([[1, 2], [2, 2], [2, 3], [8, 7], [8, 8], [25, 80]])
X = _convert_to_dataframe(X, sycl_queue=queue, target_df=dataframe)
dbscan = ONEDAL_DBSCAN(eps=3, min_samples=2).fit(X)

result = dbscan.labels_
expected = np.array([0, 0, 0, 1, 1, -1], dtype=np.int32)
assert_allclose(expected, _as_numpy(result))
8 changes: 4 additions & 4 deletions onedal/datatypes/_data_conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def _table_to_array(table, xp=None):

from ..common._policy import _HostInteropPolicy

def _convert_to_supported(policy, *data):
def _convert_to_supported(policy, *data, xp=np):
def func(x):
return x

Expand All @@ -93,13 +93,13 @@ def func(x):
device = policy._queue.sycl_device

def convert_or_pass(x):
if (x is not None) and (x.dtype == np.float64):
if (x is not None) and (x.dtype == xp.float64):
warnings.warn(
"Data will be converted into float32 from "
"float64 because device does not support it",
RuntimeWarning,
)
return x.astype(np.float32)
return xp.astype(x, dtype=xp.float32)
else:
return x

Expand Down Expand Up @@ -132,7 +132,7 @@ def convert_one_from_table(table, sycl_queue=None, sua_iface=None, xp=None):

else:

def _convert_to_supported(policy, *data):
def _convert_to_supported(policy, *data, xp=np):
def func(x):
return x

Expand Down
1 change: 1 addition & 0 deletions onedal/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
_check_array,
_check_classification_targets,
_check_n_features,
_check_sample_weight,
_check_X_y,
_column_or_1d,
_is_arraylike,
Expand Down
Loading
Loading