From af0a915054ea4950f5c3da743a94f198fe11ce66 Mon Sep 17 00:00:00 2001 From: Abhishek Sharma Date: Sun, 26 Apr 2020 12:33:22 +0530 Subject: [PATCH 1/2] Make precision argument keyword only in jax.numpy --- jax/numpy/lax_numpy.py | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/jax/numpy/lax_numpy.py b/jax/numpy/lax_numpy.py index 0125ac50cd51..2758ff1e8b94 100644 --- a/jax/numpy/lax_numpy.py +++ b/jax/numpy/lax_numpy.py @@ -717,7 +717,7 @@ def trunc(x): return where(lax.lt(x, lax._const(x, 0)), lax.ceil(x), lax.floor(x)) -def _conv(x, y, mode, op, precision): +def _conv(x, y, mode, op, *, precision): if issubdtype(x.dtype, complexfloating) or issubdtype(y.dtype, complexfloating): raise NotImplementedError(f"{op}() does not support complex inputs") if ndim(x) != 1 or ndim(y) != 1: @@ -747,13 +747,13 @@ def _conv(x, y, mode, op, precision): @_wraps(onp.convolve, lax_description=_PRECISION_DOC) -def convolve(x, y, mode='full', precision=None): - return _conv(x, y, mode, 'convolve', precision) +def convolve(x, y, mode='full', *, precision=None): + return _conv(x, y, mode, 'convolve', precision=precision) @_wraps(onp.correlate, lax_description=_PRECISION_DOC) -def correlate(x, y, mode='valid', precision=None): - return _conv(x, y, mode, 'correlate', precision) +def correlate(x, y, mode='valid', *, precision=None): + return _conv(x, y, mode, 'correlate', precision=precision) def _normalize_float(x): @@ -2482,7 +2482,7 @@ def append(arr, values, axis=None): @_wraps(onp.dot, lax_description=_PRECISION_DOC) -def dot(a, b, precision=None): # pylint: disable=missing-docstring +def dot(a, b, *, precision=None): # pylint: disable=missing-docstring _check_arraylike("dot", a, b) a, b = _promote_dtypes(a, b) a_ndim, b_ndim = ndim(a), ndim(b) @@ -2500,7 +2500,7 @@ def dot(a, b, precision=None): # pylint: disable=missing-docstring @_wraps(onp.matmul, lax_description=_PRECISION_DOC) -def matmul(a, b, precision=None): # pylint: disable=missing-docstring +def matmul(a, b, *, precision=None): # pylint: disable=missing-docstring _check_arraylike("matmul", a, b) a_is_vec, b_is_vec = (ndim(a) == 1), (ndim(b) == 1) a = lax.reshape(a, (1,) + shape(a)) if a_is_vec else a @@ -2524,14 +2524,14 @@ def matmul(a, b, precision=None): # pylint: disable=missing-docstring @_wraps(onp.vdot, lax_description=_PRECISION_DOC) -def vdot(a, b, precision=None): +def vdot(a, b, *, precision=None): if issubdtype(_dtype(a), complexfloating): a = conj(a) return dot(a.ravel(), b.ravel(), precision=precision) @_wraps(onp.tensordot, lax_description=_PRECISION_DOC) -def tensordot(a, b, axes=2, precision=None): +def tensordot(a, b, axes=2, *, precision=None): _check_arraylike("tensordot", a, b) a_ndim = ndim(a) b_ndim = ndim(b) @@ -2576,7 +2576,7 @@ def einsum(*operands, **kwargs): operands, contractions = opt_einsum.contract_path( *operands, einsum_call=True, use_blas=True, optimize=optimize) contractions = tuple(data[:3] for data in contractions) - return _einsum(operands, contractions, precision) + return _einsum(operands, contractions, precision=precision) @_wraps(onp.einsum_path) def einsum_path(subscripts, *operands, **kwargs): @@ -2588,7 +2588,7 @@ def _removechars(s, chars): return s.translate(str.maketrans(dict.fromkeys(chars))) @partial(jit, static_argnums=(1, 2)) -def _einsum(operands, contractions, precision): +def _einsum(operands, contractions, *, precision): operands = list(_promote_dtypes(*operands)) def sum(x, axes): return lax.reduce(x, onp.array(0, x.dtype), @@ -2731,7 +2731,7 @@ def _movechars(s, src, dst): @_wraps(onp.inner, lax_description=_PRECISION_DOC) -def inner(a, b, precision=None): +def inner(a, b, *, precision=None): if ndim(a) == 0 or ndim(b) == 0: return a * b return tensordot(a, b, (-1, -1), precision=precision) From a0fa99b60c8e30eddefa9b8a7a4ea218fabb4fa5 Mon Sep 17 00:00:00 2001 From: Abhishek Sharma Date: Sun, 26 Apr 2020 14:17:13 +0530 Subject: [PATCH 2/2] Fix private functions --- jax/numpy/lax_numpy.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/jax/numpy/lax_numpy.py b/jax/numpy/lax_numpy.py index 2758ff1e8b94..dc9683dac1f9 100644 --- a/jax/numpy/lax_numpy.py +++ b/jax/numpy/lax_numpy.py @@ -717,7 +717,7 @@ def trunc(x): return where(lax.lt(x, lax._const(x, 0)), lax.ceil(x), lax.floor(x)) -def _conv(x, y, mode, op, *, precision): +def _conv(x, y, mode, op, precision): if issubdtype(x.dtype, complexfloating) or issubdtype(y.dtype, complexfloating): raise NotImplementedError(f"{op}() does not support complex inputs") if ndim(x) != 1 or ndim(y) != 1: @@ -748,12 +748,12 @@ def _conv(x, y, mode, op, *, precision): @_wraps(onp.convolve, lax_description=_PRECISION_DOC) def convolve(x, y, mode='full', *, precision=None): - return _conv(x, y, mode, 'convolve', precision=precision) + return _conv(x, y, mode, 'convolve', precision) @_wraps(onp.correlate, lax_description=_PRECISION_DOC) def correlate(x, y, mode='valid', *, precision=None): - return _conv(x, y, mode, 'correlate', precision=precision) + return _conv(x, y, mode, 'correlate', precision) def _normalize_float(x): @@ -2576,7 +2576,7 @@ def einsum(*operands, **kwargs): operands, contractions = opt_einsum.contract_path( *operands, einsum_call=True, use_blas=True, optimize=optimize) contractions = tuple(data[:3] for data in contractions) - return _einsum(operands, contractions, precision=precision) + return _einsum(operands, contractions, precision) @_wraps(onp.einsum_path) def einsum_path(subscripts, *operands, **kwargs): @@ -2588,7 +2588,7 @@ def _removechars(s, chars): return s.translate(str.maketrans(dict.fromkeys(chars))) @partial(jit, static_argnums=(1, 2)) -def _einsum(operands, contractions, *, precision): +def _einsum(operands, contractions, precision): operands = list(_promote_dtypes(*operands)) def sum(x, axes): return lax.reduce(x, onp.array(0, x.dtype),