-
Notifications
You must be signed in to change notification settings - Fork 23
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[RFC] OpenXLA PJRT plugin #33
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,243 @@ | ||
## RFC: OpenXLA PJRT Plugin | ||
|
||
| Status | Proposed | | ||
| :------------ | :------------------------------------------------------ | | ||
| **RFC #** | [33](https://github.com/openxla/community/pull/33) | | ||
| **Author(s)** | Skye Wanderman-Milne (skyewm@google.com), Jieying Luo (jieying@google.com), Jacques Pienaar (jpienaar@google.com) | | ||
| **Sponsor** | Stella Laurenzo (laurenzo@google.com), James Rubin (jamesrubin@google.com) | | ||
| **Updated** | 2023-01-23 | | ||
|
||
## Objective | ||
|
||
* Framework integration of a packaged compiler and runtime solution; | ||
|
||
## Proposal | ||
|
||
* Adopt PJRT as the supported device plugin mechanism for OpenXLA; | ||
* Create new repo openxla/openxla-pjrt-plugin for the OpenXLA PJRT plugin; | ||
|
||
## Background: PJRT | ||
|
||
[PJRT](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/xla/pjrt/c/pjrt_c_api.h) | ||
is a uniform Device API that we want to add to the OpenXLA ecosystem. The long | ||
term vision for PJRT is that: (1) frameworks (TensorFlow, PyTorch, JAX, etc.) | ||
will call PJRT, which has device-specific implementations that are opaque to the | ||
frameworks; (2) each device focuses on implementing PJRT APIs, and can be opaque | ||
to the frameworks. | ||
|
||
![PJRT plugin integrating OpenXLA into ML frameworks](20230123-pjrt-plugin/frameworks.png) | ||
<p align = "center"> PJRT provides a platform independent interfacing for | ||
compilers and corresponding runtimes. </p> | ||
|
||
PJRT API will provide an easy interface with which frameworks can integrate a | ||
packaged compiler and runtime solution. It will be the supported interface that | ||
will be used by TensorFlow and JAX for all compiler and runtime integration. And | ||
as such it will be easy for other compilers and runtimes that implement the PJRT | ||
interface to integrate with these systems. | ||
|
||
## PJRT plugin mechanism goal | ||
|
||
The PJRT plugin mechanism should support the following features: | ||
|
||
* Different devices (e.g. TPU, GPU) can have different implementations | ||
(through | ||
[PJRT C API interface](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/xla/pjrt/c/pjrt_c_api.h)). | ||
* Registration of multiple PJRT plugins. | ||
* Loading multiple PJRT plugins (e.g. both CPU and TPU) in the same process. | ||
* Passing configuration values from client (e.g. JAX lib) or from json files | ||
provided by the plugin (default configs). | ||
* Plugin discovery and choosing which plugins to load. | ||
|
||
## High level plugin structure | ||
|
||
![High-level plugin structure](20230123-pjrt-plugin/plugin-structure.png) | ||
|
||
## Current Status | ||
|
||
As of Dec 14, 2022 | ||
|
||
* [LoadPjrtPlugin(plugin_name, library_path)](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/xla/stream_executor/tpu/pjrt_api.cc#L73) | ||
can be used to load a PJRT plugin. We also provide a Python method | ||
[load_pjrt_plugin](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/xla/python/xla.cc#L329) | ||
binded to it. | ||
* [GetCApiClient(plugin_name)](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/xla/pjrt/pjrt_c_api_client.cc#L1150) | ||
can be used to create a PJRT client. We also provide a Python method | ||
[get_c_api_client](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/xla/python/xla.cc#L391) | ||
binded to it. | ||
|
||
## Design Considerations | ||
|
||
### <a name="heading=h.782ksg6rl5bj"></a> Loading plugins | ||
|
||
We will provide a low level C++ method to load a PJRT plugin, which takes | ||
`library_path` and `config_values` as inputs. A python method `load_pjrt_plugin` | ||
binded to it will be provided as well. | ||
|
||
```c++ | ||
Status LoadPjrtPlugin(string library_path, map<string, string> config_values) { | ||
library_handle = dlopen(library_path); | ||
function_prt = dlsym(library_handle, "GetPjrtApi"); | ||
if (function_prt != nullptr) { | ||
PJRT_Api* api = function_prt(); | ||
plugin_name = parse(library_path); | ||
if (auto it = global_pjrt_plugin_map.find(plugin_name); | ||
it != global_pjrt_plugin_map.end()) { | ||
// Allows to the same plugin multiple times. | ||
return; | ||
} | ||
PluginInfo plugin_info(api, config_values); | ||
global_pjrt_plugin_map[plugin_name] = plugin_info; | ||
} | ||
} | ||
``` | ||
|
||
* `GetPjrtApi` returns a `PJRT_Api*` pointer that contains the implementation | ||
of the C APIs. | ||
* `global_pjrt_plugin_map` is a global `map<plugin_name, PluginInfo>` with the | ||
same lifetime as the program | ||
* `plugin_name` has the same meaning as `platform_name` in JAX. | ||
* To be able to store the config values, and be future-proofed for other | ||
information such as version, we propose to use a class PluginInfo. This | ||
class is immutable after constructed, and contains getter method for | ||
`PJRT_Api*` and config values. | ||
|
||
### Discovering and automatic loading plugins | ||
|
||
To allow framework users to pip-install plugins without requiring further code | ||
changes, we'll implement a Python discovery mechanism to automatically find and | ||
load plugins. The plugin discovery will be based on the | ||
[naming convention of the Python module](https://packaging.python.org/en/latest/guides/creating-and-discovering-plugins/#using-naming-convention), | ||
and full paths set in the environment variable `PJRT_PLUGIN_LIBRARY_PATH` which | ||
is added to allow users to manually specify directories and/or .so files). For | ||
modules found, it will be imported. | ||
|
||
* Tentative naming convention for .so/json files: | ||
"pjrt-plugin-<plugin_name>.so" or "pjrt-plugin-<plugin_name>.json". | ||
|
||
There are two options to automatically load plugins (decision has not been made | ||
yet): | ||
|
||
Option 1: Plugin module updates PJRT_PLUGIN_LIBRARY_PATH on import. A python | ||
method `load_pjrt_plugins` will be added to discover .so/json files related to | ||
PJRT plugins and load them. | ||
|
||
```python | ||
def load_pjrt_plugins(): | ||
for directory in env[PJRT_PLUGIN_LIBRARY_PATH]: | ||
if json file: | ||
Library_path, config_values = parse json file | ||
load_pjrt_plugin(library_path, config_values) | ||
elif .so file: | ||
load_pjrt_plugin(.so file path) | ||
``` | ||
|
||
Option 2: Plugin module is responsible for calling | ||
[load_pjrt_plugin](#heading=h.782ksg6rl5bj) with its default options on import. | ||
|
||
Open questions: | ||
|
||
* Are there requirements about what file should not be loaded? | ||
|
||
### <a name="heading=h.396bmv8gkskz"></a> Create PJRT client(s) | ||
|
||
Frameworks decide which PJRT clients to create and use. For example, a framework | ||
can create PJRT clients for all the plugins loaded. It can also choose to only | ||
create a subset of PJRT clients based on some priorities rules. It can also come | ||
from the user configuration. | ||
|
||
Two python methods binding to C++ methods will be provided to facilitate | ||
creating PJRT clients: | ||
|
||
1. `get_loaded_plugin_names()` which gets all loaded plugin names from | ||
`global_pjrt_plugin_map`. | ||
2. `create_pjrt_client(plugin_name)` which creates a PJRT C API client (similar | ||
to | ||
[GetCApiClient](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/xla/pjrt/pjrt_c_api_client.cc#L1150)). | ||
It will retrieve the PluginInfo stored in `global_pjrt_plugin_map` and run | ||
`PJRT_Client_Create` (see | ||
[Config values for creating PJRT client(s)](#heading=h.bjuf0soco0sj) | ||
section). | ||
|
||
Open questions: | ||
|
||
* What about plugin initialization that are not related to creating a PJRT | ||
client? For example, these two functions in | ||
[tpu_initializer_helper.cc](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/xla/stream_executor/tpu/tpu_initializer_helper.cc#L269-L270) | ||
should be run when initializing TPU. Currently they are called every time a | ||
PJRT TPU client is created. One solution is to place this initizliation in | ||
method `GetPjrtApi`. It will run the first time that a plugin is loaded. | ||
* It is not desired to create PJRT clients for every plugin that is found as | ||
that will increase resource utilization. Framework decides what behavior | ||
they prefer. | ||
* We may need to store loaded PluginInfo in a nested map `map<device_type, | ||
<plugin_name, PluginInfo>>` if the framework wants to know what plugins | ||
are available for a device (e.g. current PJRT GPU and IREE GPU) and only | ||
create one PJRT client per device. | ||
|
||
### <a name="heading=h.bjuf0soco0sj"></a> Config values for creating PJRT client(s) | ||
|
||
GetCApiClient will be changed to take a map of config values. This map can be | ||
passed to `PJRT_Client_Create` through | ||
[PJRT_NamedValue](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/xla/pjrt/c/pjrt_c_api.h#L210). | ||
The default config values can (1) come from the json file which will be stored | ||
in PluginInfo, or (2) plugin implementation during import. | ||
|
||
### Framework integration | ||
|
||
All the methods mentioned above can be reused across frameworks. For example, | ||
JAX lib can integrate as follows: | ||
|
||
* Call [load_pjrt_plugins](#heading=h.782ksg6rl5bj) when | ||
[initializing backends](https://github.com/google/jax/blob/a66b3dcdd378b723e275a19258de826260b4c83e/jax/_src/lib/xla_bridge.py#L381). | ||
* Call [get_loaded_plugin_names](#heading=h.396bmv8gkskz) to get loaded PJRT | ||
`plugin_name`, have some framework specific logics to decide whether to call | ||
[create_pjrt_client](#heading=h.396bmv8gkskz) to create the PJRT client. | ||
|
||
```python | ||
def create_pjrt_clients(): | ||
loaded_plugin_names = get_loaded_plugin_names() | ||
for plugin_name in loaded_plugin_names: | ||
# Framework specific logics to decide whether to create | ||
if should_create_pjrt_client(plugin_name): | ||
pjrt_client = create_pjrt_client(plugin_name) | ||
``` | ||
|
||
For TensorFlow, discovery will be added to | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What if 2 frameworks create to clients for the same device. Is this supported? If so, it would be great to specify it. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That is up to the specific hardware and plugin implementation. If the hardware only allows exclusive access, then the software will abide by that constraint. Otherwise, some plugins may use the hardware in an exclusive way (ie. Allocate all memory). The openxla plugin that we envision now will default to allowing multiple clients and supporting on demand memory allocation (with a possible option for more greedy access as an optimization for cases that need it). There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That sound good. Thanks. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think there is more to this than greedy access as an optimisation. A Graphcore IPU can only be owned by a single context on the host. So two processes, or indeed two clients in a single process, can't share an IPU. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think that becomes a restriction to use of a Graphcore IPU then -- the PJRT API layer isn't going to do any kind of virtualization or remoting. If a software mechanism is needed to arbitrate multi-party access to a device, then that would be up to the device implementation. Side note: this is currently an issue when using Jax by default with the XLA GPU backend as it allocates all memory, effectively making it impossible to share. There are environment variable workarounds to cause it to dynamically allocate. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I agree that becomes a platform specific restriction, and I don't want a virtualisation layer. My concern is this becomes an implicit assumption in all users of this API.
My only point being that greedy or exclusive access isn't necessarily an optimisation, it can be a requirement (there's a reason in silicon for it). There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't see us doing anything in the software stack that makes multi-tenancy either harder or easier, but we will probably seek to make the default openxla implementation more user friendly on this front by default as it is a frequent pain point. There really should be one Client per process, and if there can only be one Client per system then that would limit to only launching one process for that category of devices. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Added a paragraph to summarize this discussion thread. |
||
[here](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/api_template.__init__.py#L142), | ||
and loading PJRT plugins will be added to | ||
[here](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/common_runtime/pluggable_device/pluggable_device_plugin_init.cc#L124). | ||
Depending on the plugin, the client can be created implicitly in the first time | ||
it is used, or created in a plugin custom op kernel. Created PJRT clients will | ||
be saved in the | ||
[global TensorFlow ResourceManager](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/tfrt/common/pjrt_util.h#L26). | ||
|
||
Whether multiple frameworks can create a client for the same device is up to the | ||
specific hardware and plugin implementation. If the hardware only allows | ||
exclusive access, then the software will abide by that constraint. Otherwise, | ||
some plugins may use the hardware in an exclusive way (ie. allocate all memory). | ||
The openxla plugin that we envision now will default to allowing multiple clients | ||
and supporting on demand memory allocation (with a possible option for more greedy | ||
access as an optimization for cases that need it). | ||
|
||
For more information about PJRT, please consult the PJRT | ||
[C header](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/xla/pjrt/c/pjrt_c_api.h), | ||
[C++ header](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/xla/pjrt/pjrt_client.h), | ||
[plugin mechanism design doc](https://docs.google.com/document/d/1FieBwiflD6DT9Z38wFIrnNxrUto1AOP3UuibQ4vpvEQ/edit) | ||
and | ||
[integration guide](https://docs.google.com/document/d/1EL62DkASMFZbk89K6MIPNGUFBTbkC21v5-uL5fB8VsM/edit). | ||
|
||
## Open questions and future improvments from the comments | ||
theadactyl marked this conversation as resolved.
Show resolved
Hide resolved
|
||
* Windows support. | ||
* Test suites for PJRT implementations. | ||
* How to handle versioning or variants (e.g. x86, x86-64 and ARM CPUs, different | ||
CUDA compute capability GPUs, DPUs, etc.). | ||
* How does a device broadcast its capability or supported API calls? | ||
* Support unload and reload a PJRT plugin (hot-swap/plug). | ||
* Single vs Multi-process access/ ownership (maybe covered by IRFT). | ||
* Command queues/buffer management APIs access, simple commands vs. transactional. | ||
|
||
## Initial contribution | ||
|
||
We will use the plugin in | ||
[IREE samples](https://github.com/iree-org/iree-samples/tree/main/pjrt-plugin) | ||
to seed the repo and build full PJRT support as part of OpenXLA project. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Need to check if plugin_name is already loaded and if so, do an early return.
I wouldn't return an error. Allowing to load the same pluging many times isn't wrong.
This will also allows function_prt() to do all the initialization that it needs. (related to some open questions bellow)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the suggestion! Edited.