Skip to content

Commit d191927

Browse files
ghpvnistGoogle-ML-Automation
authored andcommittedMar 11, 2025·
Fix syntax error and typos for composite primitive docstring.
PiperOrigin-RevId: 735808000
1 parent 6f7ce9d commit d191927

File tree

2 files changed

+19
-16
lines changed

2 files changed

+19
-16
lines changed
 

‎docs/jax.lax.rst

+1
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ Operators
5858
clz
5959
collapse
6060
complex
61+
composite
6162
concatenate
6263
conj
6364
conv

‎jax/_src/lax/lax.py

+18-16
Original file line numberDiff line numberDiff line change
@@ -1489,14 +1489,14 @@ def composite(
14891489
):
14901490
"""Composite with semantics defined by the decomposition function.
14911491
1492-
A composite is a higher-order JAX function that encapsulates an operation mad
1492+
A composite is a higher-order JAX function that encapsulates an operation made
14931493
up (composed) of other JAX functions. The semantics of the op are implemented
14941494
by the ``decomposition`` function. In other words, the defined composite
14951495
function can be replaced with its decomposed implementation without changing
14961496
the semantics of the encapsulated operation.
14971497
14981498
The compiler can recognize specific composite operations by their ``name``,
1499-
``version``, ``kawargs``, and dtypes to emit more efficient code, potentially
1499+
``version``, ``kwargs``, and dtypes to emit more efficient code, potentially
15001500
leveraging hardware-specific instructions or optimizations. If the compiler
15011501
doesn't recognize the composite, it falls back to compiling the
15021502
``decomposition`` function.
@@ -1505,31 +1505,32 @@ def composite(
15051505
be implemented as ``sin(x) / cos(x)``. A hardware-aware compiler could
15061506
recognize the "tangent" composite and emit a single ``tangent`` instruction
15071507
instead of three separate instructions (``sin``, ``divide``, and ``cos``).
1508-
With compilers for hardwares without dedicated tangent support, it would fall
1509-
back to compiling the decomposition.
1508+
For hardware without dedicated tangent support, it would fall back to
1509+
compiling the decomposition.
15101510
1511-
This is useful for preserving high level abstraction that would otherwise be
1512-
lost while lowering which allows for easier pattern-matching in low-level IR.
1511+
This is useful for preserving high-level abstractions that would otherwise be
1512+
lost while lowering, which allows for easier pattern-matching in low-level IR.
15131513
15141514
Args:
15151515
decomposition: function that implements the semantics of the composite op.
15161516
name: name of the encapsulated operation.
15171517
version: optional int to indicate semantic changes to the composite.
15181518
15191519
Returns:
1520-
out: callable composite function. Note that positional arguments to this
1521-
function should be interpreted as inputs and keyword arguments should be
1522-
interpreted as attributes of the op. Any keyword arguments that are passed
1523-
with ``None`` as a value will be omitted from the
1524-
``composite_attributes``.
1520+
Callable: Returns a composite function. Note that positional arguments to
1521+
this function should be interpreted as inputs and keyword arguments should
1522+
be interpreted as attributes of the op. Any keyword arguments that are
1523+
passed with ``None`` as a value will be omitted from the
1524+
``composite_attributes``.
15251525
15261526
Examples:
15271527
Tangent kernel:
1528+
15281529
>>> def my_tangent_composite(x):
15291530
... return lax.composite(
1530-
... lambda x: lax.sin(x) / lax.cos(x), name='my.tangent'
1531+
... lambda x: lax.sin(x) / lax.cos(x), name="my.tangent"
15311532
... )(x)
1532-
...
1533+
>>>
15331534
>>> pi = jnp.pi
15341535
>>> x = jnp.array([0.0, pi / 4, 3 * pi / 4, pi])
15351536
>>> with jnp.printoptions(precision=3, suppress=True):
@@ -1538,9 +1539,10 @@ def composite(
15381539
[ 0. 1. -1. 0.]
15391540
[ 0. 1. -1. 0.]
15401541
1541-
The recommended way to create composites is via a decorator. Use `/` and `*`
1542-
in the function signature to be explicit about positional and keyword
1543-
arguments respectively:
1542+
The recommended way to create composites is via a decorator. Use ``/`` and
1543+
``*`` in the function signature to be explicit about positional and keyword
1544+
arguments, respectively:
1545+
15441546
>>> @partial(lax.composite, name="my.softmax")
15451547
... def my_softmax_composite(x, /, *, axis):
15461548
... return jax.nn.softmax(x, axis)

0 commit comments

Comments
 (0)
Please sign in to comment.