-
Notifications
You must be signed in to change notification settings - Fork 2.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Consider defining public API explicitly #7570
Comments
Thanks for the report – I'm a bit confused about what mypy requires here, because we currently do explicitly import Do you know what we should be doing differently? |
You can just specify all these public functions in the from jax._src.numpy.lax_numpy import zeros
__all__ = [
"zeros",
...
] That will make mypy happy. |
Another option would be to write from jax._src.numpy.lax_numpy import zeros as zeros From the PEP-0484
But, that will make pylint produce warnings. At least with default settings. |
Thanks - that's unfortunate that mypy requires such extensive duplication of boilerplate. The explicit imports take up 43 full lines; I don't relish the thought of writing Do you know if there's any other way to appease mypy here? |
Unfortunately, these two seem to be the only options :-/ |
I find a decorator like https://stackoverflow.com/questions/41895077/export-decorator-that-manages-all useful to cut down on the boilerplate. Maybe a similar idea could be used here? |
Good thought – although that particular solution requires that In [1]: import jax.numpy as jnp
In [2]: for name in dir(jnp):
...: obj = getattr(jnp, name)
...: objname = getattr(obj, '__name__', None)
...: if objname is None:
...: print(f'jnp.{name} has no name')
...: elif objname != name:
...: print(f'jnp.{name}.__name__ = {objname}')
...:
jnp.DeviceArray.__name__ = DeviceArrayBase
jnp.NINF has no name
jnp.NZERO has no name
jnp.PZERO has no name
jnp._NOT_IMPLEMENTED has no name
jnp.__builtins__ has no name
jnp.__cached__ has no name
jnp.__doc__ has no name
jnp.__file__ has no name
jnp.__loader__ has no name
jnp.__name__ has no name
jnp.__package__ has no name
jnp.__path__ has no name
jnp.__spec__ has no name
jnp.abs.__name__ = absolute
jnp.add_newdoc_ufunc.__name__ = _add_newdoc_ufunc
jnp.alltrue.__name__ = all
jnp.bitwise_not.__name__ = invert
jnp.c_ has no name
jnp.cdouble.__name__ = complex128
jnp.complex_.__name__ = complex128
jnp.conj.__name__ = conjugate
jnp.csingle.__name__ = complex64
jnp.cumproduct.__name__ = cumprod
jnp.degrees.__name__ = rad2deg
jnp.deprecate_with_doc.__name__ = <lambda>
jnp.divide.__name__ = true_divide
jnp.double.__name__ = float64
jnp.e has no name
jnp.empty.__name__ = zeros
jnp.empty_like.__name__ = zeros_like
jnp.euler_gamma has no name
jnp.fastCopyAndTranspose.__name__ = _fastCopyAndTranspose
jnp.fft.__name__ = jax.numpy.fft
jnp.float_.__name__ = float64
jnp.index_exp has no name
jnp.inf has no name
jnp.int_.__name__ = int64
jnp.lax_numpy.__name__ = jax._src.numpy.lax_numpy
jnp.linalg.__name__ = jax.numpy.linalg
jnp.mat.__name__ = asmatrix
jnp.max.__name__ = amax
jnp.mgrid has no name
jnp.min.__name__ = amin
jnp.mod.__name__ = remainder
jnp.nan has no name
jnp.newaxis has no name
jnp.ogrid has no name
jnp.operator_name has no name
jnp.pi has no name
jnp.product.__name__ = prod
jnp.r_ has no name
jnp.radians.__name__ = deg2rad
jnp.round.__name__ = around
jnp.round_.__name__ = around
jnp.row_stack.__name__ = vstack
jnp.s_ has no name
jnp.show_config.__name__ = show
jnp.single.__name__ = float32
jnp.sometrue.__name__ = any |
BTW, I strongly support this change. It's annoying not being able to run Hmm yes, that trick mainly works for functions and classes. But since we are talking about public interface, and jax already does a very good job at having clean namespaces, why don't you just export all symbols with their current name given by Does the example below catch everything? ➜ python
Python 3.8.11 (default, Jun 29 2021, 00:00:00)
[GCC 11.1.1 20210531 (Red Hat 11.1.1-3)] on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> import sys
>>> def _export_public(module_name):
... module = sys.modules[module_name]
... if not hasattr(module, '__all__'):
... setattr(module, '__all__', [])
... __all__ = module.__all__
... for name in dir(module):
... obj = getattr(module, name)
... if not name.startswith("_"):
... __all__.append(name)
...
>>> import jax
>>> jax.numpy.__all__
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
AttributeError: module 'jax.numpy' has no attribute '__all__'
>>> _export_public('jax.numpy')
>>> jax.numpy.__all__
['ComplexWarning', 'DeviceArray', 'NINF', 'NZERO', 'PZERO', 'abs', 'absolute', 'add', 'add_docstring', 'add_newdoc', 'add_newdoc_ufunc', 'alen', 'all', 'allclose', 'alltrue', 'amax', 'amin', 'angle', 'any', 'append', 'apply_along_axis', 'apply_over_axes', 'arange', 'arccos', 'arccosh', 'arcsin', 'arcsinh', 'arctan', 'arctan2', 'arctanh', 'argmax', 'argmin', 'argpartition', 'argsort', 'argwhere', 'around', 'array', 'array2string', 'array_equal', 'array_equiv', 'array_repr', 'array_split', 'array_str', 'asanyarray', 'asarray', 'asarray_chkfinite', 'ascontiguousarray', 'asfarray', 'asfortranarray', 'asmatrix', 'asscalar', 'atleast_1d', 'atleast_2d', 'atleast_3d', 'average', 'bartlett', 'base_repr', 'bfloat16', 'binary_repr', 'bincount', 'bitwise_and', 'bitwise_not', 'bitwise_or', 'bitwise_xor', 'blackman', 'block', 'bmat', 'bool_', 'broadcast_arrays', 'broadcast_shapes', 'broadcast_to', 'busday_count', 'busday_offset', 'byte_bounds', 'c_', 'can_cast', 'cbrt', 'cdouble', 'ceil', 'character', 'choose', 'clip', 'column_stack', 'common_type', 'compare_chararrays', 'complex128', 'complex64', 'complex_', 'complexfloating', 'compress', 'concatenate', 'conj', 'conjugate', 'convolve', 'copy', 'copysign', 'copyto', 'corrcoef', 'correlate', 'cos', 'cosh', 'count_nonzero', 'cov', 'cross', 'csingle', 'cumprod', 'cumproduct', 'cumsum', 'datetime_as_string', 'datetime_data', 'deg2rad', 'degrees', 'delete', 'deprecate', 'deprecate_with_doc', 'diag', 'diag_indices', 'diag_indices_from', 'diagflat', 'diagonal', 'diff', 'digitize', 'disp', 'divide', 'divmod', 'dot', 'double', 'dsplit', 'dstack', 'dtype', 'e', 'ediff1d', 'einsum', 'einsum_path', 'empty', 'empty_like', 'equal', 'euler_gamma', 'exp', 'exp2', 'expand_dims', 'expm1', 'extract', 'eye', 'fabs', 'fastCopyAndTranspose', 'fft', 'fill_diagonal', 'find_common_type', 'finfo', 'fix', 'flatnonzero', 'flexible', 'flip', 'fliplr', 'flipud', 'float16', 'float32', 'float64', 'float_', 'float_power', 'floating', 'floor', 'floor_divide', 'fmax', 'fmin', 'fmod', 'format_float_positional', 'format_float_scientific', 'frexp', 'frombuffer', 'fromfile', 'fromfunction', 'fromiter', 'frompyfunc', 'fromregex', 'fromstring', 'full', 'full_like', 'gcd', 'genfromtxt', 'geomspace', 'get_array_wrap', 'get_include', 'get_printoptions', 'getbufsize', 'geterr', 'geterrcall', 'geterrobj', 'gradient', 'greater', 'greater_equal', 'hamming', 'hanning', 'heaviside', 'histogram', 'histogram2d', 'histogram_bin_edges', 'histogramdd', 'hsplit', 'hstack', 'hypot', 'i0', 'identity', 'iinfo', 'imag', 'in1d', 'index_exp', 'indices', 'inexact', 'inf', 'info', 'inner', 'insert', 'int16', 'int32', 'int64', 'int8', 'int_', 'integer', 'interp', 'intersect1d', 'invert', 'is_busday', 'isclose', 'iscomplex', 'iscomplexobj', 'isfinite', 'isfortran', 'isin', 'isinf', 'isnan', 'isnat', 'isneginf', 'isposinf', 'isreal', 'isrealobj', 'isscalar', 'issctype', 'issubclass_', 'issubdtype', 'issubsctype', 'iterable', 'ix_', 'kaiser', 'kron', 'lax_numpy', 'lcm', 'ldexp', 'left_shift', 'less', 'less_equal', 'lexsort', 'linalg', 'linspace', 'load', 'loads', 'loadtxt', 'log', 'log10', 'log1p', 'log2', 'logaddexp', 'logaddexp2', 'logical_and', 'logical_not', 'logical_or', 'logical_xor', 'logspace', 'lookfor', 'mafromtxt', 'mask_indices', 'mat', 'matmul', 'max', 'maximum', 'maximum_sctype', 'may_share_memory', 'mean', 'median', 'meshgrid', 'mgrid', 'min', 'min_scalar_type', 'minimum', 'mintypecode', 'mod', 'modf', 'moveaxis', 'msort', 'multiply', 'nan', 'nan_to_num', 'nanargmax', 'nanargmin', 'nancumprod', 'nancumsum', 'nanmax', 'nanmean', 'nanmedian', 'nanmin', 'nanpercentile', 'nanprod', 'nanquantile', 'nanstd', 'nansum', 'nanvar', 'ndarray', 'ndfromtxt', 'ndim', 'negative', 'nested_iters', 'newaxis', 'nextafter', 'nonzero', 'not_equal', 'number', 'obj2sctype', 'object_', 'ogrid', 'ones', 'ones_like', 'operator_name', 'outer', 'packbits', 'pad', 'partition', 'percentile', 'pi', 'piecewise', 'place', 'poly', 'polyadd', 'polyder', 'polydiv', 'polyfit', 'polyint', 'polymul', 'polysub', 'polyval', 'positive', 'power', 'printoptions', 'prod', 'product', 'promote_types', 'ptp', 'put', 'put_along_axis', 'putmask', 'quantile', 'r_', 'rad2deg', 'radians', 'ravel', 'ravel_multi_index', 'real', 'real_if_close', 'recfromcsv', 'recfromtxt', 'reciprocal', 'remainder', 'repeat', 'require', 'reshape', 'resize', 'result_type', 'right_shift', 'rint', 'roll', 'rollaxis', 'roots', 'rot90', 'round', 'round_', 'row_stack', 's_', 'safe_eval', 'save', 'savetxt', 'savez', 'savez_compressed', 'sctype2char', 'searchsorted', 'select', 'set_numeric_ops', 'set_printoptions', 'set_string_function', 'setbufsize', 'setdiff1d', 'seterr', 'seterrcall', 'seterrobj', 'setxor1d', 'shape', 'shares_memory', 'show_config', 'sign', 'signbit', 'signedinteger', 'sin', 'sinc', 'single', 'sinh', 'size', 'sometrue', 'sort', 'sort_complex', 'source', 'spacing', 'split', 'sqrt', 'square', 'squeeze', 'stack', 'std', 'subtract', 'sum', 'swapaxes', 'take', 'take_along_axis', 'tan', 'tanh', 'tensordot', 'tile', 'trace', 'transpose', 'trapz', 'tri', 'tril', 'tril_indices', 'tril_indices_from', 'trim_zeros', 'triu', 'triu_indices', 'triu_indices_from', 'true_divide', 'trunc', 'typename', 'uint16', 'uint32', 'uint64', 'uint8', 'union1d', 'unique', 'unpackbits', 'unravel_index', 'unsignedinteger', 'unwrap', 'vander', 'var', 'vdot', 'vectorize', 'vsplit', 'vstack', 'where', 'who', 'zeros', 'zeros_like']
>>> |
How important do you think this is, given that jax currently has very few useful annotations for functions that accept and return arrays, and given the fact that it's not very clear how to effectively add such annotations? (see, e.g. the discussions at #943 & #6743) Is package-attribute-level checking worth worrying about when function-argument-level checking has no real clear solution in sight? |
If we want to have support for type hints in Jax eventually, it seems that there is no way around PSF :-)
That said, that's perfectly fine to me to mute mypy warning for Jax if that's a resolution. |
I just tried adding this to the bottom of the file: __all__ = [name for name in globals() if not name.startswith('_')] and it didn't seem to fix the issue above, despite the fact that I think this means there's no avoiding listing of every name in the For what it's worth, JAX's mypy config already silences this warning for JAX's CI: https://github.com/google/jax/blob/729b21bda612ba414aff74072219d99b69494378/mypy.ini#L7-L8 |
That seems the case. I can confirm that only explicit definition of the function names works. ## ../jax/numpy/__init__.py
__all__ = [
"zeros",
]
## src/main.py
from jax import numpy as jnp
z = jnp.zeros(1) mypy --python-executable "env/bin/python" --install-types --strict src
Success: no issues found in 2 source files I've also tried numpy. It works just fine with mypy strict mode. At least the latest version. You should consider enabling type checking for it. ## src/main.py
import numpy as np
z = np.zeros(1) mypy --python-executable "env/bin/python" --install-types --strict src
Success: no issues found in 2 source files pip list
Package Version Location
---------- ----------- ------------------------
numpy 1.21.1 |
I'm not sure that works... I believe it's importing |
#7606 shows what is required for this with the |
True. Eigher Numpy seems to import all functions of a private module and use |
Another option #7607, just to compare these two. |
Would it be possible to add exports to jax.config and jax.debug? I'm getting errors with mypy:
and pyright:
|
The `... as ...` form tells the type checker that the name is exported. See #7570. PiperOrigin-RevId: 671311946
The `... as ...` form tells the type checker that the name is exported. See #7570. PiperOrigin-RevId: 671318047
Mypy requires public API of a package being exported using either
__all__
orimport ... as ...
syntax.Currently, projects that are using jax need to set
--no-implicit-reexport=False
.There are some references on that matter.
Is it possible to add a little boilerplate to
__init__.py
to make it working? :-)The text was updated successfully, but these errors were encountered: