Skip to content

Commit

Permalink
Move jex.ffi to jax.ffi.
Browse files Browse the repository at this point in the history
  • Loading branch information
dfm committed Dec 20, 2024
1 parent 5031b6f commit 28687b0
Show file tree
Hide file tree
Showing 20 changed files with 493 additions and 402 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
* From {mod}`jax.interpreters.xla`, `abstractify` and `pytype_aval_mappings`
are now deprecated, having been replaced by symbols of the same name
in {mod}`jax.core`.
* The {mod}`jax.extend.ffi` submodule was moved to {mod}`jax.ffi`, and the
previous import path is deprecated.

* Deletions
* `jax_enable_memories` flag has been deleted and the behavior of that flag
Expand Down
63 changes: 31 additions & 32 deletions docs/ffi.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
"JAX's FFI support is provided in two parts:\n",
"\n",
"1. A header-only C++ library from XLA which is packaged as part of JAX as of v0.4.29 or available from the [openxla/xla](https://github.com/openxla/xla) project, and\n",
"2. A Python front end, available in the `jax.extend.ffi` submodule.\n",
"2. A Python front end, available in the `jax.ffi` submodule.\n",
"\n",
"In this tutorial we demonstrate the use of both of these components using a simple example, and then go on to discuss some lower-level extensions for more complicated use cases.\n",
"We start by presenting the FFI on CPU, and discuss generalizations to GPU or multi-device environments below.\n",
Expand Down Expand Up @@ -191,9 +191,9 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"With this compiled library in hand, we now need to register this handler with XLA via the {func}`~jax.extend.ffi.register_ffi_target` function.\n",
"With this compiled library in hand, we now need to register this handler with XLA via the {func}`~jax.ffi.register_ffi_target` function.\n",
"This function expects our handler (a function pointer to the C++ function `RmsNorm`) to be wrapped in a [`PyCapsule`](https://docs.python.org/3/c-api/capsule.html).\n",
"JAX provides a helper function {func}`~jax.extend.ffi.pycapsule` to help with this:"
"JAX provides a helper function {func}`~jax.ffi.pycapsule` to help with this:"
]
},
{
Expand All @@ -204,20 +204,19 @@
"source": [
"import ctypes\n",
"from pathlib import Path\n",
"import jax.extend as jex\n",
"\n",
"path = next(Path(\"ffi\").glob(\"librms_norm*\"))\n",
"rms_norm_lib = ctypes.cdll.LoadLibrary(path)\n",
"jex.ffi.register_ffi_target(\n",
" \"rms_norm\", jex.ffi.pycapsule(rms_norm_lib.RmsNorm), platform=\"cpu\")"
"jax.ffi.register_ffi_target(\n",
" \"rms_norm\", jax.ffi.pycapsule(rms_norm_lib.RmsNorm), platform=\"cpu\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"```{tip}\n",
"If you're familiar with the legacy \"custom call\" API, it's worth noting that you can also use {func}`~jax.extend.ffi.register_ffi_target` to register a custom call target by manually specifying the keyword argument `api_version=0`. The default `api_version` for {func}`~jax.extend.ffi.register_ffi_target` is `1`, the new \"typed\" FFI API that we're using here.\n",
"If you're familiar with the legacy \"custom call\" API, it's worth noting that you can also use {func}`~jax.ffi.register_ffi_target` to register a custom call target by manually specifying the keyword argument `api_version=0`. The default `api_version` for {func}`~jax.ffi.register_ffi_target` is `1`, the new \"typed\" FFI API that we're using here.\n",
"```\n",
"\n",
"**An alternative approach**:\n",
Expand Down Expand Up @@ -251,7 +250,7 @@
"# Assuming that we compiled a nanobind extension called `rms_norm`:\n",
"import rms_norm as rms_norm_lib\n",
"\n",
"jex.ffi.register_ffi_target(\"rms_norm\", rms_norm_lib.rms_norm(), platform=\"cpu\")\n",
"jax.ffi.register_ffi_target(\"rms_norm\", rms_norm_lib.rms_norm(), platform=\"cpu\")\n",
"```"
]
},
Expand All @@ -261,7 +260,7 @@
"source": [
"## Frontend code\n",
"\n",
"Now that we have registered our FFI handler, it is straightforward to call our C++ library from JAX using the {func}`~jax.extend.ffi.ffi_call` function:"
"Now that we have registered our FFI handler, it is straightforward to call our C++ library from JAX using the {func}`~jax.ffi.ffi_call` function:"
]
},
{
Expand All @@ -282,7 +281,7 @@
" if x.dtype != jnp.float32:\n",
" raise ValueError(\"Only the float32 dtype is implemented by rms_norm\")\n",
"\n",
" call = jex.ffi.ffi_call(\n",
" call = jax.ffi.ffi_call(\n",
" # The target name must be the same string as we used to register the target\n",
" # above in `register_custom_call_target`\n",
" \"rms_norm\",\n",
Expand Down Expand Up @@ -314,25 +313,25 @@
"metadata": {},
"source": [
"This code cell includes a lot of inline comments which should explain most of what is happening here, but there are a few points that are worth explicitly highlighting.\n",
"Most of the heavy lifting here is done by the {func}`~jax.extend.ffi.ffi_call` function, which tells JAX how to call the foreign function for a particular set of inputs.\n",
"It's important to note that the first argument to {func}`~jax.extend.ffi.ffi_call` must be a string that matches the target name that we used when calling `register_custom_call_target` above.\n",
"Most of the heavy lifting here is done by the {func}`~jax.ffi.ffi_call` function, which tells JAX how to call the foreign function for a particular set of inputs.\n",
"It's important to note that the first argument to {func}`~jax.ffi.ffi_call` must be a string that matches the target name that we used when calling `register_custom_call_target` above.\n",
"\n",
"Any attributes (defined using `Attr` in the C++ wrapper above) should be passed as keyword arguments to {func}`~jax.extend.ffi.ffi_call`.\n",
"Any attributes (defined using `Attr` in the C++ wrapper above) should be passed as keyword arguments to {func}`~jax.ffi.ffi_call`.\n",
"Note that we explicitly cast `eps` to `np.float32` because our FFI library expects a C `float`, and we can't use `jax.numpy` here, because these parameters must be static arguments.\n",
"\n",
"The `vmap_method` argument to {func}`~jax.extend.ffi.ffi_call` defines how this FFI call interacts with {func}`~jax.vmap` as described next.\n",
"The `vmap_method` argument to {func}`~jax.ffi.ffi_call` defines how this FFI call interacts with {func}`~jax.vmap` as described next.\n",
"\n",
"```{tip}\n",
"If you are familiar with the earlier \"custom call\" interface, you might be surprised that we're not passing the problem dimensions as parameters (batch size, etc.) to {func}`~jax.extend.ffi.ffi_call`.\n",
"If you are familiar with the earlier \"custom call\" interface, you might be surprised that we're not passing the problem dimensions as parameters (batch size, etc.) to {func}`~jax.ffi.ffi_call`.\n",
"In this earlier API, the backend had no mechanism for receiving metadata about the input arrays, but since the FFI includes dimension information with the `Buffer` objects, we no longer need to compute this using Python when lowering.\n",
"One major perk of this change is {func}`~jax.extend.ffi.ffi_call` can support some simple {func}`~jax.vmap` semantics out of the box, as discussed below.\n",
"One major perk of this change is {func}`~jax.ffi.ffi_call` can support some simple {func}`~jax.vmap` semantics out of the box, as discussed below.\n",
"```\n",
"\n",
"(ffi-call-vmap)=\n",
"### Batching with `vmap`\n",
"\n",
"{func}`~jax.extend.ffi.ffi_call` supports some simple {func}`~jax.vmap` semantics out of the box using the `vmap_method` parameter.\n",
"The docs for {func}`~jax.pure_callback` provide more details about the `vmap_method` parameter, and the same behavior applies to {func}`~jax.extend.ffi.ffi_call`.\n",
"{func}`~jax.ffi.ffi_call` supports some simple {func}`~jax.vmap` semantics out of the box using the `vmap_method` parameter.\n",
"The docs for {func}`~jax.pure_callback` provide more details about the `vmap_method` parameter, and the same behavior applies to {func}`~jax.ffi.ffi_call`.\n",
"\n",
"The simplest `vmap_method` is `\"sequential\"`.\n",
"In this case, when `vmap`ped, an `ffi_call` will be rewritten as a {func}`~jax.lax.scan` with the `ffi_call` in the body.\n",
Expand Down Expand Up @@ -395,7 +394,7 @@
"outputs": [],
"source": [
"def rms_norm_sequential(x, eps=1e-5):\n",
" return jex.ffi.ffi_call(\n",
" return jax.ffi.ffi_call(\n",
" \"rms_norm\",\n",
" jax.ShapeDtypeStruct(x.shape, x.dtype),\n",
" vmap_method=\"sequential\",\n",
Expand All @@ -418,9 +417,9 @@
"source": [
"### Differentiation\n",
"\n",
"Unlike with batching, {func}`~jax.extend.ffi.ffi_call` doesn't provide any default support for automatic differentiation (AD) of foreign functions.\n",
"Unlike with batching, {func}`~jax.ffi.ffi_call` doesn't provide any default support for automatic differentiation (AD) of foreign functions.\n",
"As far as JAX is concerned, the foreign function is a black box that can't be inspected to determine the appropriate behavior when differentiated.\n",
"Therefore, it is the {func}`~jax.extend.ffi.ffi_call` user's responsibility to define a custom derivative rule.\n",
"Therefore, it is the {func}`~jax.ffi.ffi_call` user's responsibility to define a custom derivative rule.\n",
"\n",
"More details about custom derivative rules can be found in the [custom derivatives tutorial](https://jax.readthedocs.io/en/latest/notebooks/Custom_derivative_rules_for_Python_code.html), but the most common pattern used for implementing differentiation for foreign functions is to define a {func}`~jax.custom_vjp` which itself calls a foreign function.\n",
"In this case, we actually define two new FFI calls:\n",
Expand All @@ -429,7 +428,7 @@
"2. `rms_norm_bwd` takes the residuals and the output co-tangents, and returns the input co-tangents.\n",
"\n",
"We won't get into the details of the RMS normalization backwards pass, but take a look at the [C++ source code](https://github.com/jax-ml/jax/blob/main/examples/ffi/src/jax_ffi_example/rms_norm.cc) to see how these functions are implemented on the back end.\n",
"The main point to emphasize here is that the \"residual\" computed has a different shape than the primal output, therefore, in the {func}`~jax.extend.ffi.ffi_call` to `res_norm_fwd`, the output type has two elements with different shapes.\n",
"The main point to emphasize here is that the \"residual\" computed has a different shape than the primal output, therefore, in the {func}`~jax.ffi.ffi_call` to `res_norm_fwd`, the output type has two elements with different shapes.\n",
"\n",
"This custom derivative rule can be wired in as follows:"
]
Expand All @@ -440,16 +439,16 @@
"metadata": {},
"outputs": [],
"source": [
"jex.ffi.register_ffi_target(\n",
" \"rms_norm_fwd\", jex.ffi.pycapsule(rms_norm_lib.RmsNormFwd), platform=\"cpu\"\n",
"jax.ffi.register_ffi_target(\n",
" \"rms_norm_fwd\", jax.ffi.pycapsule(rms_norm_lib.RmsNormFwd), platform=\"cpu\"\n",
")\n",
"jex.ffi.register_ffi_target(\n",
" \"rms_norm_bwd\", jex.ffi.pycapsule(rms_norm_lib.RmsNormBwd), platform=\"cpu\"\n",
"jax.ffi.register_ffi_target(\n",
" \"rms_norm_bwd\", jax.ffi.pycapsule(rms_norm_lib.RmsNormBwd), platform=\"cpu\"\n",
")\n",
"\n",
"\n",
"def rms_norm_fwd(x, eps=1e-5):\n",
" y, res = jex.ffi.ffi_call(\n",
" y, res = jax.ffi.ffi_call(\n",
" \"rms_norm_fwd\",\n",
" (\n",
" jax.ShapeDtypeStruct(x.shape, x.dtype),\n",
Expand All @@ -466,7 +465,7 @@
" assert res.shape == ct.shape[:-1]\n",
" assert x.shape == ct.shape\n",
" return (\n",
" jex.ffi.ffi_call(\n",
" jax.ffi.ffi_call(\n",
" \"rms_norm_bwd\",\n",
" jax.ShapeDtypeStruct(ct.shape, ct.dtype),\n",
" vmap_method=\"broadcast_all\",\n",
Expand Down Expand Up @@ -533,7 +532,7 @@
"On the front end, the registration code would be updated to specify the appropriate platform:\n",
"\n",
"```python\n",
"jex.ffi.register_ffi_target(\n",
"jax.ffi.register_ffi_target(\n",
" \"rms_norm_cuda\", rms_norm_lib_cuda.rms_norm(), platform=\"CUDA\"\n",
")\n",
"```\n",
Expand All @@ -554,7 +553,7 @@
" out_type = jax.ShapeDtypeStruct(x.shape, x.dtype)\n",
"\n",
" def impl(target_name):\n",
" return lambda x: jex.ffi.ffi_call(\n",
" return lambda x: jax.ffi.ffi_call(\n",
" target_name,\n",
" out_type,\n",
" vmap_method=\"broadcast_all\",\n",
Expand Down Expand Up @@ -620,9 +619,9 @@
"This tutorial covers most of the basic steps that are required to get up and running with JAX's FFI, but advanced use cases may require more features.\n",
"We will leave these topics to future tutorials, but here are some possibly useful references:\n",
"\n",
"* **Supporting multiple dtypes**: In this tutorial's example, we restricted to only support `float32` inputs and outputs, but many use cases require supporting multiple different input types. One option to handle this is to register different FFI targets for all supported input types and then use Python to select the appropriate target for {func}`jax.extend.ffi.ffi_call` depending on the input types. But, this approach could get quickly unwieldy depending on the combinatorics of the supported cases. So it is also possible to define the C++ handler to accept `ffi::AnyBuffer` instead of `ffi::Buffer<Dtype>`. Then, the input buffer will include a `element_type()` method which can be used to define the appropriate dtype dispatching logic in the backend.\n",
"* **Supporting multiple dtypes**: In this tutorial's example, we restricted to only support `float32` inputs and outputs, but many use cases require supporting multiple different input types. One option to handle this is to register different FFI targets for all supported input types and then use Python to select the appropriate target for {func}`jax.ffi.ffi_call` depending on the input types. But, this approach could get quickly unwieldy depending on the combinatorics of the supported cases. So it is also possible to define the C++ handler to accept `ffi::AnyBuffer` instead of `ffi::Buffer<Dtype>`. Then, the input buffer will include a `element_type()` method which can be used to define the appropriate dtype dispatching logic in the backend.\n",
"\n",
"* **Sharding**: When using JAX's automatic data-dependent parallelism within {func}`~jax.jit`, FFI calls implemented using {func}`~jax.extend.ffi.ffi_call` don't have sufficient information to shard appropriately, so they result in a copy of the inputs to all devices and the FFI call gets executed on the full array on each device. To get around this limitation, you can use {func}`~jax.experimental.shard_map.shard_map` or {func}`~jax.experimental.custom_partitioning.custom_partitioning`.\n",
"* **Sharding**: When using JAX's automatic data-dependent parallelism within {func}`~jax.jit`, FFI calls implemented using {func}`~jax.ffi.ffi_call` don't have sufficient information to shard appropriately, so they result in a copy of the inputs to all devices and the FFI call gets executed on the full array on each device. To get around this limitation, you can use {func}`~jax.experimental.shard_map.shard_map` or {func}`~jax.experimental.custom_partitioning.custom_partitioning`.\n",
"\n",
"* **Stateful foreign functions**: It is also possible to use the FFI to wrap functions with associated state. There is a [low-level example included in the XLA test suite](https://github.com/openxla/xla/blob/737a7da3c5405583dc95773ac0bb11b1349fc9ea/xla/service/gpu/custom_call_test.cc#L794-L845), and a future tutorial will include more details."
]
Expand Down
Loading

0 comments on commit 28687b0

Please sign in to comment.