Skip to content

TST enable non-CPU device testing via array-api-strict #30090

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 22 commits into from
Mar 7, 2025

Conversation

ogrisel
Copy link
Member

@ogrisel ogrisel commented Oct 17, 2024

This is an early draft PR to attempt to leverage multi device support recently merged in array-api-strict: data-apis/array-api-strict#59

We need to wait for a release of array-api-strict + a lock file update to actually get this to run on our CI.

However, I think we should investigate failures early in scikit-learn because I suspect that some (most?) of them are not necessarily a problem in scikit-learn but might be bugs in array-api-strict's device support itself.

/cc @betatim

Copy link

github-actions bot commented Oct 17, 2024

✔️ Linting Passed

All linting checks passed. Your pull request is in excellent shape! ☀️

Generated for commit: 5e9ff07. Link to the linter CI: here

@ogrisel
Copy link
Member Author

ogrisel commented Oct 17, 2024

Here is the output of

$ pytest -v -k array_api_strict  -l -x

on my machine with the main branch of array-api-strict:

==================================================================== test session starts ====================================================================
platform darwin -- Python 3.12.5, pytest-8.3.2, pluggy-1.5.0 -- /Users/ogrisel/miniforge3/envs/dev/bin/python3.12
cachedir: .pytest_cache
hypothesis profile 'default' -> database=DirectoryBasedExampleDatabase(PosixPath('/Users/ogrisel/code/scikit-learn/.hypothesis/examples'))
rootdir: /Users/ogrisel/code/scikit-learn
configfile: setup.cfg
testpaths: sklearn
plugins: repeat-0.9.2, hypothesis-6.112.2, anyio-4.4.0, run-parallel-0.1.0, xdist-3.6.1
collected 37763 items / 37368 deselected / 2 skipped / 395 selected                                                                                         

sklearn/decomposition/tests/test_pca.py::test_pca_array_api_compliance[PCA(n_components=2,svd_solver='full')-check_array_api_input_and_values-array_api_strict-device1-float64] PASSED [  0%]
sklearn/decomposition/tests/test_pca.py::test_pca_array_api_compliance[PCA(n_components=2,svd_solver='full')-check_array_api_input_and_values-array_api_strict-device2-float32] FAILED [  0%]

========================================================================= FAILURES ==========================================================================
__________ test_pca_array_api_compliance[PCA(n_components=2,svd_solver='full')-check_array_api_input_and_values-array_api_strict-device2-float32] ___________

estimator = PCA(n_components=2, svd_solver='full'), check = <function check_array_api_input_and_values at 0x1337b54e0>, array_namespace = 'array_api_strict'
device = array_api_strict.Device('device1'), dtype_name = 'float32'

    @pytest.mark.parametrize(
        "array_namespace, device, dtype_name", yield_namespace_device_dtype_combinations()
    )
    @pytest.mark.parametrize(
        "check",
        [check_array_api_input_and_values, check_array_api_get_precision],
        ids=_get_check_estimator_ids,
    )
    @pytest.mark.parametrize(
        "estimator",
        [
            PCA(n_components=2, svd_solver="full"),
            PCA(n_components=2, svd_solver="full", whiten=True),
            PCA(n_components=0.1, svd_solver="full", whiten=True),
            PCA(n_components=2, svd_solver="covariance_eigh"),
            PCA(n_components=2, svd_solver="covariance_eigh", whiten=True),
            PCA(
                n_components=2,
                svd_solver="randomized",
                power_iteration_normalizer="QR",
                random_state=0,  # how to use global_random_seed here?
            ),
        ],
        ids=_get_check_estimator_ids,
    )
    def test_pca_array_api_compliance(
        estimator, check, array_namespace, device, dtype_name
    ):
        name = estimator.__class__.__name__
>       check(name, estimator, array_namespace, device=device, dtype_name=dtype_name)

array_namespace = 'array_api_strict'
check      = <function check_array_api_input_and_values at 0x1337b54e0>
device     = array_api_strict.Device('device1')
dtype_name = 'float32'
estimator  = PCA(n_components=2, svd_solver='full')
name       = 'PCA'

sklearn/decomposition/tests/test_pca.py:1036: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
sklearn/utils/estimator_checks.py:861: in check_array_api_input_and_values
    return check_array_api_input(
        array_namespace = 'array_api_strict'
        device     = array_api_strict.Device('device1')
        dtype_name = 'float32'
        estimator_orig = PCA(n_components=2, svd_solver='full')
        name       = 'PCA'
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _

name = 'PCA', estimator_orig = PCA(n_components=2, svd_solver='full'), array_namespace = 'array_api_strict', device = array_api_strict.Device('device1')
dtype_name = 'float32', check_values = True

    def check_array_api_input(
        name,
        estimator_orig,
        array_namespace,
        device=None,
        dtype_name="float64",
        check_values=False,
    ):
        """Check that the estimator can work consistently with the Array API
    
        By default, this just checks that the types and shapes of the arrays are
        consistent with calling the same estimator with numpy arrays.
    
        When check_values is True, it also checks that calling the estimator on the
        array_api Array gives the same results as ndarrays.
        """
        xp = _array_api_for_tests(array_namespace, device)
    
        X, y = make_classification(random_state=42)
        X = X.astype(dtype_name, copy=False)
    
        X = _enforce_estimator_tags_X(estimator_orig, X)
        y = _enforce_estimator_tags_y(estimator_orig, y)
    
        est = clone(estimator_orig)
    
        X_xp = xp.asarray(X, device=device)
        y_xp = xp.asarray(y, device=device)
    
        est.fit(X, y)
    
        array_attributes = {
            key: value for key, value in vars(est).items() if isinstance(value, np.ndarray)
        }
    
        est_xp = clone(est)
        with config_context(array_api_dispatch=True):
            est_xp.fit(X_xp, y_xp)
            input_ns = get_namespace(X_xp)[0].__name__
    
        # Fitted attributes which are arrays must have the same
        # namespace as the one of the training data.
        for key, attribute in array_attributes.items():
            est_xp_param = getattr(est_xp, key)
            with config_context(array_api_dispatch=True):
                attribute_ns = get_namespace(est_xp_param)[0].__name__
            assert attribute_ns == input_ns, (
                f"'{key}' attribute is in wrong namespace, expected {input_ns} "
                f"got {attribute_ns}"
            )
    
>           assert array_device(est_xp_param) == array_device(X_xp)
E           AssertionError

X          = array([[-2.0251427 ,  0.0291022 , -0.4749453 , ..., -0.33450124,
         0.8657552 , -1.2002964 ],
       [ 1.6137112... ],
       [-0.00607091,  1.3085763 , -0.17495976, ...,  0.99204236,
         0.3216978 , -0.66809046]], dtype=float32)
X_xp       = Array([[-2.0251427 ,
         0.0291022 ,
        -0.4749453 ,
        ...,
        -0.33450124,
         0.8657552 ,
...
         0.3216978 ,
        -0.66809046]], dtype=array_api_strict.float32, device=array_api_strict.Device('device1'))
array_attributes = {'components_': array([[ 0.03484652,  0.6045526 , -0.09228071, -0.09317975,  0.02118714,
        -0.46225083,  0.03672...52413,  0.13682878,
       -0.03120608,  0.05840071,  0.055825  ,  0.12556158, -0.03976958],
      dtype=float32), ...}
array_namespace = 'array_api_strict'
attribute  = array([ 0.18988031,  0.03833218,  0.07648806,  0.08370368,  0.02213484,
       -0.04884844,  0.02524958, -0.11081639, ... 0.00852413,  0.13682878,
       -0.03120608,  0.05840071,  0.055825  ,  0.12556158, -0.03976958],
      dtype=float32)
attribute_ns = 'array_api_strict'
check_values = True
device     = array_api_strict.Device('device1')
dtype_name = 'float32'
est        = PCA(n_components=2, svd_solver='full')
est_xp     = PCA(n_components=2, svd_solver='full')
est_xp_param = Array([ 0.18988031,  0.03833218,
        0.07648806,  0.08370368,
        0.02213484, -0.04884844,
        0.02524958,...682878, -0.03120608,
        0.05840071,  0.055825  ,
        0.12556158, -0.03976958], dtype=array_api_strict.float32)
estimator_orig = PCA(n_components=2, svd_solver='full')
input_ns   = 'array_api_strict'
key        = 'mean_'
name       = 'PCA'
xp         = <module 'array_api_strict' from '/Users/ogrisel/code/array-api-strict/array_api_strict/__init__.py'>
y          = array([0, 0, 1, 1, 0, 0, 0, 1, 0, 1, 1, 0, 0, 0, 1, 1, 1, 0, 0, 1, 1, 0,
       0, 0, 0, 1, 1, 0, 1, 0, 0, 0, 0, 0, 0,...1,
       0, 0, 1, 0, 1, 0, 1, 0, 1, 1, 1, 0, 0, 0, 1, 0, 1, 0, 1, 1, 1, 1,
       1, 0, 0, 1, 0, 1, 1, 0, 1, 1, 0, 0])
y_xp       = Array([0,
       0,
       1,
       1,
       0,
       0,
       0,
       1,
       0,
       1,
       1,
       0...   0,
       1,
       1,
       0,
       0], dtype=array_api_strict.int64, device=array_api_strict.Device('device1'))

sklearn/utils/estimator_checks.py:762: AssertionError
================================================================== short test summary info ==================================================================
FAILED sklearn/decomposition/tests/test_pca.py::test_pca_array_api_compliance[PCA(n_components=2,svd_solver='full')-check_array_api_input_and_values-array_api_strict-device2-float32] - AssertionError
!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! stopping after 1 failures !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
=========================================== 1 failed, 1 passed, 2 skipped, 37368 deselected, 23 warnings in 8.24s ===========================================

This same test passes with PyTorch and CUDA or MPS devices, so I suspect that the lack of device propagation in the computation of the mean_ attribute might reveal a bug in array-api-strict itself. I have not yet investigated in details.

@betatim
Copy link
Member

betatim commented Oct 21, 2024

I think we will need data-apis/array-api-strict#73 and data-apis/array-api-strict#72 for this PR to work

@betatim
Copy link
Member

betatim commented Oct 21, 2024

Another issue that needs resolving scipy/scipy#21736

@ogrisel
Copy link
Member Author

ogrisel commented Feb 6, 2025

We should update the lock file to try to run the tests in this PR with the new version of array-api-strict.

EDIT: the lock files have probably already been updated in main, let's just sync this branch and see what happens.

@ogrisel
Copy link
Member Author

ogrisel commented Feb 7, 2025

@betatim I started to update this PR: it discovered several device handling issues and maybe dtype related issues. I have not yet fixed them all, feel free to take over at any point :)

To compute batch sizes and memory sizes we don't need to use the array
API, we can do that math with "just" Python types.

This change also fixes a slicing error that only appears with
array-api-strict. Unrelated to changing to Python types.
The scipy implementation contains a bug with respect to setting the
device of all the arrays it creates. This adds xlogy() to our group of
functions we implement ourselves.
Using this in functions that support the xp short circuiting, so I think
it makes sense to make this function look similar to get_namespace
@betatim betatim requested review from glemaitre and OmarManzoor March 3, 2025 12:46
@betatim
Copy link
Member

betatim commented Mar 3, 2025

I pinged Omar and Guillaume for reviews. You don't have to review this, but I thought it might be interesting for you two (and solve the problem that neither Oliver nor I can approve this).

Copy link
Contributor

@OmarManzoor OmarManzoor left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you @betatim and @ogrisel.
Generally looks good, just a few comments

betatim and others added 2 commits March 4, 2025 09:19
This reduces the amount of `xp.asarray` that we need to convert scalars
to arrays for the array API
Co-authored-by: Omar Salman <omar.salman2007@gmail.com>
@ogrisel ogrisel added the CUDA CI label Mar 6, 2025
@github-actions github-actions bot removed the CUDA CI label Mar 6, 2025
@ogrisel ogrisel added the CUDA CI label Mar 6, 2025
@github-actions github-actions bot removed the CUDA CI label Mar 6, 2025
@ogrisel ogrisel added the CUDA CI label Mar 6, 2025
@github-actions github-actions bot removed the CUDA CI label Mar 6, 2025
@ogrisel ogrisel added the CUDA CI label Mar 6, 2025
@github-actions github-actions bot removed the CUDA CI label Mar 6, 2025
@ogrisel ogrisel added the CUDA CI label Mar 6, 2025
@github-actions github-actions bot removed the CUDA CI label Mar 6, 2025
@ogrisel
Copy link
Member Author

ogrisel commented Mar 6, 2025

@OmarManzoor @betatim after #30090 (comment), the code is simpler, and all tests pass everywhere.

+1 for merge on my side.

@OmarManzoor
Copy link
Contributor

@OmarManzoor @betatim after #30090 (comment), the code is simpler, and all tests pass everywhere.

+1 for merge on my side.

👍 Let's wait for the CI to complete and I'll review and merge

Copy link
Contributor

@OmarManzoor OmarManzoor left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Thank you @ogrisel and @betatim

@OmarManzoor OmarManzoor merged commit 368a200 into scikit-learn:main Mar 7, 2025
37 checks passed
@ogrisel ogrisel deleted the multi-device-array-api-strict branch March 7, 2025 09:51
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants