Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ruff to ci setup #82

Merged
merged 37 commits into from
Feb 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
42e2a8b
Add ruff to ci setup
adonath Dec 14, 2023
4d0ccb9
Fix ruff errors in common/
adonath Jan 26, 2024
fbe6bd4
Fix ruff errors in cupy/
adonath Jan 26, 2024
c7c27be
Fix ruff errors in numpy/
adonath Jan 26, 2024
d3d57b9
Fix ruff errors in torch/
adonath Jan 26, 2024
0d437cd
Fix ruff errors in tests/
adonath Jan 26, 2024
afccf29
Fix ruff errors in array_api_compat/__init__.py
adonath Jan 26, 2024
2395ea0
Implement _get_all_public_members
adonath Jan 26, 2024
645cef2
Move linalg aliases to _aliases
adonath Jan 26, 2024
5a6f411
Fix ruff errors in cupy/linalg
adonath Jan 26, 2024
0ff0836
Move linalg aliases to numpy/_aliases
adonath Jan 26, 2024
f4d78c7
Fix ruff errors in numpy/linalg
adonath Jan 26, 2024
31bbbfa
Hide helper variables in cupy/linalg.py
adonath Jan 26, 2024
5ecc7b5
Move linalg aliases to torch/_aliases
adonath Jan 26, 2024
8e4e9ca
Fix ruff errors in torch/linalg
adonath Jan 26, 2024
5c66efc
Fix final ruff errors in array_api_compat/torch/__init__.py
adonath Jan 26, 2024
bca606d
Expose public members from numpy an cupy in __all__ respectively
adonath Jan 26, 2024
890c497
Clean up
adonath Jan 26, 2024
b2f9557
Add importorskip torch
adonath Jan 26, 2024
52ef9ee
Use importorskip
adonath Jan 26, 2024
b069230
Add missing isdtype
adonath Jan 26, 2024
ff51015
Fix tests
adonath Jan 26, 2024
b0a323d
Rename import_ to import_or_skip_cupy
adonath Jan 27, 2024
0ec2d89
Add missing imports and sort __all__
adonath Jan 27, 2024
2baa4da
More cleanup
adonath Jan 27, 2024
efd745c
Remove redefinitions
adonath Jan 27, 2024
49f2b7a
Add ruff select F822 option [skip ci]
adonath Jan 29, 2024
a748bfa
Add PLC0414 error code as well
adonath Jan 29, 2024
6b4e92c
Avoid in place modification of __all__ in _get_all_public_members
adonath Feb 1, 2024
a92f640
Add sort check for __all__
adonath Feb 2, 2024
9b1110b
Sort __all__ lists
adonath Feb 2, 2024
68c788f
Use * import for array_api_compat/__init__.py
adonath Feb 2, 2024
1720fb6
Update array_api_compat/_internal.py
adonath Feb 7, 2024
5cd47df
Adapt dask
adonath Feb 7, 2024
c5d55ae
Fix ruff errors for dask/array/linalg
adonath Feb 7, 2024
49851b5
Fix __all__ order in dask linalg
adonath Feb 7, 2024
2db3d6a
Fix import of __all__ in dask/array/__init__.py
adonath Feb 7, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions .github/workflows/ruff.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
name: CI
on: [push, pull_request]
jobs:
check-ruff:
runs-on: ubuntu-latest
continue-on-error: true
steps:
- uses: actions/checkout@v3
- name: Install Python
uses: actions/setup-python@v4
with:
python-version: "3.11"
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install ruff
# Update output format to enable automatic inline annotations.
- name: Run Ruff
run: ruff check --output-format=github --select F822,PLC0414,RUF022 --preview .
2 changes: 1 addition & 1 deletion array_api_compat/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,4 @@
"""
__version__ = '1.4.1'

from .common import *
from .common import * # noqa: F401, F403
34 changes: 33 additions & 1 deletion array_api_compat/_internal.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from functools import wraps
from inspect import signature


def get_xp(xp):
"""
Decorator to automatically replace xp with the corresponding array module.
Expand All @@ -21,13 +22,16 @@ def func(x, /, xp, kwarg=None):
arguments.

"""

def inner(f):
@wraps(f)
def wrapped_f(*args, **kwargs):
return f(*args, xp=xp, **kwargs)

sig = signature(f)
new_sig = sig.replace(parameters=[sig.parameters[i] for i in sig.parameters if i != 'xp'])
new_sig = sig.replace(
parameters=[sig.parameters[i] for i in sig.parameters if i != "xp"]
)

if wrapped_f.__doc__ is None:
wrapped_f.__doc__ = f"""\
Expand All @@ -41,3 +45,31 @@ def wrapped_f(*args, **kwargs):
return wrapped_f

return inner


def _get_all_public_members(module, exclude=None, extend_all=False):
"""Get all public members of a module.

Parameters
----------
module : module
The module to get members from.
exclude : callable, optional
A callable that takes a name and returns True if the name should be
excluded from the list of members.
extend_all : bool, optional
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is monkeypatching torch.__all__ etc.? We don't want to do that.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, but this just keeps the current behavior. Take a look at https://github.com/data-apis/array-api-compat/blob/main/array_api_compat/torch/__init__.py#L3

I have not checked whether this is still necessary, but probably we have to keep it this way?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I checked with the following code:

import torch

torch_all = set(torch.__all__)
public = set([name for name in dir(torch) if not name.startswith("_")])

print(torch_all.difference(public))
print(public.difference(torch_all))

And this gives:

set()
{'complex64', 'eig', 'special', ... , 'QInt8Storage', 'segment_reduce', 'ComplexDoubleStorage'}

So indeed __all__ does not contain multiple members and most importantly it does not contain the dtypes.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, but this just keeps the current behavior. Take a look at main/array_api_compat/torch/init.py#L3

That code does not modify the torch.__all__ list:

>>> import torch
>>> torch_all = list(torch.__all__)
>>> import array_api_compat.torch
>>> torch_all2 = list(torch.__all__)
>>> torch_all == torch_all2
True

Generally speaking, this package should not monkeypatch the underlying libraries.

So indeed all does not contain multiple members and most importantly it does not contain the dtypes.

Yes, that's a known issue. pytorch/pytorch#91908

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, now I understand. I did not mean to actually modify torch.__all__ in place but copy and extend instead. I'll fix that behavior.

If True, extend the module's __all__ attribute with the members of the
module derived from dir(module). To be used for libraries that do not have a complete __all__ list.
"""
members = getattr(module, "__all__", [])

if members and not extend_all:
return members

if exclude is None:
exclude = lambda name: name.startswith("_") # noqa: E731

members = members + [_ for _ in dir(module) if not exclude(_)]

# remove duplicates
return list(set(members))
18 changes: 17 additions & 1 deletion array_api_compat/common/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,17 @@
from ._helpers import *
from ._helpers import (
array_namespace,
device,
get_namespace,
is_array_api_obj,
size,
to_device,
)

__all__ = [
"array_namespace",
"device",
"get_namespace",
"is_array_api_obj",
"size",
"to_device",
]
11 changes: 2 additions & 9 deletions array_api_compat/common/_aliases.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@

from typing import TYPE_CHECKING
if TYPE_CHECKING:
from typing import Optional, Sequence, Tuple, Union, List
import numpy as np
from typing import Optional, Sequence, Tuple, Union
from ._typing import ndarray, Device, Dtype, NestedSequence, SupportsBufferProtocol

from typing import NamedTuple
Expand Down Expand Up @@ -544,11 +545,3 @@ def isdtype(
# more strict here to match the type annotation? Note that the
# numpy.array_api implementation will be very strict.
return dtype == kind

__all__ = ['arange', 'empty', 'empty_like', 'eye', 'full', 'full_like',
'linspace', 'ones', 'ones_like', 'zeros', 'zeros_like',
'UniqueAllResult', 'UniqueCountsResult', 'UniqueInverseResult',
'unique_all', 'unique_counts', 'unique_inverse', 'unique_values',
'astype', 'std', 'var', 'permute_dims', 'reshape', 'argsort',
'sort', 'nonzero', 'sum', 'prod', 'ceil', 'floor', 'trunc',
'matmul', 'matrix_transpose', 'tensordot', 'vecdot', 'isdtype']
12 changes: 8 additions & 4 deletions array_api_compat/common/_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,12 @@
"""
from __future__ import annotations

from typing import TYPE_CHECKING

if TYPE_CHECKING:
from typing import Optional, Union, Any
from ._typing import Array, Device

import sys
import math

Expand Down Expand Up @@ -142,7 +148,7 @@ def _check_device(xp, device):
# wrapping or subclassing them. These helper functions can be used instead of
# the wrapper functions for libraries that need to support both NumPy/CuPy and
# other libraries that use devices.
def device(x: "Array", /) -> "Device":
def device(x: Array, /) -> Device:
"""
Hardware device the array data resides on.

Expand Down Expand Up @@ -204,7 +210,7 @@ def _torch_to_device(x, device, /, stream=None):
raise NotImplementedError
return x.to(device)

def to_device(x: "Array", device: "Device", /, *, stream: "Optional[Union[int, Any]]" = None) -> "Array":
def to_device(x: Array, device: Device, /, *, stream: Optional[Union[int, Any]] = None) -> Array:
"""
Copy the array from the device on which it currently resides to the specified ``device``.

Expand Down Expand Up @@ -252,5 +258,3 @@ def size(x):
if None in x.shape:
return None
return math.prod(x.shape)

__all__ = ['is_array_api_obj', 'array_namespace', 'get_namespace', 'device', 'to_device', 'size']
12 changes: 3 additions & 9 deletions array_api_compat/common/_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from typing import TYPE_CHECKING, NamedTuple
if TYPE_CHECKING:
from typing import Literal, Optional, Sequence, Tuple, Union
from typing import Literal, Optional, Tuple, Union
from ._typing import ndarray

import numpy as np
Expand All @@ -11,7 +11,7 @@
else:
from numpy.core.numeric import normalize_axis_tuple

from ._aliases import matmul, matrix_transpose, tensordot, vecdot, isdtype
from ._aliases import matrix_transpose, isdtype
from .._internal import get_xp

# These are in the main NumPy namespace but not in numpy.linalg
Expand Down Expand Up @@ -149,10 +149,4 @@ def trace(x: ndarray, /, xp, *, offset: int = 0, dtype=None, **kwargs) -> ndarra
dtype = xp.float64
elif x.dtype == xp.complex64:
dtype = xp.complex128
return xp.asarray(xp.trace(x, offset=offset, dtype=dtype, axis1=-2, axis2=-1, **kwargs))

__all__ = ['cross', 'matmul', 'outer', 'tensordot', 'EighResult',
'QRResult', 'SlogdetResult', 'SVDResult', 'eigh', 'qr', 'slogdet',
'svd', 'cholesky', 'matrix_rank', 'pinv', 'matrix_norm',
'matrix_transpose', 'svdvals', 'vecdot', 'vector_norm', 'diagonal',
'trace']
return xp.asarray(xp.trace(x, offset=offset, dtype=dtype, axis1=-2, axis2=-1, **kwargs))
3 changes: 3 additions & 0 deletions array_api_compat/common/_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,6 @@ def __getitem__(self, key: int, /) -> _T_co | NestedSequence[_T_co]: ...
def __len__(self, /) -> int: ...

SupportsBufferProtocol = Any

Array = Any
Device = Any
151 changes: 144 additions & 7 deletions array_api_compat/cupy/__init__.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,153 @@
from cupy import *
import cupy as _cp
from cupy import * # noqa: F401, F403

# from cupy import * doesn't overwrite these builtin names
from cupy import abs, max, min, round

from .._internal import _get_all_public_members
from ..common._helpers import (
array_namespace,
device,
get_namespace,
is_array_api_obj,
size,
to_device,
)

# These imports may overwrite names from the import * above.
from ._aliases import *
from ._aliases import (
UniqueAllResult,
UniqueCountsResult,
UniqueInverseResult,
acos,
acosh,
arange,
argsort,
asarray,
asarray_cupy,
asin,
asinh,
astype,
atan,
atan2,
atanh,
bitwise_invert,
bitwise_left_shift,
bitwise_right_shift,
bool,
ceil,
concat,
empty,
empty_like,
eye,
floor,
full,
full_like,
isdtype,
linspace,
matmul,
matrix_transpose,
nonzero,
ones,
ones_like,
permute_dims,
pow,
prod,
reshape,
sort,
std,
sum,
tensordot,
trunc,
unique_all,
unique_counts,
unique_inverse,
unique_values,
var,
vecdot,
zeros,
zeros_like,
)

# See the comment in the numpy __init__.py
__import__(__package__ + '.linalg')
__all__ = []

__all__ += _get_all_public_members(_cp)

__all__ += [
"abs",
"max",
"min",
"round",
]

from .linalg import matrix_transpose, vecdot
__all__ += [
"array_namespace",
"device",
"get_namespace",
"is_array_api_obj",
"size",
"to_device",
]

from ..common._helpers import *
__all__ += [
"UniqueAllResult",
"UniqueCountsResult",
"UniqueInverseResult",
"acos",
"acosh",
"arange",
"argsort",
"asarray",
"asarray_cupy",
"asin",
"asinh",
"astype",
"atan",
"atan2",
"atanh",
"bitwise_invert",
"bitwise_left_shift",
"bitwise_right_shift",
"bool",
"ceil",
"concat",
"empty",
"empty_like",
"eye",
"floor",
"full",
"full_like",
"isdtype",
"linspace",
"matmul",
"matrix_transpose",
"nonzero",
"ones",
"ones_like",
"permute_dims",
"pow",
"prod",
"reshape",
"sort",
"std",
"sum",
"tensordot",
"trunc",
"unique_all",
"unique_counts",
"unique_inverse",
"unique_values",
"var",
"zeros",
"zeros_like",
]

__all__ += [
"matrix_transpose",
"vecdot",
]

# See the comment in the numpy __init__.py
__import__(__package__ + ".linalg")

__array_api_version__ = '2022.12'
__array_api_version__ = "2022.12"
Loading