Skip to content

Commit

Permalink
Move JAX example to public XLA:CPU API
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 698143471
  • Loading branch information
changm authored and Google-ML-Automation committed Nov 19, 2024
1 parent 3161a28 commit 42fbd30
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 3 deletions.
7 changes: 6 additions & 1 deletion examples/jax_cpp/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,13 @@ cc_binary(
"@tsl//tsl/platform:platform_port",
"@xla//xla:literal",
"@xla//xla:literal_util",
"@xla//xla/hlo/builder:xla_computation",
"@xla//xla/hlo/ir:hlo",
"@xla//xla/pjrt:pjrt_client",
"@xla//xla/pjrt/cpu:cpu_client",
"@xla//xla/pjrt:pjrt_executable",
"@xla//xla/pjrt/plugin/xla_cpu:cpu_client_options",
"@xla//xla/pjrt/plugin/xla_cpu:xla_cpu_pjrt_client",
"@xla//xla/service:hlo_module_config",
"@xla//xla/tools:hlo_module_loader",
],
)
12 changes: 10 additions & 2 deletions examples/jax_cpp/main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,15 +36,21 @@ limitations under the License.
// }
// )

#include <functional>
#include <memory>
#include <string>
#include <vector>

#include "third_party/absl/status/statusor.h"
#include "xla/hlo/builder/xla_computation.h"
#include "xla/hlo/ir/hlo_module.h"
#include "xla/literal.h"
#include "xla/literal_util.h"
#include "xla/pjrt/cpu/cpu_client.h"
#include "xla/pjrt/pjrt_client.h"
#include "xla/pjrt/pjrt_executable.h"
#include "xla/pjrt/plugin/xla_cpu/cpu_client_options.h"
#include "xla/pjrt/plugin/xla_cpu/xla_cpu_pjrt_client.h"
#include "xla/service/hlo_module_config.h"
#include "xla/tools/hlo_module_loader.h"
#include "tsl/platform/init_main.h"
#include "tsl/platform/logging.h"
Expand All @@ -66,8 +72,10 @@ int main(int argc, char** argv) {
// Run it using JAX C++ Runtime (PJRT).

// Get a CPU client.
xla::CpuClientOptions options;
options.asynchronous = true;
std::unique_ptr<xla::PjRtClient> client =
xla::GetTfrtCpuClient(/*asynchronous=*/true).value();
xla::GetXlaPjrtCpuClient(options).value();

// Compile XlaComputation to PjRtExecutable.
xla::XlaComputation xla_computation(test_module_proto);
Expand Down

0 comments on commit 42fbd30

Please sign in to comment.