Skip to content

Commit

Permalink
Removes enforce precision in all codebase (pt. 2 -doc-)
Browse files Browse the repository at this point in the history
  • Loading branch information
joanrue committed Aug 21, 2024
1 parent 36efdb8 commit c78773c
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 26 deletions.
6 changes: 0 additions & 6 deletions doc/api/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -411,13 +411,7 @@ pyxu.runtime

.. autosummary::

~pyxu.runtime.coerce
~pyxu.runtime.CWidth
~pyxu.runtime.enforce_precision
~pyxu.runtime.EnforcePrecision
~pyxu.runtime.getCoerceState
~pyxu.runtime.getPrecision
~pyxu.runtime.Precision
~pyxu.runtime.Width

pyxu.util
Expand Down
30 changes: 12 additions & 18 deletions doc/fair/dev_notes.rst
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,9 @@ abide by the following set of rules:

- In the case of N-D inputs, the output should have the same number of dimensions as the input.

- It should be decorated with :py:func:`~pyxu.runtime.enforce_precision`. Together with the
:py:class:`~pyxu.runtime.Precision` context manager, the former controls the numerical precision (e.g. *single*,
- It should control the numerical precision (e.g. *single*,
*double*) of the inputs/outputs. If possible, the computation performed by the method itself should also be carried
out at the user-specified precision, accessible via :py:func:`~pyxu.runtime.getPrecision`.
out at the input array's precision.

- Whenever possible, it should be compatible with the array modules supported by Pyxu. (Use
:py:func:`~pyxu.info.deps.supported_array_modules` for an up-to-date list). :py:func:`~pyxu.util.get_array_module`
Expand All @@ -39,29 +38,24 @@ As an example, consider the following code snippet, defining the median operator
import pyxu.util as pxu
class Median(pxa.Map):
def __init__(self, dim: int):
super().__init__(shape=(1, dim))
def __init__(self, dim_shape: tuple):
super().__init__(dim_shape=dim_shape, codim_shape=1)
@pxrt.enforce_precision(i="arr") # enforce input/output precision.
def apply(self, arr):
xp = pxu.get_array_module(arr) # find array module of `arr`.
return xp.median(arr, axis=-1, keepdims=True) # median() is applied to the last axis.
axis = tuple(range(-len(self.dim_shape), 0)) # apply median to core dimensions
return xp.median(arr, axis=axis, keepdims=False) # apply is vectorized to batch dimensions
This operator can then be fed various arrays as inputs:

.. code-block:: python3
import pyxu.info.deps as pxd
N = 5
op = Median(N)
N = () # batch size
dim_shape = (4, 3)
op = Median(dim_shape)
for xp in pxd.supported_array_modules():
out = op.apply(xp.arange(2*N).reshape(2, N)) # apply the operator to various array types.
If called from within the :py:class:`~pyxu.runtime.Precision` context manager, the decorated ``apply()`` method will
automatically *coerce* the input/output to the user-specified precision:

.. code-block:: python3
with pxrt.Precision(pxrt.Width.SINGLE):
out = op.apply(np.arange(N)) # float32 computation
for width in pxrt.Width:
arr = xp.random.normal(size=(N + dim_shape)).astype(width.value)
out = op.apply(arr) # apply the operator to various array types.
4 changes: 2 additions & 2 deletions src/pyxu/operator/func/norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def __init__(self, dim_shape: pxt.NDArrayShape):
)
self.lipschitz = np.inf

# prox(): runtime-coerce & vectorize
# prox(): vectorize
vectorize = pxu.vectorize(
i="arr",
dim_shape=self.dim_shape,
Expand Down Expand Up @@ -185,7 +185,7 @@ def __init__(self, dim_shape: pxt.NDArrayShape):
)
self.lipschitz = 1

# prox(): runtime-coerce & vectorize
# prox(): vectorize
vectorize = pxu.vectorize(
i="arr",
dim_shape=self.dim_shape,
Expand Down

0 comments on commit c78773c

Please sign in to comment.