-
Notifications
You must be signed in to change notification settings - Fork 2.9k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Implement the JAX transfer guard API
Adds `--jax_transfer_guard` flag and `jax.transfer_guard()` context manager that allows logging or disallowing unintended transfers. The API distinguishes between two types of transfers: * explicit transfers: `jax.device_put*()` and `jax.device_get()` calls. * implicit transfers: Other transfers (e.g., printing a `DeviceArray`). The transfer guard can take an action based on its guard level: * "allow": Silently allow all transfers (default; same as the previous behavior). * "log": Log and allow implicit transfers. Silently allow explicit transfers. * "disallow": Disallow implicit transfers. Silently allow explicit transfers. * "log_explicit": Log and allow all transfers. * "disallow_explicit": Disallow all transfers. The API also allows fine-control the transfer guard level of individual transfer directions. Their flag and context manager names are suffixed with the transfer direction: * "host_to_device": Converting a Python value into a `DeviceBuffer`. * "device_to_device": Copying a `DeviceBuffer` to a different device. * "device_to_host": Fetching the value of a `DeviceBuffer`. Example: ``` x = jnp.array(1) y = jnp.array(2) z = jnp.array(3) print(x) # No error with jax.transfer_guard("disallow"): print(x) # No error; x is already fetched print(jax.device_get(y)) # No error print(z) # Error! ``` PiperOrigin-RevId: 428590081
- Loading branch information
Showing
6 changed files
with
429 additions
and
18 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.