Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Consider defining public API explicitly #7570

Closed
manifest opened this issue Aug 10, 2021 · 17 comments · Fixed by #7606
Closed

Consider defining public API explicitly #7570

manifest opened this issue Aug 10, 2021 · 17 comments · Fixed by #7606
Assignees
Labels
enhancement New feature or request

Comments

@manifest
Copy link

Mypy requires public API of a package being exported using either __all__ or import ... as ... syntax.

python3 -m venv env
source env/bin/activate
pip install -U pip mypy jax
python -m mypy --install-types --non-interactive --strict -c "from jax import numpy as jnp; xs = jnp.zeros(1)"

<string>:1: error: Module has no attribute "zeros"
Found 1 error in 1 file (checked 1 source file)

Currently, projects that are using jax need to set--no-implicit-reexport=False.

There are some references on that matter.

  • PEP-0484
    • Modules and variables imported into the stub are not considered exported from the stub unless the import uses the import ... as ... form or the equivalent from ... import ... as ... form.
    • However, as an exception to the previous bullet, all objects imported into a stub using from ... import * are considered exported.
  • MyPy docs

    --no-implicit-reexport always treated as enabled for stub files.

Is it possible to add a little boilerplate to __init__.py to make it working? :-)

@manifest manifest added the bug Something isn't working label Aug 10, 2021
@jakevdp
Copy link
Collaborator

jakevdp commented Aug 10, 2021

Thanks for the report – I'm a bit confused about what mypy requires here, because we currently do explicitly import zeros within the __init__.py file, which appears to be the recommendation in the docs you linked to: https://github.com/google/jax/blob/17a606a95d8f059e2a069dd1cfa8b2dfb8e93255/jax/numpy/__init__.py#L64

Do you know what we should be doing differently?

@manifest
Copy link
Author

You can just specify all these public functions in the __all__, like so:

from jax._src.numpy.lax_numpy import zeros

__all__ = [
  "zeros",
  ...
]

That will make mypy happy.

@manifest
Copy link
Author

Another option would be to write

from jax._src.numpy.lax_numpy import zeros as zeros

From the PEP-0484

Modules and variables imported into the stub are not considered exported from the stub unless the import uses the import ... as ... form or the equivalent from ... import ... as ... form. (UPDATE: To clarify, the intention here is that only names imported using the form X as X will be exported, i.e. the name before and after as must be the same.)

form X as X really means form Y import X as X as far I understand.

But, that will make pylint produce warnings. At least with default settings.

@jakevdp
Copy link
Collaborator

jakevdp commented Aug 10, 2021

Thanks - that's unfortunate that mypy requires such extensive duplication of boilerplate. The explicit imports take up 43 full lines; I don't relish the thought of writing import xyz as xyz several hundred times, nor is it appealing to have to keep two parallel lists of those imported functions in sync. We also want to avoid using __all__ along with import *, because it leaks internal APIs into the public package namespace.

Do you know if there's any other way to appease mypy here?

@manifest
Copy link
Author

Unfortunately, these two seem to be the only options :-/

@PhilipVinc
Copy link
Contributor

I find a decorator like https://stackoverflow.com/questions/41895077/export-decorator-that-manages-all useful to cut down on the boilerplate.

Maybe a similar idea could be used here?

@jakevdp
Copy link
Collaborator

jakevdp commented Aug 10, 2021

Good thought – although that particular solution requires that obj.__name__ exists and matches the name of obj in the namespace, so it wouldn't work for all entries:

In [1]: import jax.numpy as jnp
In [2]: for name in dir(jnp): 
   ...:     obj = getattr(jnp, name) 
   ...:     objname = getattr(obj, '__name__', None) 
   ...:     if objname is None: 
   ...:         print(f'jnp.{name} has no name') 
   ...:     elif objname != name: 
   ...:         print(f'jnp.{name}.__name__ = {objname}') 
   ...:                                                                                                                                                                
jnp.DeviceArray.__name__ = DeviceArrayBase
jnp.NINF has no name
jnp.NZERO has no name
jnp.PZERO has no name
jnp._NOT_IMPLEMENTED has no name
jnp.__builtins__ has no name
jnp.__cached__ has no name
jnp.__doc__ has no name
jnp.__file__ has no name
jnp.__loader__ has no name
jnp.__name__ has no name
jnp.__package__ has no name
jnp.__path__ has no name
jnp.__spec__ has no name
jnp.abs.__name__ = absolute
jnp.add_newdoc_ufunc.__name__ = _add_newdoc_ufunc
jnp.alltrue.__name__ = all
jnp.bitwise_not.__name__ = invert
jnp.c_ has no name
jnp.cdouble.__name__ = complex128
jnp.complex_.__name__ = complex128
jnp.conj.__name__ = conjugate
jnp.csingle.__name__ = complex64
jnp.cumproduct.__name__ = cumprod
jnp.degrees.__name__ = rad2deg
jnp.deprecate_with_doc.__name__ = <lambda>
jnp.divide.__name__ = true_divide
jnp.double.__name__ = float64
jnp.e has no name
jnp.empty.__name__ = zeros
jnp.empty_like.__name__ = zeros_like
jnp.euler_gamma has no name
jnp.fastCopyAndTranspose.__name__ = _fastCopyAndTranspose
jnp.fft.__name__ = jax.numpy.fft
jnp.float_.__name__ = float64
jnp.index_exp has no name
jnp.inf has no name
jnp.int_.__name__ = int64
jnp.lax_numpy.__name__ = jax._src.numpy.lax_numpy
jnp.linalg.__name__ = jax.numpy.linalg
jnp.mat.__name__ = asmatrix
jnp.max.__name__ = amax
jnp.mgrid has no name
jnp.min.__name__ = amin
jnp.mod.__name__ = remainder
jnp.nan has no name
jnp.newaxis has no name
jnp.ogrid has no name
jnp.operator_name has no name
jnp.pi has no name
jnp.product.__name__ = prod
jnp.r_ has no name
jnp.radians.__name__ = deg2rad
jnp.round.__name__ = around
jnp.round_.__name__ = around
jnp.row_stack.__name__ = vstack
jnp.s_ has no name
jnp.show_config.__name__ = show
jnp.single.__name__ = float32
jnp.sometrue.__name__ = any

@PhilipVinc
Copy link
Contributor

PhilipVinc commented Aug 11, 2021

BTW, I strongly support this change. It's annoying not being able to run mypy out of the box.

Hmm yes, that trick mainly works for functions and classes.

But since we are talking about public interface, and jax already does a very good job at having clean namespaces, why don't you just export all symbols with their current name given by dir(), ignoring everything with a leading underscore, as that signals it being internal?

Does the example below catch everything?

python
Python 3.8.11 (default, Jun 29 2021, 00:00:00) 
[GCC 11.1.1 20210531 (Red Hat 11.1.1-3)] on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> import sys
>>> def _export_public(module_name):
...     module = sys.modules[module_name]
...     if not hasattr(module, '__all__'):
...         setattr(module, '__all__', [])
...     __all__ = module.__all__
...     for name in dir(module): 
...         obj = getattr(module, name) 
...         if not name.startswith("_"):
...             __all__.append(name)
... 
>>> import jax
>>> jax.numpy.__all__
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
AttributeError: module 'jax.numpy' has no attribute '__all__'
>>> _export_public('jax.numpy')
>>> jax.numpy.__all__
['ComplexWarning', 'DeviceArray', 'NINF', 'NZERO', 'PZERO', 'abs', 'absolute', 'add', 'add_docstring', 'add_newdoc', 'add_newdoc_ufunc', 'alen', 'all', 'allclose', 'alltrue', 'amax', 'amin', 'angle', 'any', 'append', 'apply_along_axis', 'apply_over_axes', 'arange', 'arccos', 'arccosh', 'arcsin', 'arcsinh', 'arctan', 'arctan2', 'arctanh', 'argmax', 'argmin', 'argpartition', 'argsort', 'argwhere', 'around', 'array', 'array2string', 'array_equal', 'array_equiv', 'array_repr', 'array_split', 'array_str', 'asanyarray', 'asarray', 'asarray_chkfinite', 'ascontiguousarray', 'asfarray', 'asfortranarray', 'asmatrix', 'asscalar', 'atleast_1d', 'atleast_2d', 'atleast_3d', 'average', 'bartlett', 'base_repr', 'bfloat16', 'binary_repr', 'bincount', 'bitwise_and', 'bitwise_not', 'bitwise_or', 'bitwise_xor', 'blackman', 'block', 'bmat', 'bool_', 'broadcast_arrays', 'broadcast_shapes', 'broadcast_to', 'busday_count', 'busday_offset', 'byte_bounds', 'c_', 'can_cast', 'cbrt', 'cdouble', 'ceil', 'character', 'choose', 'clip', 'column_stack', 'common_type', 'compare_chararrays', 'complex128', 'complex64', 'complex_', 'complexfloating', 'compress', 'concatenate', 'conj', 'conjugate', 'convolve', 'copy', 'copysign', 'copyto', 'corrcoef', 'correlate', 'cos', 'cosh', 'count_nonzero', 'cov', 'cross', 'csingle', 'cumprod', 'cumproduct', 'cumsum', 'datetime_as_string', 'datetime_data', 'deg2rad', 'degrees', 'delete', 'deprecate', 'deprecate_with_doc', 'diag', 'diag_indices', 'diag_indices_from', 'diagflat', 'diagonal', 'diff', 'digitize', 'disp', 'divide', 'divmod', 'dot', 'double', 'dsplit', 'dstack', 'dtype', 'e', 'ediff1d', 'einsum', 'einsum_path', 'empty', 'empty_like', 'equal', 'euler_gamma', 'exp', 'exp2', 'expand_dims', 'expm1', 'extract', 'eye', 'fabs', 'fastCopyAndTranspose', 'fft', 'fill_diagonal', 'find_common_type', 'finfo', 'fix', 'flatnonzero', 'flexible', 'flip', 'fliplr', 'flipud', 'float16', 'float32', 'float64', 'float_', 'float_power', 'floating', 'floor', 'floor_divide', 'fmax', 'fmin', 'fmod', 'format_float_positional', 'format_float_scientific', 'frexp', 'frombuffer', 'fromfile', 'fromfunction', 'fromiter', 'frompyfunc', 'fromregex', 'fromstring', 'full', 'full_like', 'gcd', 'genfromtxt', 'geomspace', 'get_array_wrap', 'get_include', 'get_printoptions', 'getbufsize', 'geterr', 'geterrcall', 'geterrobj', 'gradient', 'greater', 'greater_equal', 'hamming', 'hanning', 'heaviside', 'histogram', 'histogram2d', 'histogram_bin_edges', 'histogramdd', 'hsplit', 'hstack', 'hypot', 'i0', 'identity', 'iinfo', 'imag', 'in1d', 'index_exp', 'indices', 'inexact', 'inf', 'info', 'inner', 'insert', 'int16', 'int32', 'int64', 'int8', 'int_', 'integer', 'interp', 'intersect1d', 'invert', 'is_busday', 'isclose', 'iscomplex', 'iscomplexobj', 'isfinite', 'isfortran', 'isin', 'isinf', 'isnan', 'isnat', 'isneginf', 'isposinf', 'isreal', 'isrealobj', 'isscalar', 'issctype', 'issubclass_', 'issubdtype', 'issubsctype', 'iterable', 'ix_', 'kaiser', 'kron', 'lax_numpy', 'lcm', 'ldexp', 'left_shift', 'less', 'less_equal', 'lexsort', 'linalg', 'linspace', 'load', 'loads', 'loadtxt', 'log', 'log10', 'log1p', 'log2', 'logaddexp', 'logaddexp2', 'logical_and', 'logical_not', 'logical_or', 'logical_xor', 'logspace', 'lookfor', 'mafromtxt', 'mask_indices', 'mat', 'matmul', 'max', 'maximum', 'maximum_sctype', 'may_share_memory', 'mean', 'median', 'meshgrid', 'mgrid', 'min', 'min_scalar_type', 'minimum', 'mintypecode', 'mod', 'modf', 'moveaxis', 'msort', 'multiply', 'nan', 'nan_to_num', 'nanargmax', 'nanargmin', 'nancumprod', 'nancumsum', 'nanmax', 'nanmean', 'nanmedian', 'nanmin', 'nanpercentile', 'nanprod', 'nanquantile', 'nanstd', 'nansum', 'nanvar', 'ndarray', 'ndfromtxt', 'ndim', 'negative', 'nested_iters', 'newaxis', 'nextafter', 'nonzero', 'not_equal', 'number', 'obj2sctype', 'object_', 'ogrid', 'ones', 'ones_like', 'operator_name', 'outer', 'packbits', 'pad', 'partition', 'percentile', 'pi', 'piecewise', 'place', 'poly', 'polyadd', 'polyder', 'polydiv', 'polyfit', 'polyint', 'polymul', 'polysub', 'polyval', 'positive', 'power', 'printoptions', 'prod', 'product', 'promote_types', 'ptp', 'put', 'put_along_axis', 'putmask', 'quantile', 'r_', 'rad2deg', 'radians', 'ravel', 'ravel_multi_index', 'real', 'real_if_close', 'recfromcsv', 'recfromtxt', 'reciprocal', 'remainder', 'repeat', 'require', 'reshape', 'resize', 'result_type', 'right_shift', 'rint', 'roll', 'rollaxis', 'roots', 'rot90', 'round', 'round_', 'row_stack', 's_', 'safe_eval', 'save', 'savetxt', 'savez', 'savez_compressed', 'sctype2char', 'searchsorted', 'select', 'set_numeric_ops', 'set_printoptions', 'set_string_function', 'setbufsize', 'setdiff1d', 'seterr', 'seterrcall', 'seterrobj', 'setxor1d', 'shape', 'shares_memory', 'show_config', 'sign', 'signbit', 'signedinteger', 'sin', 'sinc', 'single', 'sinh', 'size', 'sometrue', 'sort', 'sort_complex', 'source', 'spacing', 'split', 'sqrt', 'square', 'squeeze', 'stack', 'std', 'subtract', 'sum', 'swapaxes', 'take', 'take_along_axis', 'tan', 'tanh', 'tensordot', 'tile', 'trace', 'transpose', 'trapz', 'tri', 'tril', 'tril_indices', 'tril_indices_from', 'trim_zeros', 'triu', 'triu_indices', 'triu_indices_from', 'true_divide', 'trunc', 'typename', 'uint16', 'uint32', 'uint64', 'uint8', 'union1d', 'unique', 'unpackbits', 'unravel_index', 'unsignedinteger', 'unwrap', 'vander', 'var', 'vdot', 'vectorize', 'vsplit', 'vstack', 'where', 'who', 'zeros', 'zeros_like']
>>> 

@jakevdp
Copy link
Collaborator

jakevdp commented Aug 11, 2021

How important do you think this is, given that jax currently has very few useful annotations for functions that accept and return arrays, and given the fact that it's not very clear how to effectively add such annotations? (see, e.g. the discussions at #943 & #6743)

Is package-attribute-level checking worth worrying about when function-argument-level checking has no real clear solution in sight?

@manifest
Copy link
Author

If we want to have support for type hints in Jax eventually, it seems that there is no way around PSF :-)
Personally, I started to use __all__ just for that single purpose.

Union, as proposed in #943, or some Generic type seems as a valid option for Jax Arrays, but that may be fixed later.

That said, that's perfectly fine to me to mute mypy warning for Jax if that's a resolution.

@jakevdp
Copy link
Collaborator

jakevdp commented Aug 12, 2021

I just tried adding this to the bottom of the file:

__all__ = [name for name in globals() if not name.startswith('_')]

and it didn't seem to fix the issue above, despite the fact that "zeros" appears in __all__. It looks like __all__ has to be defined literally rather than programmatically.

I think this means there's no avoiding listing of every name in the __init__ file twice: once to import into the namespace, and a second time to satisfy mypy.

For what it's worth, JAX's mypy config already silences this warning for JAX's CI: https://github.com/google/jax/blob/729b21bda612ba414aff74072219d99b69494378/mypy.ini#L7-L8

@manifest
Copy link
Author

That seems the case. I can confirm that only explicit definition of the function names works.

## ../jax/numpy/__init__.py
__all__ = [
  "zeros",
]

## src/main.py
from jax import numpy as jnp

z = jnp.zeros(1)
mypy --python-executable "env/bin/python" --install-types --strict src
Success: no issues found in 2 source files

I've also tried numpy. It works just fine with mypy strict mode. At least the latest version. You should consider enabling type checking for it.

## src/main.py
import numpy as np

z = np.zeros(1)
mypy --python-executable "env/bin/python" --install-types --strict src
Success: no issues found in 2 source files
pip list
Package    Version     Location
---------- ----------- ------------------------
numpy      1.21.1

@jakevdp
Copy link
Collaborator

jakevdp commented Aug 12, 2021

I'm not sure that works... I believe it's importing func over and over, rather than the name stored within the variable func. Thou you may be able to import via importlib to do this kind of thing...

@jakevdp
Copy link
Collaborator

jakevdp commented Aug 12, 2021

#7606 shows what is required for this with the import X as X approach.

@manifest
Copy link
Author

I'm not sure that works... I believe it's importing func over and over, rather than the name stored within the variable func. Thou you may be able to import via importlib to do this kind of thing...

True. Eigher import X as X or __all__ with exact strings.

Numpy seems to import all functions of a private module and use __all__ to define its public API.

@manifest
Copy link
Author

Another option #7607, just to compare these two.

@carlosgmartin
Copy link
Contributor

Would it be possible to add exports to jax.config and jax.debug? I'm getting errors with mypy:

example.py:1: error: Module "jax.config" does not explicitly export attribute "config"  [attr-defined]
example.py:2: error: Module "jax.debug" does not explicitly export attribute "print"  [attr-defined]
example.py:2: error: Module "jax.debug" does not explicitly export attribute "callback"  [attr-defined]
example.py:2: error: Module "jax.debug" does not explicitly export attribute "breakpoint"  [attr-defined]

and pyright:

  /Users/carlos/Desktop/example.py:1:24 - error: "config" is not exported from module "jax.config"
    Import from "jax._src.config" instead (reportPrivateImportUsage)
  /Users/carlos/Desktop/example.py:2:23 - error: "print" is not exported from module "jax.debug"
    Import from "jax._src.debugging" instead (reportPrivateImportUsage)
  /Users/carlos/Desktop/example.py:2:30 - error: "callback" is not exported from module "jax.debug"
    Import from "jax._src.debugging" instead (reportPrivateImportUsage)
  /Users/carlos/Desktop/example.py:2:40 - error: "breakpoint" is not exported from module "jax.debug"
    Import from "jax._src.debugger.core" instead (reportPrivateImportUsage)

copybara-service bot pushed a commit that referenced this issue Sep 5, 2024
The `... as ...` form tells the type checker that the name is exported.
See #7570.

PiperOrigin-RevId: 671311946
copybara-service bot pushed a commit that referenced this issue Sep 5, 2024
The `... as ...` form tells the type checker that the name is exported.
See #7570.

PiperOrigin-RevId: 671318047
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
5 participants