@@ -1489,14 +1489,14 @@ def composite(
1489
1489
):
1490
1490
"""Composite with semantics defined by the decomposition function.
1491
1491
1492
- A composite is a higher-order JAX function that encapsulates an operation mad
1492
+ A composite is a higher-order JAX function that encapsulates an operation made
1493
1493
up (composed) of other JAX functions. The semantics of the op are implemented
1494
1494
by the ``decomposition`` function. In other words, the defined composite
1495
1495
function can be replaced with its decomposed implementation without changing
1496
1496
the semantics of the encapsulated operation.
1497
1497
1498
1498
The compiler can recognize specific composite operations by their ``name``,
1499
- ``version``, ``kawargs ``, and dtypes to emit more efficient code, potentially
1499
+ ``version``, ``kwargs ``, and dtypes to emit more efficient code, potentially
1500
1500
leveraging hardware-specific instructions or optimizations. If the compiler
1501
1501
doesn't recognize the composite, it falls back to compiling the
1502
1502
``decomposition`` function.
@@ -1505,31 +1505,32 @@ def composite(
1505
1505
be implemented as ``sin(x) / cos(x)``. A hardware-aware compiler could
1506
1506
recognize the "tangent" composite and emit a single ``tangent`` instruction
1507
1507
instead of three separate instructions (``sin``, ``divide``, and ``cos``).
1508
- With compilers for hardwares without dedicated tangent support, it would fall
1509
- back to compiling the decomposition.
1508
+ For hardware without dedicated tangent support, it would fall back to
1509
+ compiling the decomposition.
1510
1510
1511
- This is useful for preserving high level abstraction that would otherwise be
1512
- lost while lowering which allows for easier pattern-matching in low-level IR.
1511
+ This is useful for preserving high- level abstractions that would otherwise be
1512
+ lost while lowering, which allows for easier pattern-matching in low-level IR.
1513
1513
1514
1514
Args:
1515
1515
decomposition: function that implements the semantics of the composite op.
1516
1516
name: name of the encapsulated operation.
1517
1517
version: optional int to indicate semantic changes to the composite.
1518
1518
1519
1519
Returns:
1520
- out: callable composite function. Note that positional arguments to this
1521
- function should be interpreted as inputs and keyword arguments should be
1522
- interpreted as attributes of the op. Any keyword arguments that are passed
1523
- with ``None`` as a value will be omitted from the
1524
- ``composite_attributes``.
1520
+ Callable: Returns a composite function. Note that positional arguments to
1521
+ this function should be interpreted as inputs and keyword arguments should
1522
+ be interpreted as attributes of the op. Any keyword arguments that are
1523
+ passed with ``None`` as a value will be omitted from the
1524
+ ``composite_attributes``.
1525
1525
1526
1526
Examples:
1527
1527
Tangent kernel:
1528
+
1528
1529
>>> def my_tangent_composite(x):
1529
1530
... return lax.composite(
1530
- ... lambda x: lax.sin(x) / lax.cos(x), name=' my.tangent'
1531
+ ... lambda x: lax.sin(x) / lax.cos(x), name=" my.tangent"
1531
1532
... )(x)
1532
- ...
1533
+ >>>
1533
1534
>>> pi = jnp.pi
1534
1535
>>> x = jnp.array([0.0, pi / 4, 3 * pi / 4, pi])
1535
1536
>>> with jnp.printoptions(precision=3, suppress=True):
@@ -1538,9 +1539,10 @@ def composite(
1538
1539
[ 0. 1. -1. 0.]
1539
1540
[ 0. 1. -1. 0.]
1540
1541
1541
- The recommended way to create composites is via a decorator. Use `/` and `*`
1542
- in the function signature to be explicit about positional and keyword
1543
- arguments respectively:
1542
+ The recommended way to create composites is via a decorator. Use ``/`` and
1543
+ ``*`` in the function signature to be explicit about positional and keyword
1544
+ arguments, respectively:
1545
+
1544
1546
>>> @partial(lax.composite, name="my.softmax")
1545
1547
... def my_softmax_composite(x, /, *, axis):
1546
1548
... return jax.nn.softmax(x, axis)
0 commit comments