From 8ffb8ed2e3f23df1ad14dae8bb9cfd31ec765e1e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Thu, 12 Dec 2024 00:01:05 +0100 Subject: [PATCH 01/29] organize platforms and toolchains --- deps/ReactantExtra/BUILD | 35 ------------------- deps/ReactantExtra/tools/platforms.bzl | 31 ++++++++++++++++ .../toolchains}/yggdrasil.bzl | 0 3 files changed, 31 insertions(+), 35 deletions(-) create mode 100644 deps/ReactantExtra/tools/platforms.bzl rename deps/ReactantExtra/{toolchain => tools/toolchains}/yggdrasil.bzl (100%) diff --git a/deps/ReactantExtra/BUILD b/deps/ReactantExtra/BUILD index c538bbb8a..9ea88a5da 100644 --- a/deps/ReactantExtra/BUILD +++ b/deps/ReactantExtra/BUILD @@ -315,41 +315,6 @@ cc_toolchain_config( ], ) - - -platform( - name = "darwin_x86_64", - constraint_values = [ - "@platforms//os:macos", - "@platforms//cpu:x86_64", - ], -) - - -platform( - name = "darwin_arm64", - constraint_values = [ - "@platforms//os:macos", - "@platforms//cpu:arm64", - ], -) - -platform( - name = "linux_x86_64", - constraint_values = [ - "@platforms//os:linux", - "@platforms//cpu:x86_64", - ], -) - -platform( - name = "linux_aarch64", - constraint_values = [ - "@platforms//os:linux", - "@platforms//cpu:aarch64", - ], -) - cc_library( name = "ReactantExtraLib", srcs = glob( diff --git a/deps/ReactantExtra/tools/platforms.bzl b/deps/ReactantExtra/tools/platforms.bzl new file mode 100644 index 000000000..da7dfc796 --- /dev/null +++ b/deps/ReactantExtra/tools/platforms.bzl @@ -0,0 +1,31 @@ +platform( + name = "darwin_x86_64", + constraint_values = [ + "@platforms//os:macos", + "@platforms//cpu:x86_64", + ], +) + +platform( + name = "darwin_arm64", + constraint_values = [ + "@platforms//os:macos", + "@platforms//cpu:arm64", + ], +) + +platform( + name = "linux_x86_64", + constraint_values = [ + "@platforms//os:linux", + "@platforms//cpu:x86_64", + ], +) + +platform( + name = "linux_aarch64", + constraint_values = [ + "@platforms//os:linux", + "@platforms//cpu:aarch64", + ], +) diff --git a/deps/ReactantExtra/toolchain/yggdrasil.bzl b/deps/ReactantExtra/tools/toolchains/yggdrasil.bzl similarity index 100% rename from deps/ReactantExtra/toolchain/yggdrasil.bzl rename to deps/ReactantExtra/tools/toolchains/yggdrasil.bzl From c59b7c48b864c0a45f93a55f8470759c3d30af2f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Thu, 12 Dec 2024 15:15:49 +0100 Subject: [PATCH 02/29] hardcode libcxxwrap_julia path --- deps/ReactantExtra/BUILD | 5 +++++ .../tools/toolchains/{yggdrasil.bzl => yggdrasil/BUILD} | 0 2 files changed, 5 insertions(+) rename deps/ReactantExtra/tools/toolchains/{yggdrasil.bzl => yggdrasil/BUILD} (100%) diff --git a/deps/ReactantExtra/BUILD b/deps/ReactantExtra/BUILD index 9ea88a5da..d80faa11e 100644 --- a/deps/ReactantExtra/BUILD +++ b/deps/ReactantExtra/BUILD @@ -339,6 +339,9 @@ cc_library( hdrs = glob([ "*.h", ]), + include_prefix = [ + "/Users/mofeing/.julia/artifacts/6a1f8b0d254a485be750499b732b476ddbee44c5/include/" + ], copts = [ "-Werror=unused-variable", "-Werror=unused-but-set-variable", @@ -391,6 +394,8 @@ cc_library( "-Wl,-exported_symbol,_ifrt_*", "-Wl,-exported_symbol,_RegisterCustomCallTarget", "-Wl,-exported_symbol,_ConvertLLVMToMLIR", +"-L/Users/mofeing/.julia/artifacts/6a1f8b0d254a485be750499b732b476ddbee44c5/lib/", +"-llibcxxwrap_julia.0.14.0.dylib", ]}), deps = [ "@enzyme//:EnzymeMLIR", diff --git a/deps/ReactantExtra/tools/toolchains/yggdrasil.bzl b/deps/ReactantExtra/tools/toolchains/yggdrasil/BUILD similarity index 100% rename from deps/ReactantExtra/tools/toolchains/yggdrasil.bzl rename to deps/ReactantExtra/tools/toolchains/yggdrasil/BUILD From e142d080dbe91cd63d70649edcc97b5ea5c9e6f1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Thu, 12 Dec 2024 15:17:25 +0100 Subject: [PATCH 03/29] format code --- deps/ReactantExtra/BUILD | 119 ++++++++++++++++++++------------------- 1 file changed, 60 insertions(+), 59 deletions(-) diff --git a/deps/ReactantExtra/BUILD b/deps/ReactantExtra/BUILD index d80faa11e..d749fcc34 100644 --- a/deps/ReactantExtra/BUILD +++ b/deps/ReactantExtra/BUILD @@ -352,53 +352,54 @@ cc_library( alwayslink = True, linkstatic = True, linkopts = select({ - "//conditions:default": [], - "@bazel_tools//src/conditions:darwin": [ -"-Wl,-exported_symbol,_stablehlo*", -"-Wl,-exported_symbol,_mlir*", -"-Wl,-exported_symbol,_InitializeLogs", -"-Wl,-exported_symbol,_SetLogLevel", -"-Wl,-exported_symbol,_SetModuleLogLevel", -"-Wl,-exported_symbol,_GetDefaultTargetTriple", -"-Wl,-exported_symbol,_enzymeActivityAttrGet", -"-Wl,-exported_symbol,_MakeCPUClient", -"-Wl,-exported_symbol,_MakeGPUClient", -"-Wl,-exported_symbol,_MakeTPUClient", -"-Wl,-exported_symbol,_LoadPjrtPlugin", -"-Wl,-exported_symbol,_InitializePjrtPlugin", -"-Wl,-exported_symbol,_GetCApiClient", -"-Wl,-exported_symbol,_ClientNumDevices", -"-Wl,-exported_symbol,_ClientNumAddressableDevices", -"-Wl,-exported_symbol,_ClientProcessIndex", -"-Wl,-exported_symbol,_ClientGetDevice", -"-Wl,-exported_symbol,_ClientGetAddressableDevice", -"-Wl,-exported_symbol,_ExecutableFree", -"-Wl,-exported_symbol,_BufferToDevice", -"-Wl,-exported_symbol,_BufferToClient", -"-Wl,-exported_symbol,_DeviceToClient", -"-Wl,-exported_symbol,_PjRtBufferFree", -"-Wl,-exported_symbol,_UnsafeBufferPointer", -"-Wl,-exported_symbol,_ArrayFromHostBuffer", -"-Wl,-exported_symbol,_BufferOnCPU", -"-Wl,-exported_symbol,_CopyBufferToDevice", -"-Wl,-exported_symbol,_BufferToHost", -"-Wl,-exported_symbol,_FreeClient", -"-Wl,-exported_symbol,_ClientCompile", -"-Wl,-exported_symbol,_LinkInModule", -"-Wl,-exported_symbol,_FreeFuture", -"-Wl,-exported_symbol,_FutureIsReady", -"-Wl,-exported_symbol,_FutureAwait", -"-Wl,-exported_symbol,_XLAExecute", -"-Wl,-exported_symbol,_RegisterDialects", -"-Wl,-exported_symbol,_InitializeRegistryAndPasses", -"-Wl,-exported_symbol,_ifrt_*", -"-Wl,-exported_symbol,_RegisterCustomCallTarget", -"-Wl,-exported_symbol,_ConvertLLVMToMLIR", -"-L/Users/mofeing/.julia/artifacts/6a1f8b0d254a485be750499b732b476ddbee44c5/lib/", -"-llibcxxwrap_julia.0.14.0.dylib", - ]}), + "//conditions:default": [], + "@bazel_tools//src/conditions:darwin": [ + "-Wl,-exported_symbol,_stablehlo*", + "-Wl,-exported_symbol,_mlir*", + "-Wl,-exported_symbol,_InitializeLogs", + "-Wl,-exported_symbol,_SetLogLevel", + "-Wl,-exported_symbol,_SetModuleLogLevel", + "-Wl,-exported_symbol,_GetDefaultTargetTriple", + "-Wl,-exported_symbol,_enzymeActivityAttrGet", + "-Wl,-exported_symbol,_MakeCPUClient", + "-Wl,-exported_symbol,_MakeGPUClient", + "-Wl,-exported_symbol,_MakeTPUClient", + "-Wl,-exported_symbol,_LoadPjrtPlugin", + "-Wl,-exported_symbol,_InitializePjrtPlugin", + "-Wl,-exported_symbol,_GetCApiClient", + "-Wl,-exported_symbol,_ClientNumDevices", + "-Wl,-exported_symbol,_ClientNumAddressableDevices", + "-Wl,-exported_symbol,_ClientProcessIndex", + "-Wl,-exported_symbol,_ClientGetDevice", + "-Wl,-exported_symbol,_ClientGetAddressableDevice", + "-Wl,-exported_symbol,_ExecutableFree", + "-Wl,-exported_symbol,_BufferToDevice", + "-Wl,-exported_symbol,_BufferToClient", + "-Wl,-exported_symbol,_DeviceToClient", + "-Wl,-exported_symbol,_PjRtBufferFree", + "-Wl,-exported_symbol,_UnsafeBufferPointer", + "-Wl,-exported_symbol,_ArrayFromHostBuffer", + "-Wl,-exported_symbol,_BufferOnCPU", + "-Wl,-exported_symbol,_CopyBufferToDevice", + "-Wl,-exported_symbol,_BufferToHost", + "-Wl,-exported_symbol,_FreeClient", + "-Wl,-exported_symbol,_ClientCompile", + "-Wl,-exported_symbol,_LinkInModule", + "-Wl,-exported_symbol,_FreeFuture", + "-Wl,-exported_symbol,_FutureIsReady", + "-Wl,-exported_symbol,_FutureAwait", + "-Wl,-exported_symbol,_XLAExecute", + "-Wl,-exported_symbol,_RegisterDialects", + "-Wl,-exported_symbol,_InitializeRegistryAndPasses", + "-Wl,-exported_symbol,_ifrt_*", + "-Wl,-exported_symbol,_RegisterCustomCallTarget", + "-Wl,-exported_symbol,_ConvertLLVMToMLIR", + "-L/Users/mofeing/.julia/artifacts/6a1f8b0d254a485be750499b732b476ddbee44c5/lib/", + "-llibcxxwrap_julia.0.14.0.dylib", + ] + }), deps = [ - "@enzyme//:EnzymeMLIR", + "@enzyme//:EnzymeMLIR", "@llvm-project//mlir:AffineDialect", "@llvm-project//mlir:AllPassesAndDialects", "@llvm-project//mlir:ArithDialect", @@ -446,8 +447,8 @@ cc_library( "@xla//xla/service/cpu:cpu_transfer_manager", "@xla//xla/pjrt/gpu:se_gpu_pjrt_client", - - "@xla//xla/tsl/protobuf:protos_all_cc_impl", + + "@xla//xla/tsl/protobuf:protos_all_cc_impl", "@xla//xla/tsl/framework:allocator_registry_impl", "@xla//xla/pjrt:status_casters", @@ -456,7 +457,7 @@ cc_library( "@xla//xla/python/ifrt/hlo:hlo_program", "@xla//xla/ffi:call_frame", "@com_google_protobuf//:protobuf", - "@xla//xla/tsl/profiler/backends/cpu:annotation_stack_impl", + "@xla//xla/tsl/profiler/backends/cpu:annotation_stack_impl", "@xla//xla/tsl/profiler/backends/cpu:traceme_recorder_impl", "@xla//xla/tsl/profiler/utils:time_utils_impl", "@tsl//tsl/platform:env_impl", @@ -469,16 +470,16 @@ cc_library( "@com_google_absl//absl/log:globals", "@llvm-project//mlir:CAPIIRObjects", ] + select({ - "@xla//xla/tsl:is_cuda_enabled_and_oss":[ - "@xla//xla/stream_executor/cuda:all_runtime", - "@xla//xla/service/gpu/model:hlo_op_profiles", - "@xla//xla/service/gpu/model:hlo_op_profile_proto_cc_impl", - "@xla//xla/service/gpu:nvptx_compiler", - "@xla//xla/service/gpu:amdgpu_compiler", - "@xla//xla/service/gpu:gpu_transfer_manager", - "@xla//xla/stream_executor:kernel", - ], - "//conditions:default": [], + "@xla//xla/tsl:is_cuda_enabled_and_oss":[ + "@xla//xla/stream_executor/cuda:all_runtime", + "@xla//xla/service/gpu/model:hlo_op_profiles", + "@xla//xla/service/gpu/model:hlo_op_profile_proto_cc_impl", + "@xla//xla/service/gpu:nvptx_compiler", + "@xla//xla/service/gpu:amdgpu_compiler", + "@xla//xla/service/gpu:gpu_transfer_manager", + "@xla//xla/stream_executor:kernel", + ], + "//conditions:default": [], }), ) From 8c36fd5a1086b096f7f855018458fdccefcd17da Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Thu, 12 Dec 2024 15:17:56 +0100 Subject: [PATCH 04/29] add CxxWrap dependency --- Project.toml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/Project.toml b/Project.toml index 9a1277d9d..e3d883c05 100644 --- a/Project.toml +++ b/Project.toml @@ -6,6 +6,7 @@ version = "0.2.10" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" CEnum = "fa961155-64e5-5f13-b03f-caf6b980ea82" +CxxWrap = "1f15a43c-97ca-5a2a-ae31-89f07a497df4" Downloads = "f43a241f-c20a-4ad4-852c-f6b1247861c6" Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" @@ -42,6 +43,7 @@ AbstractFFTs = "1.5" Adapt = "4" ArrayInterface = "7.10" CEnum = "0.4, 0.5" +CxxWrap = "0.16.0" Downloads = "1.6" Enzyme = "0.13.22" EnzymeCore = "0.8.8" From 873cf439f47843460f671d715d6afdd6b0a3aa29 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Thu, 12 Dec 2024 17:54:29 +0100 Subject: [PATCH 05/29] remove outdated hardcoded symbolic links --- deps/clang | 2 -- deps/clang++ | 2 -- deps/gcc | 3 --- 3 files changed, 7 deletions(-) delete mode 100755 deps/clang delete mode 100755 deps/clang++ delete mode 100755 deps/gcc diff --git a/deps/clang b/deps/clang deleted file mode 100755 index 77df2a34c..000000000 --- a/deps/clang +++ /dev/null @@ -1,2 +0,0 @@ -#!/bin/bash -/home/wmoses/llvms/llvm16/install/bin/clang -I/usr/include/x86_64-linux-gnu/c++/11 -L/home/wmoses/llvms/llvm16/build/lib/x86_64-unknown-linux-gnu -stdlib=libc++ -v "$@" diff --git a/deps/clang++ b/deps/clang++ deleted file mode 100755 index 25b16f719..000000000 --- a/deps/clang++ +++ /dev/null @@ -1,2 +0,0 @@ -#!/bin/bash -/home/wmoses/llvms/llvm16/build/bin/clang++ -I/usr/include/x86_64-linux-gnu/c++/11 -I/usr/include/c++/11 -I/usr/include/x86_64-linux-gnu/c++/11 -L/usr/lib/x86_64-linux-gnu "$@" diff --git a/deps/gcc b/deps/gcc deleted file mode 100755 index d92c9c10d..000000000 --- a/deps/gcc +++ /dev/null @@ -1,3 +0,0 @@ -#!/bin/bash -# /usr/local/cuda/bin/nvcc "$@" -/home/wmoses/llvms/llvm16/install/bin/clang -Xclang -fcuda-allow-variadic-functions -I/usr/include/c++/11 -I/usr/include/x86_64-linux-gnu/c++/11 -Wno-unused-command-line-argument -L/usr/lib/gcc/x86_64-linux-gnu/11 -static-libstdc++ "$@" || /home/wmoses/llvms/llvm16/install/bin/clang -Xclang -fcuda-allow-variadic-functions -I/usr/include/c++/11 -I/usr/include/x86_64-linux-gnu/c++/11 -Wno-unused-command-line-argument -L/usr/lib/gcc/x86_64-linux-gnu/11 -static-libstdc++ -g0 "$@" -g0 From 94eec05fd8bf42325705b3082121b98228e84aa5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Thu, 12 Dec 2024 17:56:53 +0100 Subject: [PATCH 06/29] add third party bazel wrapper to libcxxwrap_julia --- deps/ReactantExtra/WORKSPACE | 5 +++++ deps/ReactantExtra/third_party/BUILD | 1 + .../third_party/libcxxwrap_julia/BUILD | 3 +++ .../libcxxwrap_julia/libcxxwrap_julia.BUILD | 14 ++++++++++++++ .../third_party/libcxxwrap_julia/workspace.bzl | 9 +++++++++ 5 files changed, 32 insertions(+) create mode 100644 deps/ReactantExtra/third_party/BUILD create mode 100644 deps/ReactantExtra/third_party/libcxxwrap_julia/BUILD create mode 100644 deps/ReactantExtra/third_party/libcxxwrap_julia/libcxxwrap_julia.BUILD create mode 100644 deps/ReactantExtra/third_party/libcxxwrap_julia/workspace.bzl diff --git a/deps/ReactantExtra/WORKSPACE b/deps/ReactantExtra/WORKSPACE index 174cc6715..de0ef5e96 100644 --- a/deps/ReactantExtra/WORKSPACE +++ b/deps/ReactantExtra/WORKSPACE @@ -1,5 +1,10 @@ +workspace(name = "Reactant") + load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive") +load("//third_party/libcxxwrap_julia:workspace.bzl", "libcxxwrap_julia_deps") +libcxxwrap_julia_deps() + NSYNC_COMMIT = "82b118aa7ace3132e517e2c467f8732978cf4023" NSYNC_SHA256 = "" http_archive( diff --git a/deps/ReactantExtra/third_party/BUILD b/deps/ReactantExtra/third_party/BUILD new file mode 100644 index 000000000..b2c46fa7a --- /dev/null +++ b/deps/ReactantExtra/third_party/BUILD @@ -0,0 +1 @@ +licenses(["notice"]) \ No newline at end of file diff --git a/deps/ReactantExtra/third_party/libcxxwrap_julia/BUILD b/deps/ReactantExtra/third_party/libcxxwrap_julia/BUILD new file mode 100644 index 000000000..9b495c067 --- /dev/null +++ b/deps/ReactantExtra/third_party/libcxxwrap_julia/BUILD @@ -0,0 +1,3 @@ +exports_files(srcs = [ + "workspace.bzl", +]) diff --git a/deps/ReactantExtra/third_party/libcxxwrap_julia/libcxxwrap_julia.BUILD b/deps/ReactantExtra/third_party/libcxxwrap_julia/libcxxwrap_julia.BUILD new file mode 100644 index 000000000..6a4f170ec --- /dev/null +++ b/deps/ReactantExtra/third_party/libcxxwrap_julia/libcxxwrap_julia.BUILD @@ -0,0 +1,14 @@ +# TODO libcxxwrap_julia has LICENSE.md file in share/licenses/libcxxwrap_julia/LICENSE.md +licenses(["notice"]) + +package(default_visibility = ["//visibility:public"]) + +filegroup( + name = "include", + srcs = glob(["include/**/*.h"]), +) + +filegroup( + name = "libs", + srcs = glob(["lib/libcxxwrap_julia*.dylib"]), +) diff --git a/deps/ReactantExtra/third_party/libcxxwrap_julia/workspace.bzl b/deps/ReactantExtra/third_party/libcxxwrap_julia/workspace.bzl new file mode 100644 index 000000000..280baf7bb --- /dev/null +++ b/deps/ReactantExtra/third_party/libcxxwrap_julia/workspace.bzl @@ -0,0 +1,9 @@ +"""Loads the libcxxwrap_julia library.""" + +def libcxxwrap_julia_deps(): + # TODO change this to download the real artifacts or build them from source + native.new_local_repository( + name = "libcxxwrap_julia", + path = "/Users/mofeing/.julia/artifacts/6a1f8b0d254a485be750499b732b476ddbee44c5/", + build_file = "//third_party/libcxxwrap_julia:libcxxwrap_julia.BUILD", + ) From 75ecbc311b3455157b08298b7462543c862ac1df Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Thu, 12 Dec 2024 17:57:17 +0100 Subject: [PATCH 07/29] readd platforms --- deps/ReactantExtra/BUILD | 33 +++++++++++++++++++++++++++++++++ 1 file changed, 33 insertions(+) diff --git a/deps/ReactantExtra/BUILD b/deps/ReactantExtra/BUILD index d749fcc34..02551c20a 100644 --- a/deps/ReactantExtra/BUILD +++ b/deps/ReactantExtra/BUILD @@ -315,6 +315,39 @@ cc_toolchain_config( ], ) +platform( + name = "darwin_x86_64", + constraint_values = [ + "@platforms//os:macos", + "@platforms//cpu:x86_64", + ], +) + + +platform( + name = "darwin_arm64", + constraint_values = [ + "@platforms//os:macos", + "@platforms//cpu:arm64", + ], +) + +platform( + name = "linux_x86_64", + constraint_values = [ + "@platforms//os:linux", + "@platforms//cpu:x86_64", + ], +) + +platform( + name = "linux_aarch64", + constraint_values = [ + "@platforms//os:linux", + "@platforms//cpu:aarch64", + ], +) + cc_library( name = "ReactantExtraLib", srcs = glob( From 2901505ec65fce23cf3e8bb49d28f85d89a69a02 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Thu, 12 Dec 2024 18:06:46 +0100 Subject: [PATCH 08/29] Prepare wrapping on Julia side --- src/IFRT.jl | 8 ++++++++ src/Reactant.jl | 2 ++ 2 files changed, 10 insertions(+) create mode 100644 src/IFRT.jl diff --git a/src/IFRT.jl b/src/IFRT.jl new file mode 100644 index 000000000..8e7f1a1c3 --- /dev/null +++ b/src/IFRT.jl @@ -0,0 +1,8 @@ +module IFRT + +using CxxWrap +using Reactant_jll + +@wrapmodule(() -> joinpath(Reactant_jll.libdir, :reactant_module_ifrt)) + +end diff --git a/src/Reactant.jl b/src/Reactant.jl index e7c8805de..91e74b17f 100644 --- a/src/Reactant.jl +++ b/src/Reactant.jl @@ -146,6 +146,8 @@ function Enzyme.make_zero( return res end +include("IFRT.jl") + using .Compiler: @compile, @code_hlo, @jit, traced_getfield, create_result, compile export ConcreteRArray, ConcreteRNumber, @compile, @code_hlo, @jit, @trace From 5d384b4457ae25c6481291ba06a9b4aeda2054d2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Thu, 12 Dec 2024 18:06:56 +0100 Subject: [PATCH 09/29] some small fixes --- deps/ReactantExtra/BUILD | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/deps/ReactantExtra/BUILD b/deps/ReactantExtra/BUILD index 02551c20a..f484e1575 100644 --- a/deps/ReactantExtra/BUILD +++ b/deps/ReactantExtra/BUILD @@ -353,8 +353,7 @@ cc_library( srcs = glob( [ "*.cpp", - ], - + ] ) + [ # "@com_google_protobuf//:src/google/protobuf/io/coded_stream.cc", "@xla//xla:xla.pb.cc", @@ -372,15 +371,13 @@ cc_library( hdrs = glob([ "*.h", ]), - include_prefix = [ - "/Users/mofeing/.julia/artifacts/6a1f8b0d254a485be750499b732b476ddbee44c5/include/" - ], copts = [ "-Werror=unused-variable", "-Werror=unused-but-set-variable", "-Werror=return-type", "-Werror=unused-result", - "-Wno-error=stringop-truncation" + "-Wno-error=stringop-truncation", + # "-I$(location @libcxxwrap_julia//:include)" ], alwayslink = True, linkstatic = True, @@ -432,6 +429,7 @@ cc_library( ] }), deps = [ + # "@libcxxwrap_julia", "@enzyme//:EnzymeMLIR", "@llvm-project//mlir:AffineDialect", "@llvm-project//mlir:AllPassesAndDialects", From e28a8ce4b717a788215b96219a18b44c7226b219 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Thu, 12 Dec 2024 18:22:01 +0100 Subject: [PATCH 10/29] Start prototyping IFRT module wrapper binding --- deps/ReactantExtra/API.cpp | 53 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 53 insertions(+) diff --git a/deps/ReactantExtra/API.cpp b/deps/ReactantExtra/API.cpp index 3ae7a7ebf..314ff7af3 100644 --- a/deps/ReactantExtra/API.cpp +++ b/deps/ReactantExtra/API.cpp @@ -90,6 +90,8 @@ #include "xla/python/pjrt_ifrt/pjrt_executable.h" #include "xla/python/pjrt_ifrt/pjrt_compiler.h" +#include "jlcxx/jlcxx.hpp" + using namespace mlir; using namespace llvm; using namespace xla; @@ -1482,4 +1484,55 @@ extern "C" void ifrt_pjrt_compiler_free(ifrt::PjRtCompiler* compiler) { } #pragma endregion +JLCXX_MODULE reactant_module_ifrt(jlcxx::Module& mod) { + // mod.add_type("Value") + // .method("client", &ifrt::Value::client) + // .method("get_ready_future", &ifrt::Value::GetReadyFuture) + // .method("delete!", &ifrt::Value::Delete) + // .method("isdeleted", &ifrt::Value::IsDeleted) + // .method("debug_string", &ifrt::Value::DebugString); + + mod.add_bits("DTypeKind", jlcxx::julia_type("CppEnum")); + mod.set_const("DTypeKindInvalid", ifrt::DType::Kind::kInvalid); + mod.set_const("DTypeKindPred", ifrt::DType::Kind::kPred); + mod.set_const("DTypeKindS2", ifrt::DType::Kind::kS2); + mod.set_const("DTypeKindS4", ifrt::DType::Kind::kS4); + mod.set_const("DTypeKindS8", ifrt::DType::Kind::kS8); + mod.set_const("DTypeKindS16", ifrt::DType::Kind::kS16); + mod.set_const("DTypeKindS32", ifrt::DType::Kind::kS32); + mod.set_const("DTypeKindS64", ifrt::DType::Kind::kS64); + mod.set_const("DTypeKindU2", ifrt::DType::Kind::kU2); + mod.set_const("DTypeKindU4", ifrt::DType::Kind::kU4); + mod.set_const("DTypeKindU8", ifrt::DType::Kind::kU8); + mod.set_const("DTypeKindU16", ifrt::DType::Kind::kU16); + mod.set_const("DTypeKindU32", ifrt::DType::Kind::kU32); + mod.set_const("DTypeKindU64", ifrt::DType::Kind::kU64); + mod.set_const("DTypeKindF16", ifrt::DType::Kind::kF16); + mod.set_const("DTypeKindF32", ifrt::DType::Kind::kF32); + mod.set_const("DTypeKindF64", ifrt::DType::Kind::kF64); + mod.set_const("DTypeKindBF16", ifrt::DType::Kind::kBF16); + mod.set_const("DTypeKindC64", ifrt::DType::Kind::kC64); + mod.set_const("DTypeKindC128", ifrt::DType::Kind::kC128); + mod.set_const("DTypeKindToken", ifrt::DType::Kind::kToken); + mod.set_const("DTypeKindOpaque", ifrt::DType::Kind::kOpaque); + mod.set_const("DTypeKindF8E3M4", ifrt::DType::Kind::kF8E3M4); + mod.set_const("DTypeKindF8E4M3", ifrt::DType::Kind::kF8E4M3); + mod.set_const("DTypeKindF8E4M3FN", ifrt::DType::Kind::kF8E4M3FN); + mod.set_const("DTypeKindF8E4M3B11FNUZ", ifrt::DType::Kind::kF8E4M3B11FNUZ); + mod.set_const("DTypeKindF8E4M3FNUZ", ifrt::DType::Kind::kF8E4M3FNUZ); + mod.set_const("DTypeKindF8E5M2", ifrt::DType::Kind::kF8E5M2); + mod.set_const("DTypeKindF8E5M2FNUZ", ifrt::DType::Kind::kF8E5M2FNUZ); + mod.set_const("DTypeKindString", ifrt::DType::Kind::kString); + + mod.add_type("DType") + .constructor(&ifrt::DType::DType) + .method("kind", &ifrt::DType::kind) + .method("byte_size", &ifrt::DType::byte_size) + .method("bit_size", &ifrt::DType::bit_size); + mod.set_override_module(jl_base_module); + mod.method("==", [](ifrt::DType* a, ifrt::DType* b) { return *a == *b; }); + mod.method("!=", [](ifrt::DType* a, ifrt::DType* b) { return *a != *b; }); + mod.unset_override_module(); +} + #pragma endregion From 0d76f17d0dc742be524eccf89cde8a1347fe84a8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Thu, 12 Dec 2024 19:04:46 +0100 Subject: [PATCH 11/29] first step on moving externals to modular organization --- .../third_party/build_bazel_rules_apple/BUILD | 0 .../build_bazel_rules_apple/workspace.bzl | 17 +++++++ deps/ReactantExtra/third_party/enzyme/BUILD | 0 .../third_party/enzyme/workspace.bzl | 12 +++++ .../ReactantExtra/third_party/enzyme_ad/BUILD | 0 .../third_party/enzyme_ad/workspace.bzl | 14 ++++++ deps/ReactantExtra/third_party/jax/BUILD | 0 .../third_party/jax/workspace.bzl | 18 +++++++ deps/ReactantExtra/third_party/nsync/BUILD | 0 .../third_party/nsync/workspace.bzl | 14 ++++++ deps/ReactantExtra/third_party/rules_cc/BUILD | 0 .../third_party/rules_cc/workspace.bzl | 13 +++++ deps/ReactantExtra/third_party/upb/BUILD | 0 .../third_party/upb/workspace.bzl | 15 ++++++ deps/ReactantExtra/third_party/xla/BUILD | 0 .../third_party/xla/workspace.bzl | 50 +++++++++++++++++++ 16 files changed, 153 insertions(+) create mode 100644 deps/ReactantExtra/third_party/build_bazel_rules_apple/BUILD create mode 100644 deps/ReactantExtra/third_party/build_bazel_rules_apple/workspace.bzl create mode 100644 deps/ReactantExtra/third_party/enzyme/BUILD create mode 100644 deps/ReactantExtra/third_party/enzyme/workspace.bzl create mode 100644 deps/ReactantExtra/third_party/enzyme_ad/BUILD create mode 100644 deps/ReactantExtra/third_party/enzyme_ad/workspace.bzl create mode 100644 deps/ReactantExtra/third_party/jax/BUILD create mode 100644 deps/ReactantExtra/third_party/jax/workspace.bzl create mode 100644 deps/ReactantExtra/third_party/nsync/BUILD create mode 100644 deps/ReactantExtra/third_party/nsync/workspace.bzl create mode 100644 deps/ReactantExtra/third_party/rules_cc/BUILD create mode 100644 deps/ReactantExtra/third_party/rules_cc/workspace.bzl create mode 100644 deps/ReactantExtra/third_party/upb/BUILD create mode 100644 deps/ReactantExtra/third_party/upb/workspace.bzl create mode 100644 deps/ReactantExtra/third_party/xla/BUILD create mode 100644 deps/ReactantExtra/third_party/xla/workspace.bzl diff --git a/deps/ReactantExtra/third_party/build_bazel_rules_apple/BUILD b/deps/ReactantExtra/third_party/build_bazel_rules_apple/BUILD new file mode 100644 index 000000000..e69de29bb diff --git a/deps/ReactantExtra/third_party/build_bazel_rules_apple/workspace.bzl b/deps/ReactantExtra/third_party/build_bazel_rules_apple/workspace.bzl new file mode 100644 index 000000000..1d15ee7d4 --- /dev/null +++ b/deps/ReactantExtra/third_party/build_bazel_rules_apple/workspace.bzl @@ -0,0 +1,17 @@ +"""Loads bazel rules for apple.""" + +load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive") + +def repo(): + http_archive( + name = "build_bazel_rules_apple", + sha256 = "34c41bfb59cdaea29ac2df5a2fa79e5add609c71bb303b2ebb10985f93fa20e7", + url = "https://github.com/bazelbuild/rules_apple/releases/download/3.1.1/rules_apple.3.1.1.tar.gz", + ) + + load( + "@build_bazel_rules_apple//apple:repositories.bzl", + "apple_rules_dependencies", + ) + + apple_rules_dependencies() diff --git a/deps/ReactantExtra/third_party/enzyme/BUILD b/deps/ReactantExtra/third_party/enzyme/BUILD new file mode 100644 index 000000000..e69de29bb diff --git a/deps/ReactantExtra/third_party/enzyme/workspace.bzl b/deps/ReactantExtra/third_party/enzyme/workspace.bzl new file mode 100644 index 000000000..c7db5aef4 --- /dev/null +++ b/deps/ReactantExtra/third_party/enzyme/workspace.bzl @@ -0,0 +1,12 @@ +"""Loads Enzyme.""" + +load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive") +load("@enzyme_ad//:workspace.bzl", "ENZYME_COMMIT", "ENZYME_SHA256") + +def repo(): + http_archive( + name = "enzyme", + sha256 = ENZYME_SHA256, + strip_prefix = "Enzyme-" + ENZYME_COMMIT + "/enzyme", + urls = ["https://github.com/EnzymeAD/Enzyme/archive/{commit}.tar.gz".format(commit = ENZYME_COMMIT)], + ) diff --git a/deps/ReactantExtra/third_party/enzyme_ad/BUILD b/deps/ReactantExtra/third_party/enzyme_ad/BUILD new file mode 100644 index 000000000..e69de29bb diff --git a/deps/ReactantExtra/third_party/enzyme_ad/workspace.bzl b/deps/ReactantExtra/third_party/enzyme_ad/workspace.bzl new file mode 100644 index 000000000..c2565b507 --- /dev/null +++ b/deps/ReactantExtra/third_party/enzyme_ad/workspace.bzl @@ -0,0 +1,14 @@ +"""Loads Enzyme-JAX.""" + +load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive") + +ENZYMEXLA_COMMIT = "f6587e37ff7298f2a1a273b08c24d69fca7ff30f" +ENZYMEXLA_SHA256 = "" + +def repo(): + http_archive( + name = "enzyme_ad", + sha256 = ENZYMEXLA_SHA256, + strip_prefix = "Enzyme-JAX-" + ENZYMEXLA_COMMIT, + urls = ["https://github.com/EnzymeAD/Enzyme-JAX/archive/{commit}.tar.gz".format(commit = ENZYMEXLA_COMMIT)], + ) diff --git a/deps/ReactantExtra/third_party/jax/BUILD b/deps/ReactantExtra/third_party/jax/BUILD new file mode 100644 index 000000000..e69de29bb diff --git a/deps/ReactantExtra/third_party/jax/workspace.bzl b/deps/ReactantExtra/third_party/jax/workspace.bzl new file mode 100644 index 000000000..00ce15ec9 --- /dev/null +++ b/deps/ReactantExtra/third_party/jax/workspace.bzl @@ -0,0 +1,18 @@ +"""Loads Enzyme-JAX.""" + +load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive") +load("@enzyme_ad//:workspace.bzl", "JAX_COMMIT", "JAX_SHA256") + +def repo(): + http_archive( + name = "jax", + sha256 = JAX_SHA256, + strip_prefix = "jax-" + JAX_COMMIT, + urls = ["https://github.com/google/jax/archive/{commit}.tar.gz".format(commit = JAX_COMMIT)], + patch_args = ["-p1"], + patches = ["@enzyme_ad//:patches/jax.patch"], + ) + + load("@jax//third_party/xla:workspace.bzl", jax_xla_workspace = "repo") + + jax_xla_workspace() diff --git a/deps/ReactantExtra/third_party/nsync/BUILD b/deps/ReactantExtra/third_party/nsync/BUILD new file mode 100644 index 000000000..e69de29bb diff --git a/deps/ReactantExtra/third_party/nsync/workspace.bzl b/deps/ReactantExtra/third_party/nsync/workspace.bzl new file mode 100644 index 000000000..de0aaf538 --- /dev/null +++ b/deps/ReactantExtra/third_party/nsync/workspace.bzl @@ -0,0 +1,14 @@ +"""Loads nsync.""" + +load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive") + +NSYNC_COMMIT = "82b118aa7ace3132e517e2c467f8732978cf4023" +NSYNC_SHA256 = "" + +def repo(): + http_archive( + name = "nsync", + sha256 = NSYNC_SHA256, + strip_prefix = "nsync-" + NSYNC_COMMIT, + urls = ["https://github.com/wsmoses/nsync/archive/{commit}.tar.gz".format(commit = NSYNC_COMMIT)], + ) diff --git a/deps/ReactantExtra/third_party/rules_cc/BUILD b/deps/ReactantExtra/third_party/rules_cc/BUILD new file mode 100644 index 000000000..e69de29bb diff --git a/deps/ReactantExtra/third_party/rules_cc/workspace.bzl b/deps/ReactantExtra/third_party/rules_cc/workspace.bzl new file mode 100644 index 000000000..9c54c9b8b --- /dev/null +++ b/deps/ReactantExtra/third_party/rules_cc/workspace.bzl @@ -0,0 +1,13 @@ +"""Loads bazel rules_cc.""" + +load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive") + +def repo(): + http_archive( + name = "rules_cc", + sha256 = "85723d827f080c5e927334f1fb18a294c0b3f94fee6d6b45945f5cdae6ea0fd4", + strip_prefix = "rules_cc-c8c38f8c710cbbf834283e4777916b68261b359c", + urls = [ + "https://github.com/bazelbuild/rules_cc/archive/c8c38f8c710cbbf834283e4777916b68261b359c.tar.gz", + ], + ) diff --git a/deps/ReactantExtra/third_party/upb/BUILD b/deps/ReactantExtra/third_party/upb/BUILD new file mode 100644 index 000000000..e69de29bb diff --git a/deps/ReactantExtra/third_party/upb/workspace.bzl b/deps/ReactantExtra/third_party/upb/workspace.bzl new file mode 100644 index 000000000..0b1d9eb53 --- /dev/null +++ b/deps/ReactantExtra/third_party/upb/workspace.bzl @@ -0,0 +1,15 @@ +"""Loads upb.""" + +load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive") + +def repo(): + http_archive( + name = "upb", + sha256 = "61d0417abd60e65ed589c9deee7c124fe76a4106831f6ad39464e1525cef1454", + strip_prefix = "upb-9effcbcb27f0a665f9f345030188c0b291e32482", + patch_cmds = [ + "sed -i.bak0 's/@bazel_tools\\/\\/platforms:windows/@platforms\\/\\/os:windows/g' BUILD", + "sed -i.bak0 's/-Werror//g' BUILD", + ], + url = "https://github.com/protocolbuffers/upb/archive/9effcbcb27f0a665f9f345030188c0b291e32482.tar.gz", + ) diff --git a/deps/ReactantExtra/third_party/xla/BUILD b/deps/ReactantExtra/third_party/xla/BUILD new file mode 100644 index 000000000..e69de29bb diff --git a/deps/ReactantExtra/third_party/xla/workspace.bzl b/deps/ReactantExtra/third_party/xla/workspace.bzl new file mode 100644 index 000000000..014c3844b --- /dev/null +++ b/deps/ReactantExtra/third_party/xla/workspace.bzl @@ -0,0 +1,50 @@ +"""Loads XLA.""" + +load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive") +load("@enzyme_ad//:workspace.bzl", "XLA_PATCHES") +load("@jax//third_party/xla:workspace.bzl", "XLA_COMMIT", "XLA_SHA256") + +XLA_PATCHES = XLA_PATCHES + [ + """ + sed -i.bak0 "s/__cpp_lib_hardware_interference_size/HW_INTERFERENCE_SIZE/g" xla/backends/cpu/runtime/thunk_executor.h + """, + """ + sed -i.bak0 "s/__cpp_lib_hardware_interference_size/HW_INTERFERENCE_SIZE/g" xla/stream_executor/host/host_kernel.cc + """, + """ + sed -i.bak0 "s/patch_cmds = \\[/patch_cmds = \\[\\\"find . -type f -name config.bzl -exec sed -i.bak0 's\\/HAVE_LINK_H=1\\/HAVE_LINK_H=0\\/g' {} +\\\",/g" third_party/llvm/workspace.bzl + """, + """ + sed -i.bak0 "s/patch_cmds = \\[/patch_cmds = \\[\\\"find . -type f -name config.bzl -exec sed -i.bak0 's\\/LLVM_ENABLE_THREADS=1\\/LLVM_ENABLE_THREADS=0\\/g' {} +\\\",/g" third_party/llvm/workspace.bzl + """, + """ + sed -i.bak0 "s/patch_cmds = \\[/patch_cmds = \\[\\\"find . -type f -name config.bzl -exec sed -i.bak0 's\\/HAVE_MALLINFO=1\\/DONT_HAVE_ANY_MALLINFO=0\\/g' {} +\\\",/g" third_party/llvm/workspace.bzl + """, + """ + sed -i.bak0 "s/patch_cmds = \\[/patch_cmds = \\[\\\"find . -type f -name config.bzl -exec sed -i.bak0 's\\/HAVE_PTHREAD_GETNAME_NP=1\\/FAKE_HAVE_PTHREAD_GETNAME_NP=0\\/g' {} +\\\",/g" third_party/llvm/workspace.bzl + """, + """ + sed -i.bak0 "s/patch_cmds = \\[/patch_cmds = \\[\\\"find . -type f -name config.bzl -exec sed -i.bak0 's\\/HAVE_PTHREAD_SETNAME_NP=1\\/FAKE_HAVE_PTHREAD_SETNAME_NP=0\\/g' {} +\\\",/g" third_party/llvm/workspace.bzl + """, + """ + sed -i.bak0 "s/patch_cmds = \\[/patch_cmds = \\[\\\"find . -type f -name config.h -exec sed -i.bak0 's\\/ENABLE_CRASH_OVERRIDES 1\\/ENABLE_CRASH_OVERRIDES 0\\/g' {} +\\\",/g" third_party/llvm/workspace.bzl + """, + """ + sed -i.bak0 "s/patch_cmds = \\[/patch_cmds = \\[\\\"find . -type f -name config.h -exec sed -i.bak0 's\\/HAVE_PTHREAD_GETNAME_NP\\/FAKE_HAVE_PTHREAD_GETNAME_NP\\/g' {} +\\\",/g" third_party/llvm/workspace.bzl + """, + """ + sed -i.bak0 "s/patch_cmds = \\[/patch_cmds = \\[\\\"find . -type f -name config.h -exec sed -i.bak0 's\\/HAVE_PTHREAD_SETNAME_NP\\/FAKE_HAVE_PTHREAD_SETNAME_NP\\/g' {} +\\\",/g" third_party/llvm/workspace.bzl + """, + # """ + # sed -i.bak0 "s/patch_cmds = \\[/patch_cmds = \\['find . -type f -name BUILD.bazel -exec sed -i.bak0 \\\\\\'s\\/\\\"CAPIIR\\\",\\/\\\"CAPIIR\\\",alwayslink=1,\\/g\\\\\\\\' {} +',/g" third_party/llvm/workspace.bzl + # """, +] + +def repo(): + http_archive( + name = "xla", + sha256 = XLA_SHA256, + strip_prefix = "xla-" + XLA_COMMIT, + urls = ["https://github.com/wsmoses/xla/archive/{commit}.tar.gz".format(commit = XLA_COMMIT)], + patch_cmds = XLA_PATCHES, + ) From 3cebe5de397c0a924879e709b6bd3922a2a733c1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Fri, 13 Dec 2024 01:00:22 +0100 Subject: [PATCH 12/29] refactor libcxxwrap_julia on top of `cc_import` --- .../libcxxwrap_julia/libcxxwrap_julia.BUILD | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/deps/ReactantExtra/third_party/libcxxwrap_julia/libcxxwrap_julia.BUILD b/deps/ReactantExtra/third_party/libcxxwrap_julia/libcxxwrap_julia.BUILD index 6a4f170ec..6c4bd3f2c 100644 --- a/deps/ReactantExtra/third_party/libcxxwrap_julia/libcxxwrap_julia.BUILD +++ b/deps/ReactantExtra/third_party/libcxxwrap_julia/libcxxwrap_julia.BUILD @@ -3,12 +3,9 @@ licenses(["notice"]) package(default_visibility = ["//visibility:public"]) -filegroup( - name = "include", - srcs = glob(["include/**/*.h"]), -) - -filegroup( - name = "libs", - srcs = glob(["lib/libcxxwrap_julia*.dylib"]), +cc_import( + name = "libcxxwrap_julia", + hdrs = glob(["include/jlcxx/*.hpp"]), + shared_library = "lib/libcxxwrap_julia.dylib", + visibility = ["//visibility:public"], ) From 596f96afc86a18c7f421983435c04cc30e812f49 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Fri, 13 Dec 2024 01:00:54 +0100 Subject: [PATCH 13/29] use modular workspaces --- deps/ReactantExtra/WORKSPACE | 145 +++--------------- .../build_bazel_rules_apple/workspace.bzl | 7 - .../third_party/jax/workspace.bzl | 4 - .../libcxxwrap_julia/workspace.bzl | 2 +- .../third_party/xla/workspace.bzl | 50 ++---- 5 files changed, 38 insertions(+), 170 deletions(-) diff --git a/deps/ReactantExtra/WORKSPACE b/deps/ReactantExtra/WORKSPACE index de0ef5e96..a8c606a99 100644 --- a/deps/ReactantExtra/WORKSPACE +++ b/deps/ReactantExtra/WORKSPACE @@ -2,103 +2,26 @@ workspace(name = "Reactant") load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive") -load("//third_party/libcxxwrap_julia:workspace.bzl", "libcxxwrap_julia_deps") -libcxxwrap_julia_deps() - -NSYNC_COMMIT = "82b118aa7ace3132e517e2c467f8732978cf4023" -NSYNC_SHA256 = "" -http_archive( - name = "nsync", - sha256 = NSYNC_SHA256, - strip_prefix = "nsync-" + NSYNC_COMMIT, - urls = ["https://github.com/wsmoses/nsync/archive/{commit}.tar.gz".format(commit = NSYNC_COMMIT)], -) +load("//third_party/libcxxwrap_julia:workspace.bzl", libcxxwrap_julia_workspace = "repo") +libcxxwrap_julia_workspace() -ENZYMEXLA_COMMIT = "f6587e37ff7298f2a1a273b08c24d69fca7ff30f" -ENZYMEXLA_SHA256 = "" +load("//third_party/nsync:workspace.bzl", nsync_workspace = "repo") +nsync_workspace() -http_archive( - name = "enzyme_ad", - sha256 = ENZYMEXLA_SHA256, - strip_prefix = "Enzyme-JAX-" + ENZYMEXLA_COMMIT, - urls = ["https://github.com/EnzymeAD/Enzyme-JAX/archive/{commit}.tar.gz".format(commit = ENZYMEXLA_COMMIT)], -) +load("//third_party/enzyme_ad:workspace.bzl", enzyme_ad_workspace = "repo") +enzyme_ad_workspace() -load("@enzyme_ad//:workspace.bzl", "JAX_COMMIT", "JAX_SHA256", "ENZYME_COMMIT", "ENZYME_SHA256", "XLA_PATCHES") - -XLA_PATCHES = XLA_PATCHES + [ -""" -sed -i.bak0 "s/__cpp_lib_hardware_interference_size/HW_INTERFERENCE_SIZE/g" xla/backends/cpu/runtime/thunk_executor.h -""", -""" -sed -i.bak0 "s/__cpp_lib_hardware_interference_size/HW_INTERFERENCE_SIZE/g" xla/stream_executor/host/host_kernel.cc -""", -""" -sed -i.bak0 "s/patch_cmds = \\[/patch_cmds = \\[\\\"find . -type f -name config.bzl -exec sed -i.bak0 's\\/HAVE_LINK_H=1\\/HAVE_LINK_H=0\\/g' {} +\\\",/g" third_party/llvm/workspace.bzl -""", -""" -sed -i.bak0 "s/patch_cmds = \\[/patch_cmds = \\[\\\"find . -type f -name config.bzl -exec sed -i.bak0 's\\/LLVM_ENABLE_THREADS=1\\/LLVM_ENABLE_THREADS=0\\/g' {} +\\\",/g" third_party/llvm/workspace.bzl -""", -""" -sed -i.bak0 "s/patch_cmds = \\[/patch_cmds = \\[\\\"find . -type f -name config.bzl -exec sed -i.bak0 's\\/HAVE_MALLINFO=1\\/DONT_HAVE_ANY_MALLINFO=0\\/g' {} +\\\",/g" third_party/llvm/workspace.bzl -""", -""" -sed -i.bak0 "s/patch_cmds = \\[/patch_cmds = \\[\\\"find . -type f -name config.bzl -exec sed -i.bak0 's\\/HAVE_PTHREAD_GETNAME_NP=1\\/FAKE_HAVE_PTHREAD_GETNAME_NP=0\\/g' {} +\\\",/g" third_party/llvm/workspace.bzl -""", -""" -sed -i.bak0 "s/patch_cmds = \\[/patch_cmds = \\[\\\"find . -type f -name config.bzl -exec sed -i.bak0 's\\/HAVE_PTHREAD_SETNAME_NP=1\\/FAKE_HAVE_PTHREAD_SETNAME_NP=0\\/g' {} +\\\",/g" third_party/llvm/workspace.bzl -""", -""" -sed -i.bak0 "s/patch_cmds = \\[/patch_cmds = \\[\\\"find . -type f -name config.h -exec sed -i.bak0 's\\/ENABLE_CRASH_OVERRIDES 1\\/ENABLE_CRASH_OVERRIDES 0\\/g' {} +\\\",/g" third_party/llvm/workspace.bzl -""", -""" -sed -i.bak0 "s/patch_cmds = \\[/patch_cmds = \\[\\\"find . -type f -name config.h -exec sed -i.bak0 's\\/HAVE_PTHREAD_GETNAME_NP\\/FAKE_HAVE_PTHREAD_GETNAME_NP\\/g' {} +\\\",/g" third_party/llvm/workspace.bzl -""", -""" -sed -i.bak0 "s/patch_cmds = \\[/patch_cmds = \\[\\\"find . -type f -name config.h -exec sed -i.bak0 's\\/HAVE_PTHREAD_SETNAME_NP\\/FAKE_HAVE_PTHREAD_SETNAME_NP\\/g' {} +\\\",/g" third_party/llvm/workspace.bzl -""", -# """ -# sed -i.bak0 "s/patch_cmds = \\[/patch_cmds = \\['find . -type f -name BUILD.bazel -exec sed -i.bak0 \\\\\\'s\\/\\\"CAPIIR\\\",\\/\\\"CAPIIR\\\",alwayslink=1,\\/g\\\\\\\\' {} +',/g" third_party/llvm/workspace.bzl -# """, -] - -http_archive( - name = "rules_cc", - sha256 = "85723d827f080c5e927334f1fb18a294c0b3f94fee6d6b45945f5cdae6ea0fd4", - strip_prefix = "rules_cc-c8c38f8c710cbbf834283e4777916b68261b359c", - urls = [ - "https://github.com/bazelbuild/rules_cc/archive/c8c38f8c710cbbf834283e4777916b68261b359c.tar.gz", - ], -) +load("//third_party/rules_cc:workspace.bzl", rules_cc_workspace = "repo") +rules_cc_workspace() load("@rules_cc//cc:repositories.bzl", "rules_cc_dependencies") - rules_cc_dependencies() -LLVM_TARGETS = select({ - "@bazel_tools//src/conditions:windows": ["AMDGPU", "NVPTX"], - "@bazel_tools//src/conditions:darwin": [], - "//conditions:default": ["AMDGPU", "NVPTX"], -}) + ["AArch64", "X86", "ARM"] +load("//third_party/jax:workspace.bzl", jax_workspace = "repo") +jax_workspace() -http_archive( - name = "jax", - sha256 = JAX_SHA256, - strip_prefix = "jax-" + JAX_COMMIT, - urls = ["https://github.com/google/jax/archive/{commit}.tar.gz".format(commit = JAX_COMMIT)], - patch_args = ["-p1"], - patches = ["@enzyme_ad//:patches/jax.patch"], -) - -load("@jax//third_party/xla:workspace.bzl", "XLA_COMMIT", "XLA_SHA256") - -http_archive( - name = "xla", - sha256 = XLA_SHA256, - strip_prefix = "xla-" + XLA_COMMIT, - urls = ["https://github.com/wsmoses/xla/archive/{commit}.tar.gz".format(commit = XLA_COMMIT)], - patch_cmds = XLA_PATCHES -) +load("//third_party/xla:workspace.bzl", xla_workspace = "repo") +xla_workspace() load("@xla//third_party/py:python_init_rules.bzl", "python_init_rules") python_init_rules() @@ -124,49 +47,29 @@ load("@xla//third_party/py:python_init_rules.bzl", "python_init_rules") python_init_rules() load("@rules_python//python:repositories.bzl", "py_repositories") - py_repositories() load("@rules_python//python/pip_install:repositories.bzl", "pip_install_dependencies") - pip_install_dependencies() -http_archive( - name = "enzyme", - sha256 = ENZYME_SHA256, - strip_prefix = "Enzyme-" + ENZYME_COMMIT + "/enzyme", - urls = ["https://github.com/EnzymeAD/Enzyme/archive/{commit}.tar.gz".format(commit = ENZYME_COMMIT)], -) +load("//third_party/enzyme:workspace.bzl", enzyme_workspace = "repo") +enzyme_workspace() -http_archive( - name = "build_bazel_rules_apple", - sha256 = "34c41bfb59cdaea29ac2df5a2fa79e5add609c71bb303b2ebb10985f93fa20e7", - url = "https://github.com/bazelbuild/rules_apple/releases/download/3.1.1/rules_apple.3.1.1.tar.gz", -) +load("//third_party/build_bazel_rules_apple:workspace.bzl", build_bazel_rules_apple_workspace = "repo") +build_bazel_rules_apple_workspace() load( "@build_bazel_rules_apple//apple:repositories.bzl", "apple_rules_dependencies", ) - apple_rules_dependencies() - -http_archive( - name = "upb", - sha256 = "61d0417abd60e65ed589c9deee7c124fe76a4106831f6ad39464e1525cef1454", - strip_prefix = "upb-9effcbcb27f0a665f9f345030188c0b291e32482", - patch_cmds = [ - "sed -i.bak0 's/@bazel_tools\\/\\/platforms:windows/@platforms\\/\\/os:windows/g' BUILD", - "sed -i.bak0 's/-Werror//g' BUILD" - ], - url = "https://github.com/protocolbuffers/upb/archive/9effcbcb27f0a665f9f345030188c0b291e32482.tar.gz" -) +load("//third_party/upb:workspace.bzl", upb_workspace = "repo") +upb_workspace() load("@jax//third_party/xla:workspace.bzl", jax_xla_workspace = "repo") jax_xla_workspace() - load("@xla//:workspace4.bzl", "xla_workspace4") xla_workspace4() @@ -175,6 +78,12 @@ xla_workspace3() load("@xla//:workspace2.bzl", "xla_workspace2") +LLVM_TARGETS = select({ + "@bazel_tools//src/conditions:windows": ["AMDGPU", "NVPTX"], + "@bazel_tools//src/conditions:darwin": [], + "//conditions:default": ["AMDGPU", "NVPTX"], +}) + ["AArch64", "X86", "ARM"] + load("@llvm-raw//utils/bazel:configure.bzl", "llvm_configure") llvm_configure(name = "llvm-project", targets = LLVM_TARGETS) xla_workspace2() @@ -192,7 +101,6 @@ load( "@tsl//third_party/gpus/cuda/hermetic:cuda_json_init_repository.bzl", "cuda_json_init_repository", ) - cuda_json_init_repository() load( @@ -205,11 +113,9 @@ load( "cuda_redist_init_repositories", "cudnn_redist_init_repository", ) - cuda_redist_init_repositories( cuda_redistributions = CUDA_REDISTRIBUTIONS, ) - cudnn_redist_init_repository( cudnn_redistributions = CUDNN_REDISTRIBUTIONS, ) @@ -218,19 +124,16 @@ load( "@tsl//third_party/gpus/cuda/hermetic:cuda_configure.bzl", "cuda_configure", ) - cuda_configure(name = "local_config_cuda") load( "@tsl//third_party/nccl/hermetic:nccl_redist_init_repository.bzl", "nccl_redist_init_repository", ) - nccl_redist_init_repository() load( "@tsl//third_party/nccl/hermetic:nccl_configure.bzl", "nccl_configure", ) - nccl_configure(name = "local_config_nccl") diff --git a/deps/ReactantExtra/third_party/build_bazel_rules_apple/workspace.bzl b/deps/ReactantExtra/third_party/build_bazel_rules_apple/workspace.bzl index 1d15ee7d4..f4350d99a 100644 --- a/deps/ReactantExtra/third_party/build_bazel_rules_apple/workspace.bzl +++ b/deps/ReactantExtra/third_party/build_bazel_rules_apple/workspace.bzl @@ -8,10 +8,3 @@ def repo(): sha256 = "34c41bfb59cdaea29ac2df5a2fa79e5add609c71bb303b2ebb10985f93fa20e7", url = "https://github.com/bazelbuild/rules_apple/releases/download/3.1.1/rules_apple.3.1.1.tar.gz", ) - - load( - "@build_bazel_rules_apple//apple:repositories.bzl", - "apple_rules_dependencies", - ) - - apple_rules_dependencies() diff --git a/deps/ReactantExtra/third_party/jax/workspace.bzl b/deps/ReactantExtra/third_party/jax/workspace.bzl index 00ce15ec9..53e0cac75 100644 --- a/deps/ReactantExtra/third_party/jax/workspace.bzl +++ b/deps/ReactantExtra/third_party/jax/workspace.bzl @@ -12,7 +12,3 @@ def repo(): patch_args = ["-p1"], patches = ["@enzyme_ad//:patches/jax.patch"], ) - - load("@jax//third_party/xla:workspace.bzl", jax_xla_workspace = "repo") - - jax_xla_workspace() diff --git a/deps/ReactantExtra/third_party/libcxxwrap_julia/workspace.bzl b/deps/ReactantExtra/third_party/libcxxwrap_julia/workspace.bzl index 280baf7bb..0a7063e0f 100644 --- a/deps/ReactantExtra/third_party/libcxxwrap_julia/workspace.bzl +++ b/deps/ReactantExtra/third_party/libcxxwrap_julia/workspace.bzl @@ -1,6 +1,6 @@ """Loads the libcxxwrap_julia library.""" -def libcxxwrap_julia_deps(): +def repo(): # TODO change this to download the real artifacts or build them from source native.new_local_repository( name = "libcxxwrap_julia", diff --git a/deps/ReactantExtra/third_party/xla/workspace.bzl b/deps/ReactantExtra/third_party/xla/workspace.bzl index 014c3844b..dd2479db3 100644 --- a/deps/ReactantExtra/third_party/xla/workspace.bzl +++ b/deps/ReactantExtra/third_party/xla/workspace.bzl @@ -4,47 +4,23 @@ load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive") load("@enzyme_ad//:workspace.bzl", "XLA_PATCHES") load("@jax//third_party/xla:workspace.bzl", "XLA_COMMIT", "XLA_SHA256") -XLA_PATCHES = XLA_PATCHES + [ - """ - sed -i.bak0 "s/__cpp_lib_hardware_interference_size/HW_INTERFERENCE_SIZE/g" xla/backends/cpu/runtime/thunk_executor.h - """, - """ - sed -i.bak0 "s/__cpp_lib_hardware_interference_size/HW_INTERFERENCE_SIZE/g" xla/stream_executor/host/host_kernel.cc - """, - """ - sed -i.bak0 "s/patch_cmds = \\[/patch_cmds = \\[\\\"find . -type f -name config.bzl -exec sed -i.bak0 's\\/HAVE_LINK_H=1\\/HAVE_LINK_H=0\\/g' {} +\\\",/g" third_party/llvm/workspace.bzl - """, - """ - sed -i.bak0 "s/patch_cmds = \\[/patch_cmds = \\[\\\"find . -type f -name config.bzl -exec sed -i.bak0 's\\/LLVM_ENABLE_THREADS=1\\/LLVM_ENABLE_THREADS=0\\/g' {} +\\\",/g" third_party/llvm/workspace.bzl - """, - """ - sed -i.bak0 "s/patch_cmds = \\[/patch_cmds = \\[\\\"find . -type f -name config.bzl -exec sed -i.bak0 's\\/HAVE_MALLINFO=1\\/DONT_HAVE_ANY_MALLINFO=0\\/g' {} +\\\",/g" third_party/llvm/workspace.bzl - """, - """ - sed -i.bak0 "s/patch_cmds = \\[/patch_cmds = \\[\\\"find . -type f -name config.bzl -exec sed -i.bak0 's\\/HAVE_PTHREAD_GETNAME_NP=1\\/FAKE_HAVE_PTHREAD_GETNAME_NP=0\\/g' {} +\\\",/g" third_party/llvm/workspace.bzl - """, - """ - sed -i.bak0 "s/patch_cmds = \\[/patch_cmds = \\[\\\"find . -type f -name config.bzl -exec sed -i.bak0 's\\/HAVE_PTHREAD_SETNAME_NP=1\\/FAKE_HAVE_PTHREAD_SETNAME_NP=0\\/g' {} +\\\",/g" third_party/llvm/workspace.bzl - """, - """ - sed -i.bak0 "s/patch_cmds = \\[/patch_cmds = \\[\\\"find . -type f -name config.h -exec sed -i.bak0 's\\/ENABLE_CRASH_OVERRIDES 1\\/ENABLE_CRASH_OVERRIDES 0\\/g' {} +\\\",/g" third_party/llvm/workspace.bzl - """, - """ - sed -i.bak0 "s/patch_cmds = \\[/patch_cmds = \\[\\\"find . -type f -name config.h -exec sed -i.bak0 's\\/HAVE_PTHREAD_GETNAME_NP\\/FAKE_HAVE_PTHREAD_GETNAME_NP\\/g' {} +\\\",/g" third_party/llvm/workspace.bzl - """, - """ - sed -i.bak0 "s/patch_cmds = \\[/patch_cmds = \\[\\\"find . -type f -name config.h -exec sed -i.bak0 's\\/HAVE_PTHREAD_SETNAME_NP\\/FAKE_HAVE_PTHREAD_SETNAME_NP\\/g' {} +\\\",/g" third_party/llvm/workspace.bzl - """, - # """ - # sed -i.bak0 "s/patch_cmds = \\[/patch_cmds = \\['find . -type f -name BUILD.bazel -exec sed -i.bak0 \\\\\\'s\\/\\\"CAPIIR\\\",\\/\\\"CAPIIR\\\",alwayslink=1,\\/g\\\\\\\\' {} +',/g" third_party/llvm/workspace.bzl - # """, -] - def repo(): http_archive( name = "xla", sha256 = XLA_SHA256, strip_prefix = "xla-" + XLA_COMMIT, urls = ["https://github.com/wsmoses/xla/archive/{commit}.tar.gz".format(commit = XLA_COMMIT)], - patch_cmds = XLA_PATCHES, + patch_cmds = XLA_PATCHES + [ + """sed -i.bak0 "s/__cpp_lib_hardware_interference_size/HW_INTERFERENCE_SIZE/g" xla/backends/cpu/runtime/thunk_executor.h""", + """sed -i.bak0 "s/__cpp_lib_hardware_interference_size/HW_INTERFERENCE_SIZE/g" xla/stream_executor/host/host_kernel.cc""", + """sed -i.bak0 "s/patch_cmds = \\[/patch_cmds = \\[\\\"find . -type f -name config.bzl -exec sed -i.bak0 's\\/HAVE_LINK_H=1\\/HAVE_LINK_H=0\\/g' {} +\\\",/g" third_party/llvm/workspace.bzl""", + """sed -i.bak0 "s/patch_cmds = \\[/patch_cmds = \\[\\\"find . -type f -name config.bzl -exec sed -i.bak0 's\\/LLVM_ENABLE_THREADS=1\\/LLVM_ENABLE_THREADS=0\\/g' {} +\\\",/g" third_party/llvm/workspace.bzl""", + """sed -i.bak0 "s/patch_cmds = \\[/patch_cmds = \\[\\\"find . -type f -name config.bzl -exec sed -i.bak0 's\\/HAVE_MALLINFO=1\\/DONT_HAVE_ANY_MALLINFO=0\\/g' {} +\\\",/g" third_party/llvm/workspace.bzl""", + """sed -i.bak0 "s/patch_cmds = \\[/patch_cmds = \\[\\\"find . -type f -name config.bzl -exec sed -i.bak0 's\\/HAVE_PTHREAD_GETNAME_NP=1\\/FAKE_HAVE_PTHREAD_GETNAME_NP=0\\/g' {} +\\\",/g" third_party/llvm/workspace.bzl""", + """sed -i.bak0 "s/patch_cmds = \\[/patch_cmds = \\[\\\"find . -type f -name config.bzl -exec sed -i.bak0 's\\/HAVE_PTHREAD_SETNAME_NP=1\\/FAKE_HAVE_PTHREAD_SETNAME_NP=0\\/g' {} +\\\",/g" third_party/llvm/workspace.bzl""", + """sed -i.bak0 "s/patch_cmds = \\[/patch_cmds = \\[\\\"find . -type f -name config.h -exec sed -i.bak0 's\\/ENABLE_CRASH_OVERRIDES 1\\/ENABLE_CRASH_OVERRIDES 0\\/g' {} +\\\",/g" third_party/llvm/workspace.bzl""", + """sed -i.bak0 "s/patch_cmds = \\[/patch_cmds = \\[\\\"find . -type f -name config.h -exec sed -i.bak0 's\\/HAVE_PTHREAD_GETNAME_NP\\/FAKE_HAVE_PTHREAD_GETNAME_NP\\/g' {} +\\\",/g" third_party/llvm/workspace.bzl""", + """sed -i.bak0 "s/patch_cmds = \\[/patch_cmds = \\[\\\"find . -type f -name config.h -exec sed -i.bak0 's\\/HAVE_PTHREAD_SETNAME_NP\\/FAKE_HAVE_PTHREAD_SETNAME_NP\\/g' {} +\\\",/g" third_party/llvm/workspace.bzl""", + # """sed -i.bak0 "s/patch_cmds = \\[/patch_cmds = \\['find . -type f -name BUILD.bazel -exec sed -i.bak0 \\\\\\'s\\/\\\"CAPIIR\\\",\\/\\\"CAPIIR\\\",alwayslink=1,\\/g\\\\\\\\' {} +',/g" third_party/llvm/workspace.bzl""", + ], ) From b9bbb3822feb6cd69990611bd05360062afd713b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Fri, 13 Dec 2024 01:04:06 +0100 Subject: [PATCH 14/29] add `libcxxwrap_julia` as dependency --- deps/ReactantExtra/BUILD | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/deps/ReactantExtra/BUILD b/deps/ReactantExtra/BUILD index f484e1575..8dc18e56f 100644 --- a/deps/ReactantExtra/BUILD +++ b/deps/ReactantExtra/BUILD @@ -424,12 +424,10 @@ cc_library( "-Wl,-exported_symbol,_ifrt_*", "-Wl,-exported_symbol,_RegisterCustomCallTarget", "-Wl,-exported_symbol,_ConvertLLVMToMLIR", - "-L/Users/mofeing/.julia/artifacts/6a1f8b0d254a485be750499b732b476ddbee44c5/lib/", - "-llibcxxwrap_julia.0.14.0.dylib", - ] + ], }), deps = [ - # "@libcxxwrap_julia", + "@libcxxwrap_julia", "@enzyme//:EnzymeMLIR", "@llvm-project//mlir:AffineDialect", "@llvm-project//mlir:AllPassesAndDialects", From 535721da12fe18b7129cb628e1a717b56c469ec8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Fri, 13 Dec 2024 15:29:28 +0100 Subject: [PATCH 15/29] hardcode julia dep --- deps/ReactantExtra/BUILD | 4 +++- deps/ReactantExtra/WORKSPACE | 3 +++ deps/ReactantExtra/third_party/julia/BUILD | 0 deps/ReactantExtra/third_party/julia/julia.BUILD | 10 ++++++++++ deps/ReactantExtra/third_party/julia/workspace.bzl | 9 +++++++++ .../libcxxwrap_julia/libcxxwrap_julia.BUILD | 6 +++++- 6 files changed, 30 insertions(+), 2 deletions(-) create mode 100644 deps/ReactantExtra/third_party/julia/BUILD create mode 100644 deps/ReactantExtra/third_party/julia/julia.BUILD create mode 100644 deps/ReactantExtra/third_party/julia/workspace.bzl diff --git a/deps/ReactantExtra/BUILD b/deps/ReactantExtra/BUILD index 8dc18e56f..5447a8b54 100644 --- a/deps/ReactantExtra/BUILD +++ b/deps/ReactantExtra/BUILD @@ -377,7 +377,9 @@ cc_library( "-Werror=return-type", "-Werror=unused-result", "-Wno-error=stringop-truncation", - # "-I$(location @libcxxwrap_julia//:include)" + # "-I$(location @libcxxwrap_julia//:include)", + "-Iexternal/libcxxwrap_julia/include/", + "-Iexternal/julia/include/julia/", ], alwayslink = True, linkstatic = True, diff --git a/deps/ReactantExtra/WORKSPACE b/deps/ReactantExtra/WORKSPACE index a8c606a99..9ed0b83b0 100644 --- a/deps/ReactantExtra/WORKSPACE +++ b/deps/ReactantExtra/WORKSPACE @@ -2,6 +2,9 @@ workspace(name = "Reactant") load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive") +load("//third_party/julia:workspace.bzl", julia_workspace = "repo") +julia_workspace() + load("//third_party/libcxxwrap_julia:workspace.bzl", libcxxwrap_julia_workspace = "repo") libcxxwrap_julia_workspace() diff --git a/deps/ReactantExtra/third_party/julia/BUILD b/deps/ReactantExtra/third_party/julia/BUILD new file mode 100644 index 000000000..e69de29bb diff --git a/deps/ReactantExtra/third_party/julia/julia.BUILD b/deps/ReactantExtra/third_party/julia/julia.BUILD new file mode 100644 index 000000000..7df4ebe20 --- /dev/null +++ b/deps/ReactantExtra/third_party/julia/julia.BUILD @@ -0,0 +1,10 @@ +licenses(["notice"]) + +package(default_visibility = ["//visibility:public"]) + +cc_import( + name = "julia", + hdrs = glob(["include/**/*"]), + includes = ["include"], + shared_library = "lib/libjulia.dylib", +) diff --git a/deps/ReactantExtra/third_party/julia/workspace.bzl b/deps/ReactantExtra/third_party/julia/workspace.bzl new file mode 100644 index 000000000..019e52d1a --- /dev/null +++ b/deps/ReactantExtra/third_party/julia/workspace.bzl @@ -0,0 +1,9 @@ +"""Loads julia.""" + +def repo(): + # TODO change this to download the real artifacts or build them from source? + native.new_local_repository( + name = "julia", + path = "/Users/mofeing/.julia/juliaup/julia-1.11.2+0.aarch64.apple.darwin14/", + build_file = "//third_party/julia:julia.BUILD", + ) diff --git a/deps/ReactantExtra/third_party/libcxxwrap_julia/libcxxwrap_julia.BUILD b/deps/ReactantExtra/third_party/libcxxwrap_julia/libcxxwrap_julia.BUILD index 6c4bd3f2c..1c9dba4fe 100644 --- a/deps/ReactantExtra/third_party/libcxxwrap_julia/libcxxwrap_julia.BUILD +++ b/deps/ReactantExtra/third_party/libcxxwrap_julia/libcxxwrap_julia.BUILD @@ -5,7 +5,11 @@ package(default_visibility = ["//visibility:public"]) cc_import( name = "libcxxwrap_julia", - hdrs = glob(["include/jlcxx/*.hpp"]), + hdrs = glob(["include/**/*.hpp"]), + includes = ["include"], shared_library = "lib/libcxxwrap_julia.dylib", visibility = ["//visibility:public"], + deps = [ + "@julia", + ], ) From 2effd55a8c0d0e2ad12796be39f037ff3ebd694d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Tue, 17 Dec 2024 12:27:04 +0100 Subject: [PATCH 16/29] export `reactant_*` functions --- deps/ReactantExtra/BUILD | 1 + 1 file changed, 1 insertion(+) diff --git a/deps/ReactantExtra/BUILD b/deps/ReactantExtra/BUILD index 5447a8b54..9805781d0 100644 --- a/deps/ReactantExtra/BUILD +++ b/deps/ReactantExtra/BUILD @@ -426,6 +426,7 @@ cc_library( "-Wl,-exported_symbol,_ifrt_*", "-Wl,-exported_symbol,_RegisterCustomCallTarget", "-Wl,-exported_symbol,_ConvertLLVMToMLIR", + "-Wl,-exported_symbol,_reactant_*", ], }), deps = [ From 736028a5b6b2495131c784b2ef230499575b646a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Tue, 17 Dec 2024 15:56:01 +0100 Subject: [PATCH 17/29] downgrade libcxxwrap_julia to v0.13.3 --- deps/ReactantExtra/third_party/libcxxwrap_julia/workspace.bzl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/deps/ReactantExtra/third_party/libcxxwrap_julia/workspace.bzl b/deps/ReactantExtra/third_party/libcxxwrap_julia/workspace.bzl index 0a7063e0f..cc81001c1 100644 --- a/deps/ReactantExtra/third_party/libcxxwrap_julia/workspace.bzl +++ b/deps/ReactantExtra/third_party/libcxxwrap_julia/workspace.bzl @@ -4,6 +4,6 @@ def repo(): # TODO change this to download the real artifacts or build them from source native.new_local_repository( name = "libcxxwrap_julia", - path = "/Users/mofeing/.julia/artifacts/6a1f8b0d254a485be750499b732b476ddbee44c5/", + path = "/Users/mofeing/.julia/artifacts/4997cdb1f8db7f55d750afcad5db88e3bb4a7819/", build_file = "//third_party/libcxxwrap_julia:libcxxwrap_julia.BUILD", ) From 70d25ebfbeac1d94bf815add79412fcfb1f10998 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Tue, 17 Dec 2024 15:56:27 +0100 Subject: [PATCH 18/29] fix major version when linking to libcxxwrap_julia --- .../third_party/libcxxwrap_julia/libcxxwrap_julia.BUILD | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/deps/ReactantExtra/third_party/libcxxwrap_julia/libcxxwrap_julia.BUILD b/deps/ReactantExtra/third_party/libcxxwrap_julia/libcxxwrap_julia.BUILD index 1c9dba4fe..26d668986 100644 --- a/deps/ReactantExtra/third_party/libcxxwrap_julia/libcxxwrap_julia.BUILD +++ b/deps/ReactantExtra/third_party/libcxxwrap_julia/libcxxwrap_julia.BUILD @@ -7,7 +7,7 @@ cc_import( name = "libcxxwrap_julia", hdrs = glob(["include/**/*.hpp"]), includes = ["include"], - shared_library = "lib/libcxxwrap_julia.dylib", + shared_library = "lib/libcxxwrap_julia.0.dylib", visibility = ["//visibility:public"], deps = [ "@julia", From 338ed3ea1d827ca504742a14cce57bd37f1cd487 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Tue, 17 Dec 2024 15:56:38 +0100 Subject: [PATCH 19/29] remove legacy export --- deps/ReactantExtra/third_party/libcxxwrap_julia/BUILD | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/deps/ReactantExtra/third_party/libcxxwrap_julia/BUILD b/deps/ReactantExtra/third_party/libcxxwrap_julia/BUILD index 9b495c067..8b1378917 100644 --- a/deps/ReactantExtra/third_party/libcxxwrap_julia/BUILD +++ b/deps/ReactantExtra/third_party/libcxxwrap_julia/BUILD @@ -1,3 +1 @@ -exports_files(srcs = [ - "workspace.bzl", -]) + From 88964e841fd25e42df669b57e9efea30407d4b13 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Tue, 17 Dec 2024 16:09:21 +0100 Subject: [PATCH 20/29] fix cxx wrapping module instantiation --- src/IFRT.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/IFRT.jl b/src/IFRT.jl index 8e7f1a1c3..37e90e097 100644 --- a/src/IFRT.jl +++ b/src/IFRT.jl @@ -3,6 +3,6 @@ module IFRT using CxxWrap using Reactant_jll -@wrapmodule(() -> joinpath(Reactant_jll.libdir, :reactant_module_ifrt)) +@wrapmodule(() -> Reactant_jll.libReactantExtra, :reactant_module_ifrt) end From e0d4ed3d19f93e193a5b8ce2799f72cfb20e3658 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Tue, 17 Dec 2024 17:18:38 +0100 Subject: [PATCH 21/29] move `API.cpp` to new `src/` folder to start modularizing code --- deps/ReactantExtra/BUILD | 6 +----- deps/ReactantExtra/{ => src}/API.cpp | 0 2 files changed, 1 insertion(+), 5 deletions(-) rename deps/ReactantExtra/{ => src}/API.cpp (100%) diff --git a/deps/ReactantExtra/BUILD b/deps/ReactantExtra/BUILD index 9805781d0..d07238ae2 100644 --- a/deps/ReactantExtra/BUILD +++ b/deps/ReactantExtra/BUILD @@ -350,11 +350,7 @@ platform( cc_library( name = "ReactantExtraLib", - srcs = glob( - [ - "*.cpp", - ] - ) + [ + srcs = glob(["src/*.cpp"]) + [ # "@com_google_protobuf//:src/google/protobuf/io/coded_stream.cc", "@xla//xla:xla.pb.cc", "@xla//xla:xla_data.pb.cc", diff --git a/deps/ReactantExtra/API.cpp b/deps/ReactantExtra/src/API.cpp similarity index 100% rename from deps/ReactantExtra/API.cpp rename to deps/ReactantExtra/src/API.cpp From d1ef025aa72a8d2169ae1fc7a36bba426cf8016a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Tue, 17 Dec 2024 19:26:22 +0100 Subject: [PATCH 22/29] export `register_julia_module` from libcxxwrap_julia --- deps/ReactantExtra/BUILD | 1 + 1 file changed, 1 insertion(+) diff --git a/deps/ReactantExtra/BUILD b/deps/ReactantExtra/BUILD index d07238ae2..f3a6814d0 100644 --- a/deps/ReactantExtra/BUILD +++ b/deps/ReactantExtra/BUILD @@ -423,6 +423,7 @@ cc_library( "-Wl,-exported_symbol,_RegisterCustomCallTarget", "-Wl,-exported_symbol,_ConvertLLVMToMLIR", "-Wl,-exported_symbol,_reactant_*", + "-Wl,-exported_symbol,_register_julia_module", ], }), deps = [ From f3c9de3260c38844b1c1d5def804f69fa1e8f2b3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Tue, 17 Dec 2024 19:28:19 +0100 Subject: [PATCH 23/29] move IFRT ffi to "IFRT.cpp" file --- deps/ReactantExtra/src/API.cpp | 943 -------------------------------- deps/ReactantExtra/src/IFRT.cpp | 905 ++++++++++++++++++++++++++++++ 2 files changed, 905 insertions(+), 943 deletions(-) create mode 100644 deps/ReactantExtra/src/IFRT.cpp diff --git a/deps/ReactantExtra/src/API.cpp b/deps/ReactantExtra/src/API.cpp index 314ff7af3..eee5cc1cb 100644 --- a/deps/ReactantExtra/src/API.cpp +++ b/deps/ReactantExtra/src/API.cpp @@ -60,38 +60,6 @@ #include "llvm-c/TargetMachine.h" -// IFRT -#include "xla/python/ifrt/value.h" -#include "xla/python/ifrt/tuple.h" -#include "xla/python/ifrt/dtype.h" -#include "xla/python/ifrt/shape.h" -#include "xla/python/ifrt/index.h" -#include "xla/python/ifrt/index_domain.h" -#include "xla/python/ifrt/memory.h" -#include "xla/python/ifrt/device.h" -#include "xla/python/ifrt/sharding.h" -#include "xla/python/ifrt/array.h" -#include "xla/python/ifrt/topology.h" -#include "xla/python/ifrt/client.h" -#include "xla/python/ifrt/host_callback.h" -#include "xla/python/ifrt/executable.h" -#include "xla/python/ifrt/hlo/hlo_program.h" -#include "xla/python/ifrt/compiler.h" - -// IFRT - PJRT -#include "xla/python/pjrt_ifrt/pjrt_dtype.h" -#include "xla/python/pjrt_ifrt/pjrt_tuple.h" -#include "xla/python/pjrt_ifrt/pjrt_memory.h" -#include "xla/python/pjrt_ifrt/pjrt_device.h" -#include "xla/python/pjrt_ifrt/pjrt_array.h" -#include "xla/python/pjrt_ifrt/pjrt_topology.h" -#include "xla/python/pjrt_ifrt/pjrt_client.h" -#include "xla/python/pjrt_ifrt/pjrt_host_callback.h" -#include "xla/python/pjrt_ifrt/pjrt_executable.h" -#include "xla/python/pjrt_ifrt/pjrt_compiler.h" - -#include "jlcxx/jlcxx.hpp" - using namespace mlir; using namespace llvm; using namespace xla; @@ -625,914 +593,3 @@ extern "C" MlirOperation LinkInModule(MlirModule prevModC, MlirModule newModC, c newMod.getBody()->getOperations()); return wrap(entryFn); } - -#pragma region xla::ifrt - -#pragma region xla::ifrt::Value -extern "C" ifrt::Client* ifrt_value_client(ifrt::Value* value) { - return value->client(); -} - -extern "C" ifrt::Future<> ifrt_value_get_ready_future(ifrt::Value* value) { - return value->GetReadyFuture(); -} - -extern "C" ifrt::Future<> ifrt_value_delete(ifrt::Value* value) { - return value->Delete(); -} - -extern "C" bool ifrt_value_is_deleted(ifrt::Value* value) { - return value->IsDeleted(); -} - -extern "C" const char* ifrt_value_debug_string(ifrt::Value* value) { - return cstr_from_string(value->DebugString()); -} -#pragma endregion - -#pragma region xla::ifrt::Tuple -extern "C" int ifrt_tuple_arity(ifrt::Tuple* tuple) { - return tuple->Arity(); -} - -// TODO ifrt::Tuple::Unpack -#pragma endregion - -#pragma region xla::ifrt::PjRtTuple -extern "C" ifrt::PjRtTuple* ifrt_pjrt_tuple_ctor(ifrt::PjRtCompatibleClient* client, ifrt::Value* values, int nvalues) { - auto values_ptr = new tsl::RCReference[nvalues]; - for (int i=0; i(); - values_ptr[i].reset(&values[i]); - } - auto span = absl::Span>(values_ptr, nvalues); - return xla::ValueOrThrow(ifrt::PjRtTuple::Create(client, span)).release(); -} - -extern "C" void ifrt_pjrt_tuple_free(ifrt::PjRtTuple* tuple) { - delete tuple; -} -#pragma endregion - -#pragma region xla::ifrt::DType -extern "C" ifrt::DType* ifrt_dtype_ctor(ifrt::DType::Kind kind) { - return new ifrt::DType(kind); -} - -extern "C" void ifrt_dtype_free(ifrt::DType* dtype) { - delete dtype; -} - -extern "C" ifrt::DType::Kind ifrt_dtype_kind(ifrt::DType* dtype) { - return dtype->kind(); -} - -extern "C" bool ifrt_dtype_eq(ifrt::DType* dtype1, ifrt::DType* dtype2) { - return *dtype1 == *dtype2; -} - -extern "C" bool ifrt_dtype_ne(ifrt::DType* dtype1, ifrt::DType* dtype2) { - return *dtype1 != *dtype2; -} - -// Returns -1 if not aligned to a byte boundary or there is no fixed size -extern "C" int ifrt_dtype_byte_size(ifrt::DType* dtype) { - auto byte_size = dtype->byte_size(); - if (byte_size.has_value()) { - return byte_size.value(); - } - return -1; -} - -// Returns -1 if there is no fixed size -extern "C" int ifrt_dtype_bit_size(ifrt::DType* dtype) { - auto bit_size = dtype->bit_size(); - if (bit_size.has_value()) { - return bit_size.value(); - } - return -1; -} - -extern "C" const char* ifrt_dtype_debug_string(ifrt::DType* dtype) { - return cstr_from_string(dtype->DebugString()); -} - -// xla::PrimitiveType is a enum, so we use int to represent it on Julia side -extern "C" xla::PrimitiveType ifrt_to_primitive_type(ifrt::DType* dtype) { - return xla::ValueOrThrow(ifrt::ToPrimitiveType(*dtype)); -} - -// xla::PrimitiveType is a enum, so we use int to represent it on Julia side -extern "C" ifrt::DType* ifrt_to_dtype(xla::PrimitiveType primitive_type) { - auto dtype = xla::ValueOrThrow(ifrt::ToDType(primitive_type)); - return new ifrt::DType(dtype.kind()); -} -#pragma endregion - -#pragma region xla::ifrt::Shape -extern "C" ifrt::Shape* ifrt_shape_ctor(const int64_t* dims, size_t dims_size) { - return new ifrt::Shape(absl::Span(dims, dims_size)); -} - -extern "C" void ifrt_shape_free(ifrt::Shape* shape) { - delete shape; -} - -extern "C" const int64_t* ifrt_shape_dims(ifrt::Shape* shape) { - return shape->dims().data(); -} - -extern "C" int64_t ifrt_shape_dims_num_elements(ifrt::Shape* shape) { - return shape->num_elements(); -} - -extern "C" const char* ifrt_shape_debug_string(ifrt::Shape* shape) { - return cstr_from_string(shape->DebugString()); -} -#pragma endregion - -#pragma region xla::ifrt::DynamicShape -extern "C" ifrt::DynamicShape* ifrt_dynamicshape_ctor(ifrt::Shape* shape, const bool* dynamic_dims_mask) { - auto tag = ifrt::BoundedDynamicShapeTag(absl::Span(dynamic_dims_mask, shape->dims().size())); - auto dynshape = xla::ValueOrThrow(ifrt::DynamicShape::Create(*shape, tag)); - return new ifrt::DynamicShape(dynshape); -} - -extern "C" void ifrt_dynamicshape_free(ifrt::DynamicShape* shape) { - delete shape; -} - -// TODO ifrt::DynamicShape::GetTag - -extern "C" bool ifrt_dynamicshape_eq(ifrt::DynamicShape* shape1, ifrt::DynamicShape* shape2) { - return *shape1 == *shape2; -} - -extern "C" bool ifrt_dynamicshape_ne(ifrt::DynamicShape* shape1, ifrt::DynamicShape* shape2) { - return *shape1 != *shape2; -} - -extern "C" ifrt::Shape* ifrt_dynamicshape_get_padded_shape(ifrt::DynamicShape* shape) { - auto padshape = xla::ValueOrThrow(shape->GetPaddedShape()); - return new ifrt::Shape(padshape); -} - -extern "C" bool ifrt_dynamicshape_is_dynamic_dim(ifrt::DynamicShape* shape, int dimension) { - return shape->IsDynamicDim(dimension); -} - -extern "C" const char* ifrt_dynamicshape_debug_string(ifrt::DynamicShape* shape) { - return cstr_from_string(shape->DebugString()); -} -#pragma endregion - -#pragma region xla::ifrt::Index -extern "C" ifrt::Index* ifrt_index_ctor(const int64_t* elements, size_t elements_size) { - return new ifrt::Index(absl::Span(elements, elements_size)); -} - -extern "C" ifrt::Index* ifrt_index_zeros(int num_elements) { - return new ifrt::Index(ifrt::Index::Zeros(num_elements)); -} - -extern "C" void ifrt_index_free(ifrt::Index* index) { - delete index; -} - -extern "C" const int64_t* ifrt_index_elements(ifrt::Index* index) { - return index->elements().data(); -} - -extern "C" int ifrt_index_count(ifrt::Index* index) { - return index->elements().size(); -} - -extern "C" bool ifrt_index_eq(ifrt::Index* index1, ifrt::Index* index2) { - return *index1 == *index2; -} - -extern "C" bool ifrt_index_ne(ifrt::Index* index1, ifrt::Index* index2) { - return *index1 != *index2; -} - -extern "C" ifrt::Index* ifrt_index_add(ifrt::Index* index, ifrt::Index* offset) { - return new ifrt::Index(*index + *offset); -} - -extern "C" ifrt::Index* ifrt_index_sub(ifrt::Index* index, ifrt::Index* offset) { - return new ifrt::Index(*index - *offset); -} - -// WARN we're not checking if the multiplier has the same size as the index -extern "C" ifrt::Index* ifrt_index_mul(ifrt::Index* index, const int64_t* multiplier) { - return new ifrt::Index(*index * absl::Span(multiplier, ifrt_index_count(index))); -} - -extern "C" void ifrt_index_add_inplace(ifrt::Index* index, ifrt::Index* offset) { - *index += *offset; -} - -extern "C" void ifrt_index_sub_inplace(ifrt::Index* index, ifrt::Index* offset) { - *index -= *offset; -} - -extern "C" void ifrt_index_mul_inplace(ifrt::Index* index, const int64_t* multiplier) { - *index *= absl::Span(multiplier, ifrt_index_count(index)); -} - -extern "C" const char* ifrt_index_debug_string(ifrt::Index* index) { - return cstr_from_string(index->DebugString()); -} -#pragma endregion - -#pragma region xla::ifrt::IndexDomain -extern "C" ifrt::IndexDomain* ifrt_indexdomain_ctor(ifrt::Shape* shape) { - return new ifrt::IndexDomain(*shape); -} - -extern "C" ifrt::IndexDomain* ifrt_indexdomain_ctor_with_origin(ifrt::Index* origin, ifrt::Shape* shape) { - return new ifrt::IndexDomain(*origin, *shape); -} - -extern "C" void ifrt_indexdomain_free(ifrt::IndexDomain* index_domain) { - delete index_domain; -} - -extern "C" const ifrt::Index* ifrt_indexdomain_origin(ifrt::IndexDomain* index_domain) { - return &index_domain->origin(); -} - -extern "C" const ifrt::Shape* ifrt_indexdomain_shape(ifrt::IndexDomain* index_domain) { - return &index_domain->shape(); -} - -extern "C" bool ifrt_indexdomain_eq(ifrt::IndexDomain* index_domain1, ifrt::IndexDomain* index_domain2) { - return *index_domain1 == *index_domain2; -} - -extern "C" bool ifrt_indexdomain_ne(ifrt::IndexDomain* index_domain1, ifrt::IndexDomain* index_domain2) { - return *index_domain1 != *index_domain2; -} - -extern "C" ifrt::IndexDomain* ifrt_indexdomain_add(ifrt::IndexDomain* index_domain, ifrt::Index* offset) { - return new ifrt::IndexDomain(*index_domain + *offset); -} - -extern "C" ifrt::IndexDomain* ifrt_indexdomain_sub(ifrt::IndexDomain* index_domain, ifrt::Index* offset) { - return new ifrt::IndexDomain(*index_domain - *offset); -} - -extern "C" void ifrt_indexdomain_add_inplace(ifrt::IndexDomain* index_domain, ifrt::Index* offset) { - *index_domain += *offset; -} - -extern "C" void ifrt_indexdomain_sub_inplace(ifrt::IndexDomain* index_domain, ifrt::Index* offset) { - *index_domain -= *offset; -} - -extern "C" const char* ifrt_indexdomain_debug_string(ifrt::IndexDomain* index_domain) { - return cstr_from_string(index_domain->DebugString()); -} -#pragma endregion - -#pragma region xla::ifrt::MemoryKind -// Pass a nullptr to create a `MemoryKind` with no memory chosen. -extern "C" ifrt::MemoryKind* ifrt_memorykind_ctor(const char* memory_kind) { - if (memory_kind == nullptr) - return new ifrt::MemoryKind(); - return new ifrt::MemoryKind(std::string(memory_kind)); -} - -extern "C" void ifrt_memorykind_free(ifrt::MemoryKind* memory_kind) { - delete memory_kind; -} - -extern "C" bool ifrt_memorykind_eq(ifrt::MemoryKind* mk1, ifrt::MemoryKind* mk2) { - return *mk1 == *mk2; -} - -extern "C" bool ifrt_memorykind_ne(ifrt::MemoryKind* mk1, ifrt::MemoryKind* mk2) { - return *mk1 != *mk2; -} - -extern "C" const char* ifrt_memorykind_string(ifrt::MemoryKind* memory_kind) { - if (memory_kind->memory_kind().has_value()) - return cstr_from_string(memory_kind->memory_kind().value()); - else - return nullptr; -} - -extern "C" ifrt::MemoryKind* ifrt_memorykind_canonicalize(ifrt::MemoryKind* memory_kind, ifrt::Device* device) { - return new ifrt::MemoryKind(CanonicalizeMemoryKind(*memory_kind, device)); -} -#pragma endregion - -#pragma region xla::ifrt::Memory -// MemoryId is a struct with a single int32_t field --> check out xla/python/ifrt/memory.h -extern "C" ifrt::MemoryId ifrt_memory_id(ifrt::Memory* memory) { - return memory->Id(); -} - -extern "C" const ifrt::MemoryKind* ifrt_memory_kind(ifrt::Memory* memory) { - return &(memory->Kind()); -} - -extern "C" const char* ifrt_memory_to_string(ifrt::Memory* memory) { - return cstr_from_string(memory->ToString()); -} - -extern "C" const char* ifrt_memory_debug_string(ifrt::Memory* memory) { - return cstr_from_string(memory->DebugString()); -} - -extern "C" std::tuple ifrt_memory_devices(ifrt::Memory* memory) { - auto devices = memory->Devices(); - return std::make_tuple(devices.size(), devices.data()); -} -#pragma endregion - -#pragma region xla::ifrt::PjRtMemory -extern "C" ifrt::PjRtMemory* ifrt_pjrt_memory_ctor(ifrt::PjRtClient* client, xla::PjRtMemorySpace* memory_space) { - return new ifrt::PjRtMemory(client, memory_space); -} - -extern "C" void ifrt_pjrt_memory_free(ifrt::PjRtMemory* memory) { - delete memory; -} - -extern "C" ifrt::PjRtClient* ifrt_pjrt_memory_client(ifrt::PjRtMemory* memory) { - return memory->client(); -} - -extern "C" xla::PjRtMemorySpace* ifrt_pjrt_memory_space(ifrt::PjRtMemory* memory) { - return memory->pjrt_memory(); -} -#pragma endregion - -#pragma region xla::ifrt::Device -extern "C" ifrt::Client* ifrt_device_client(ifrt::Device* device) { - return device->client(); -} - -// DeviceId is a struct with a single int32_t field --> check out xla/pjrt/pjrt_common.h -extern "C" ifrt::DeviceId ifrt_device_id(ifrt::Device* device) { - return device->Id(); -} - -// TODO ifrt_device_attributes - -extern "C" const char* ifrt_device_kind(ifrt::Device* device) { - return cstr_from_string(device->Kind()); -} - -extern "C" const char* ifrt_device_to_string(ifrt::Device* device) { - return cstr_from_string(device->ToString()); -} - -extern "C" const char* ifrt_device_debug_string(ifrt::Device* device) { - return cstr_from_string(device->DebugString()); -} - -extern "C" ifrt::Memory* ifrt_device_default_memory(ifrt::Device* device) { - return xla::ValueOrThrow(device->DefaultMemory()); -} - -// TODO ifrt_device_memories - -extern "C" bool ifrt_device_is_addressable(ifrt::Device* device) { - return device->IsAddressable(); -} - -extern "C" int ifrt_device_process_index(ifrt::Device* device) { - return device->ProcessIndex(); -} -#pragma endregion - -#pragma region xla::ifrt::PjRtDevice -// DeviceId is a struct with a single int32_t field --> check out xla/pjrt/pjrt_common.h -// TODO support `attributes` parameter -extern "C" ifrt::PjRtDevice* ifrt_pjrt_device_ctor(ifrt::PjRtClient* client, ifrt::DeviceId device_id, const char* kind, const char* to_string, const char* debug_string, int process_index, xla::PjRtDevice* pjrt_device) { - return new ifrt::PjRtDevice(client, device_id, kind, to_string, debug_string, process_index, absl::flat_hash_map(), pjrt_device); -} - -extern "C" void ifrt_pjrt_device_free(ifrt::PjRtDevice* device) { - delete device; -} - -extern "C" xla::PjRtDevice* ifrt_pjrt_device_pjrt_device(ifrt::PjRtDevice* device) { - return device->pjrt_device(); -} -#pragma endregion - -#pragma region xla::ifrt::Sharding -// TODO ifrt_sharding_devices -// TODO ifrt_sharding_memory_kind - -// extern "C" void ifrt_sharding_disassemble(ifrt::Sharding* sharding, ifrt::Shape* shape, char** error) { -// auto status = sharding->Disassemble(*shape); -// if (!status.ok()) { -// auto str = status.message(); -// char* err = (char*)malloc(str.size()+1); -// memcpy(err, str.data(), str.size()+1); -// *error = err; -// } -// } - -// TODO ifrt_sharding_disassemble_dynamic_shape -// TODO ifrt_sharding_index_domains - -extern "C" const char* ifrt_sharding_debug_string(ifrt::Sharding* sharding) { - return cstr_from_string(sharding->DebugString()); -} -#pragma endregion - -#pragma region xla::ifrt::Array -extern "C" ifrt::DType* ifrt_array_dtype(ifrt::Array* array) { - return new ifrt::DType(array->dtype()); -} - -extern "C" const ifrt::Shape* ifrt_array_shape(ifrt::Array* array) { - return &(array->shape()); -} - -extern "C" const ifrt::Sharding* ifrt_array_sharding(ifrt::Array* array) { - return &(array->sharding()); -} - -extern "C" PjRtLayout* ifrt_array_layout(ifrt::Array* array) { - return xla::ValueOrThrow(array->layout()).release(); -} - -// TODO xla::ifrt::Array::DisassembleIntoSingleDeviceArrays -// TODO xla::ifrt::Array::FullyReplicatedShard - -extern "C" ifrt::Future<> ifrt_array_copy_to_host_buffer(ifrt::Array* array, void* data, const int64_t* byte_strides, int semantics) { - return array->CopyToHostBuffer(data, absl::Span(byte_strides, array->shape().num_elements()), ifrt::ArrayCopySemantics(semantics)); -} -#pragma endregion - -#pragma region xla::ifrt::PjRtArray -// TODO constructors / `Create` - -extern "C" std::tuple ifrt_pjrt_array_pjrt_buffers(ifrt::PjRtArray* array) { - auto buffers = array->pjrt_buffers(); - auto buffers_ptr = new xla::PjRtBuffer*[buffers.size()]; - for (int i=0; iplatform_name()); -} - -extern "C" const char* ifrt_topology_platform_version(ifrt::Topology* topology) { - return cstr_from_string(topology->platform_version()); -} - -// returns PjRtPlatformId which is a type alias for uint64_t -extern "C" uint64_t ifrt_topology_platform_id(ifrt::Topology* topology) { - return topology->platform_id(); -} - -extern "C" std::tuple ifrt_topology_device_descriptions(ifrt::Topology* topology) { - auto descriptions = topology->DeviceDescriptions(); - auto descriptions_ptr = new const xla::PjRtDeviceDescription*[descriptions.size()]; - for (int i=0; iSerialize())); -} - -// TODO xla::ifrt::Topology::Attributes - -#pragma endregion - -#pragma region xla::ifrt::PjRtTopology -extern "C" ifrt::PjRtTopology* ifrt_pjrt_topology_ctor(const xla::PjRtTopologyDescription* description) { - return new ifrt::PjRtTopology(std::shared_ptr{description}); -} - -extern "C" const xla::PjRtTopologyDescription* ifrt_pjrt_topology_description(ifrt::PjRtTopology* topology) { - return topology->description().get(); -} -#pragma endregion - -#pragma region xla::ifrt::Client -extern "C" int ifrt_client_device_count(ifrt::Client* client) { - return client->device_count(); -} - -extern "C" int ifrt_client_addressable_device_count(ifrt::Client* client) { - return client->addressable_device_count(); -} - -extern "C" ifrt::Device* const* ifrt_client_devices(ifrt::Client* client) { - return client->devices().data(); -} - -extern "C" ifrt::Device* const* ifrt_client_addressable_devices(ifrt::Client* client) { - return client->addressable_devices().data(); -} - -extern "C" int ifrt_client_process_index(ifrt::Client* client) { - return client->process_index(); -} - -// TODO xla::ifrt::Client::GetDefaultDeviceAssignment - -extern "C" ifrt::Device* ifrt_client_lookup_device(ifrt::Client* client, int device_id) { - return xla::ValueOrThrow(client->LookupDevice(ifrt::DeviceId(device_id))); -} - -extern "C" ifrt::Device* ifrt_client_lookup_addressable_device(ifrt::Client* client, int device_id) { - return xla::ValueOrThrow(client->LookupAddressableDevice(device_id)); -} - -extern "C" ifrt::Compiler* ifrt_client_default_compiler(ifrt::Client* client) { - return client->GetDefaultCompiler(); -} - -// TODO ifrt_client_topology_for_devices -// TODO ifrt_client_default_layout_for_device -#pragma endregion - -#pragma region xla::ifrt::PjRtClient -// TODO support more parameters of `PjRtClient::CreateOptions` -extern "C" ifrt::PjRtClient* ifrt_pjrt_client_ctor(xla::PjRtClient* pjrt_client) { - return xla::ValueOrThrow(ifrt::PjRtClient::Create(ifrt::PjRtClient::CreateOptions{std::shared_ptr{pjrt_client}})).release(); -} - -extern "C" void ifrt_pjrt_client_free(ifrt::PjRtClient* client) { - delete client; -} - -extern "C" xla::PjRtClient* ifrt_pjrt_client_pjrt_client(ifrt::PjRtClient* client) { - return client->pjrt_client(); -} - -// TODO there are problems with using `make_shared -// extern "C" ifrt::PjRtCompatibleArray* ifrt_pjrt_client_create_pjrt_array(ifrt::PjRtClient* client, xla::PjRtBuffer* pjrt_buffer) { -// auto buffer_ptr = std::make_shared(*pjrt_buffer); -// return xla::ValueOrThrow(client->CreatePjRtArray(buffer_ptr)).release(); -// } - -// TODO extern "C" ifrt::PjRtCompatibleArray* ifrt_pjrt_client_create_pjrt_array_from_buffers(ifrt::Shape* shape, ifrt::PjRtBuffer** pjrt_buffers, int num_buffers) {} - -extern "C" ifrt::PjRtCompatibleDevice* ifrt_pjrt_client_lookup_pjrt_device(ifrt::PjRtClient* client, xla::PjRtDevice* pjrt_device) { - return xla::ValueOrThrow(client->LookupPjRtDevice(pjrt_device)); -} - -extern "C" ifrt::PjRtCompatibleMemory* ifrt_pjrt_client_lookup_pjrt_memory(ifrt::PjRtClient* client, xla::PjRtMemorySpace* pjrt_memory_space) { - return xla::ValueOrThrow(client->LookupPjRtMemory(pjrt_memory_space)); -} -#pragma endregion - -#pragma region xla::ifrt::HostCallback -extern "C" const char* ifrt_hostcallback_serialize(ifrt::HostCallback* host_callback) { - return cstr_from_string(host_callback->Serialize()); -} -#pragma endregion - -#pragma region xla::ifrt::LoadedHostCallback -extern "C" ifrt::Client* ifrt_loadedhostcallback_client(ifrt::LoadedHostCallback* host_callback) { - return host_callback->client(); -} - -extern "C" const char* ifrt_loadedhostcallback_serialize(ifrt::LoadedHostCallback* host_callback) { - // auto msg = ; - return cstr_from_string(xla::ValueOrThrow(host_callback->Serialize())); -} -#pragma endregion - -#pragma region xla::ifrt::PjRtHostSendAndRecvLoadedHostCallback -extern "C" ifrt::PjRtHostSendAndRecvLoadedHostCallback* ifrt_pjrt_hostsendandrecv_loadhostcallback_ctor(ifrt::PjRtClient* client, xla::HostCallback* host_callback) { - auto xla_callback_ptr = std::make_unique(*host_callback); - return new ifrt::PjRtHostSendAndRecvLoadedHostCallback(client, std::move(xla_callback_ptr)); -} - -extern "C" void ifrt_pjrt_hostsendandrecv_loadhostcallback_free(ifrt::PjRtHostSendAndRecvLoadedHostCallback* host_callback) { - delete host_callback; -} - -extern "C" xla::HostCallback* ifrt_pjrt_hostsendandrecv_loadhostcallback_host_callback(ifrt::PjRtHostSendAndRecvLoadedHostCallback* host_callback) { - return new xla::HostCallback(host_callback->host_callback()); -} -#pragma endregion - -#pragma region xla::ifrt::Executable -extern "C" const char* ifrt_executable_name(ifrt::Executable* executable) { - return cstr_from_string(executable->name()); -} - -extern "C" const char* ifrt_executable_fingerprint(ifrt::Executable* executable) { - auto result = xla::ValueOrThrow(executable->Fingerprint()); - if (!result.has_value()) return ""; - return cstr_from_string(result.value()); -} - -extern "C" const char* ifrt_executable_serialize(ifrt::Executable* executable) { - return cstr_from_string(xla::ValueOrThrow(executable->Serialize())); -} - -extern "C" int ifrt_executable_num_devices(ifrt::Executable* executable) { - return executable->num_devices(); -} - -extern "C" int64_t ifrt_executable_size(ifrt::Executable* executable) { - return executable->SizeOfGeneratedCodeInBytes(); -} - -// TODO xla::ifrt::Executable::GetCompiledMemoryStats - -extern "C" std::tuple ifrt_executable_parameter_shardings(ifrt::Executable* executable) { - auto shardings = executable->GetParameterShardings(); - if (!shardings.has_value()) return std::make_tuple(0, nullptr); - return std::make_tuple(shardings.value().size(), shardings.value().data()); -} - -extern "C" std::tuple ifrt_executable_output_shardings(ifrt::Executable* executable) { - auto shardings = executable->GetOutputShardings(); - if (!shardings.has_value()) return std::make_tuple(0, nullptr); - return std::make_tuple(shardings.value().size(), shardings.value().data()); -} - -extern "C" std::tuple ifrt_executable_parameter_layouts(ifrt::Executable* executable) { - auto layouts = xla::ValueOrThrow(executable->GetParameterLayouts()); - auto layouts_ptr = new xla::PjRtLayout*[layouts.size()]; - for (int i=0; i ifrt_executable_output_layouts(ifrt::Executable* executable) { - auto layouts = xla::ValueOrThrow(executable->GetOutputLayouts()); - auto layouts_ptr = new xla::PjRtLayout*[layouts.size()]; - for (int i=0; i ifrt_executable_hlo_modules(ifrt::Executable* executable) { - auto modules = xla::ValueOrThrow(executable->GetHloModules()); - auto modules_ptr = new xla::HloModule*[modules.size()]; - for (int i=0; i(*pjrt_executable); -// auto options = std::make_unique(*compile_options); -// return xla::ValueOrThrow(ifrt::PjRtExecutable::Create(pjrt_executable_shared, std::move(options))).release(); -// } - -extern "C" void ifrt_pjrt_executable_free(ifrt::PjRtExecutable* executable) { - delete executable; -} - -extern "C" xla::PjRtExecutable* ifrt_pjrt_executable_pjrt_executable(ifrt::PjRtExecutable* executable) { - return executable->pjrt_executable(); -} -#pragma endregion - -#pragma region xla::ifrt::LoadedExecutable -extern "C" ifrt::Client* ifrt_loadedexecutable_client(ifrt::LoadedExecutable* executable) { - return executable->client(); -} - -extern "C" const char* ifrt_loadedexecutable_name(ifrt::LoadedExecutable* executable) { - return cstr_from_string(executable->name()); -} - -extern "C" const char* ifrt_loadedexecutable_fingerprint(ifrt::LoadedExecutable* executable) { - auto result = xla::ValueOrThrow(executable->Fingerprint()); - if (!result.has_value()) return ""; - return cstr_from_string(result.value()); -} - -extern "C" const char* ifrt_loadedexecutable_serialize(ifrt::LoadedExecutable* executable) { - return cstr_from_string(xla::ValueOrThrow(executable->Serialize())); -} - -extern "C" ifrt::Future<> ifrt_loadedexecutable_get_ready_future(ifrt::LoadedExecutable* executable) { - return executable->GetReadyFuture(); -} - -extern "C" int ifrt_loadedexecutable_num_devices(ifrt::LoadedExecutable* executable) { - return executable->num_devices(); -} - -extern "C" int64_t ifrt_loadedexecutable_size(ifrt::LoadedExecutable* executable) { - return executable->SizeOfGeneratedCodeInBytes(); -} - -// TODO xla::ifrt::GetCompiledMemoryStats - -extern "C" std::tuple ifrt_loadedexecutable_parameter_shardings(ifrt::LoadedExecutable* executable) { - auto shardings = executable->GetParameterShardings(); - if (!shardings.has_value()) return std::make_tuple(0, nullptr); - return std::make_tuple(shardings.value().size(), shardings.value().data()); -} - -extern "C" std::tuple ifrt_loadedexecutable_output_shardings(ifrt::LoadedExecutable* executable) { - auto shardings = executable->GetOutputShardings(); - if (!shardings.has_value()) return std::make_tuple(0, nullptr); - return std::make_tuple(shardings.value().size(), shardings.value().data()); -} - -extern "C" std::tuple ifrt_loadedexecutable_parameter_layouts(ifrt::LoadedExecutable* executable) { - auto layouts = xla::ValueOrThrow(executable->GetParameterLayouts()); - auto layouts_ptr = new xla::PjRtLayout*[layouts.size()]; - for (int i=0; i ifrt_loadedexecutable_output_layouts(ifrt::LoadedExecutable* executable) { - auto layouts = xla::ValueOrThrow(executable->GetOutputLayouts()); - auto layouts_ptr = new xla::PjRtLayout*[layouts.size()]; - for (int i=0; i ifrt_loadedexecutable_hlo_modules(ifrt::LoadedExecutable* executable) { - auto modules = xla::ValueOrThrow(executable->GetHloModules()); - auto modules_ptr = new xla::HloModule*[modules.size()]; - for (int i=0; i** futures, size_t futures_size) { -// std::vector arguments(args, args + args_size); -// std::vector result(results, results + results_size); -// std::vector*> future(futures, futures + futures_size); -// return xla::ValueOrThrow(executable->Execute(arguments, result, future)); -// } - -extern "C" ifrt::Future<> ifrt_loadedexecutable_delete(ifrt::LoadedExecutable* executable) { - return executable->Delete(); -} - -extern "C" bool ifrt_loadedexecutable_is_deleted(ifrt::LoadedExecutable* executable) { - return executable->IsDeleted(); -} - -extern "C" std::tuple ifrt_loadedexecutable_addressable_devices(ifrt::LoadedExecutable* executable) { - auto devices = executable->addressable_devices(); - return std::make_tuple(devices.size(), devices.data()); -} - -// TODO auxiliary functions for xla::ifrt::LoadedExecutable::ExecuteResult -#pragma endregion - -#pragma region xla::ifrt::PjRtLoadedExecutable -// TODO add support for LoadedHostCallback -// TODO there are problems with using `make_shared -// extern "C" ifrt::LoadedExecutable* ifrt_pjrt_loadedexecutable_ctor(ifrt::PjRtCompatibleClient* client, xla::PjRtLoadedExecutable* pjrt_loaded_executable) { -// auto pjrt_loaded_executable_ptr = std::make_shared(*pjrt_loaded_executable); -// return xla::ValueOrThrow(ifrt::PjRtLoadedExecutable::Create(client, pjrt_loaded_executable_ptr, std::vector>())).release(); -// } - -// TODO add support for LoadedHostCallback -extern "C" ifrt::LoadedExecutable* ifrt_pjrt_loadedexecutable_ctor_from_mlir_module(ifrt::PjRtCompatibleClient* client, mlir::ModuleOp* module, xla::CompileOptions* compile_options) { - return xla::ValueOrThrow(ifrt::PjRtLoadedExecutable::Create(client, *module, *compile_options, std::vector>())).release(); -} - -extern "C" void ifrt_pjrt_loadedexecutable_free(ifrt::PjRtLoadedExecutable* executable) { - delete executable; -} - -extern "C" xla::PjRtLoadedExecutable* ifrt_pjrt_loadedexecutable_pjrt_loadedexecutable(ifrt::PjRtLoadedExecutable* executable) { - return executable->pjrt_loaded_executable(); -} -#pragma endregion - -#pragma region xla::ifrt::CustomCallProgram -#pragma endregion - -#pragma region xla::ifrt::HloProgram -extern "C" ifrt::HloProgram* ifrt_hloprogram_ctor() { - return new ifrt::HloProgram(); -} - -extern "C" ifrt::HloProgram* ifrt_hloprogram_ctor_with_module(mlir::ModuleOp* module) { - return new ifrt::HloProgram(*module); -} - -// extern "C" ifrt::HloProgram* ifrt_hloprogram_ctor_with_context_and_module(mlir::MLIRContext* context, mlir::ModuleOp* module) { -// auto context_ptr = std::make_unique(*context); -// return new ifrt::HloProgram(std::move(context_ptr), *module); -// } -#pragma endregion - -#pragma region xla::ifrt::Compiler -extern "C" ifrt::LoadedExecutable* ifrt_compiler_compile(ifrt::Compiler* compiler, ifrt::Program* program) { - // apparently ifrt::CompileOptions is a legacy artifact so we don't use it and set directly to the default - auto program_ptr = std::make_unique(*program); - auto options = std::make_unique(); - return xla::ValueOrThrow(compiler->Compile(std::move(program_ptr), std::move(options))).release(); -} - -extern "C" ifrt::Executable* ifrt_compiler_compile_with_topology(ifrt::Compiler* compiler, ifrt::Program* program, const ifrt::Topology* topology) { - // apparently ifrt::CompileOptions is a legacy artifact so we don't use it and set directly to the default - auto options = std::make_unique(); - auto program_ptr = std::make_unique(*program); - auto exec_ptr = xla::ValueOrThrow(compiler->Compile(std::move(program_ptr), *topology, std::move(options))).release(); - return exec_ptr; -} - -extern "C" ifrt::LoadedExecutable* ifrt_compiler_deserialize_loadedexecutable(ifrt::Compiler* compiler, const char* data) { - // apparently ifrt::DeserializeExecutableOptions is a legacy artifact so we don't use it and set directly to the default - auto options = std::make_unique(); - return xla::ValueOrThrow(compiler->DeserializeLoadedExecutable(std::string(data), std::move(options))).release(); -} -#pragma endregion - -#pragma region xla::ifrt::PjRtCompiler -extern "C" ifrt::PjRtCompiler* ifrt_pjrt_compiler_ctor(ifrt::PjRtClient* client) { - return new ifrt::PjRtCompiler(client); -} - -extern "C" void ifrt_pjrt_compiler_free(ifrt::PjRtCompiler* compiler) { - delete compiler; -} -#pragma endregion - -JLCXX_MODULE reactant_module_ifrt(jlcxx::Module& mod) { - // mod.add_type("Value") - // .method("client", &ifrt::Value::client) - // .method("get_ready_future", &ifrt::Value::GetReadyFuture) - // .method("delete!", &ifrt::Value::Delete) - // .method("isdeleted", &ifrt::Value::IsDeleted) - // .method("debug_string", &ifrt::Value::DebugString); - - mod.add_bits("DTypeKind", jlcxx::julia_type("CppEnum")); - mod.set_const("DTypeKindInvalid", ifrt::DType::Kind::kInvalid); - mod.set_const("DTypeKindPred", ifrt::DType::Kind::kPred); - mod.set_const("DTypeKindS2", ifrt::DType::Kind::kS2); - mod.set_const("DTypeKindS4", ifrt::DType::Kind::kS4); - mod.set_const("DTypeKindS8", ifrt::DType::Kind::kS8); - mod.set_const("DTypeKindS16", ifrt::DType::Kind::kS16); - mod.set_const("DTypeKindS32", ifrt::DType::Kind::kS32); - mod.set_const("DTypeKindS64", ifrt::DType::Kind::kS64); - mod.set_const("DTypeKindU2", ifrt::DType::Kind::kU2); - mod.set_const("DTypeKindU4", ifrt::DType::Kind::kU4); - mod.set_const("DTypeKindU8", ifrt::DType::Kind::kU8); - mod.set_const("DTypeKindU16", ifrt::DType::Kind::kU16); - mod.set_const("DTypeKindU32", ifrt::DType::Kind::kU32); - mod.set_const("DTypeKindU64", ifrt::DType::Kind::kU64); - mod.set_const("DTypeKindF16", ifrt::DType::Kind::kF16); - mod.set_const("DTypeKindF32", ifrt::DType::Kind::kF32); - mod.set_const("DTypeKindF64", ifrt::DType::Kind::kF64); - mod.set_const("DTypeKindBF16", ifrt::DType::Kind::kBF16); - mod.set_const("DTypeKindC64", ifrt::DType::Kind::kC64); - mod.set_const("DTypeKindC128", ifrt::DType::Kind::kC128); - mod.set_const("DTypeKindToken", ifrt::DType::Kind::kToken); - mod.set_const("DTypeKindOpaque", ifrt::DType::Kind::kOpaque); - mod.set_const("DTypeKindF8E3M4", ifrt::DType::Kind::kF8E3M4); - mod.set_const("DTypeKindF8E4M3", ifrt::DType::Kind::kF8E4M3); - mod.set_const("DTypeKindF8E4M3FN", ifrt::DType::Kind::kF8E4M3FN); - mod.set_const("DTypeKindF8E4M3B11FNUZ", ifrt::DType::Kind::kF8E4M3B11FNUZ); - mod.set_const("DTypeKindF8E4M3FNUZ", ifrt::DType::Kind::kF8E4M3FNUZ); - mod.set_const("DTypeKindF8E5M2", ifrt::DType::Kind::kF8E5M2); - mod.set_const("DTypeKindF8E5M2FNUZ", ifrt::DType::Kind::kF8E5M2FNUZ); - mod.set_const("DTypeKindString", ifrt::DType::Kind::kString); - - mod.add_type("DType") - .constructor(&ifrt::DType::DType) - .method("kind", &ifrt::DType::kind) - .method("byte_size", &ifrt::DType::byte_size) - .method("bit_size", &ifrt::DType::bit_size); - mod.set_override_module(jl_base_module); - mod.method("==", [](ifrt::DType* a, ifrt::DType* b) { return *a == *b; }); - mod.method("!=", [](ifrt::DType* a, ifrt::DType* b) { return *a != *b; }); - mod.unset_override_module(); -} - -#pragma endregion diff --git a/deps/ReactantExtra/src/IFRT.cpp b/deps/ReactantExtra/src/IFRT.cpp new file mode 100644 index 000000000..3f3cda513 --- /dev/null +++ b/deps/ReactantExtra/src/IFRT.cpp @@ -0,0 +1,905 @@ +#include "jlcxx/jlcxx.hpp" + +// IFRT +#include "xla/python/ifrt/value.h" +#include "xla/python/ifrt/tuple.h" +#include "xla/python/ifrt/dtype.h" +#include "xla/python/ifrt/shape.h" +#include "xla/python/ifrt/index.h" +#include "xla/python/ifrt/index_domain.h" +#include "xla/python/ifrt/memory.h" +#include "xla/python/ifrt/device.h" +#include "xla/python/ifrt/sharding.h" +#include "xla/python/ifrt/array.h" +#include "xla/python/ifrt/topology.h" +#include "xla/python/ifrt/client.h" +#include "xla/python/ifrt/host_callback.h" +#include "xla/python/ifrt/executable.h" +#include "xla/python/ifrt/hlo/hlo_program.h" +#include "xla/python/ifrt/compiler.h" + +// IFRT - PJRT +#include "xla/python/pjrt_ifrt/pjrt_dtype.h" +#include "xla/python/pjrt_ifrt/pjrt_tuple.h" +#include "xla/python/pjrt_ifrt/pjrt_memory.h" +#include "xla/python/pjrt_ifrt/pjrt_device.h" +#include "xla/python/pjrt_ifrt/pjrt_array.h" +#include "xla/python/pjrt_ifrt/pjrt_topology.h" +#include "xla/python/pjrt_ifrt/pjrt_client.h" +#include "xla/python/pjrt_ifrt/pjrt_host_callback.h" +#include "xla/python/pjrt_ifrt/pjrt_executable.h" +#include "xla/python/pjrt_ifrt/pjrt_compiler.h" + +using namespace xla; + +// #pragma region xla::ifrt + +// #pragma region xla::ifrt::Value +// extern "C" ifrt::Client* ifrt_value_client(ifrt::Value* value) { +// return value->client(); +// } + +// extern "C" ifrt::Future<> ifrt_value_get_ready_future(ifrt::Value* value) { +// return value->GetReadyFuture(); +// } + +// extern "C" ifrt::Future<> ifrt_value_delete(ifrt::Value* value) { +// return value->Delete(); +// } + +// extern "C" bool ifrt_value_is_deleted(ifrt::Value* value) { +// return value->IsDeleted(); +// } + +// extern "C" const char* ifrt_value_debug_string(ifrt::Value* value) { +// return cstr_from_string(value->DebugString()); +// } +// #pragma endregion + +// #pragma region xla::ifrt::Tuple +// extern "C" int ifrt_tuple_arity(ifrt::Tuple* tuple) { +// return tuple->Arity(); +// } + +// // TODO ifrt::Tuple::Unpack +// #pragma endregion + +// #pragma region xla::ifrt::PjRtTuple +// extern "C" ifrt::PjRtTuple* ifrt_pjrt_tuple_ctor(ifrt::PjRtCompatibleClient* client, ifrt::Value* values, int nvalues) { +// auto values_ptr = new tsl::RCReference[nvalues]; +// for (int i=0; i(); +// values_ptr[i].reset(&values[i]); +// } +// auto span = absl::Span>(values_ptr, nvalues); +// return xla::ValueOrThrow(ifrt::PjRtTuple::Create(client, span)).release(); +// } + +// extern "C" void ifrt_pjrt_tuple_free(ifrt::PjRtTuple* tuple) { +// delete tuple; +// } +// #pragma endregion + +// #pragma region xla::ifrt::Shape +// extern "C" ifrt::Shape* ifrt_shape_ctor(const int64_t* dims, size_t dims_size) { +// return new ifrt::Shape(absl::Span(dims, dims_size)); +// } + +// extern "C" void ifrt_shape_free(ifrt::Shape* shape) { +// delete shape; +// } + +// extern "C" const int64_t* ifrt_shape_dims(ifrt::Shape* shape) { +// return shape->dims().data(); +// } + +// extern "C" int64_t ifrt_shape_dims_num_elements(ifrt::Shape* shape) { +// return shape->num_elements(); +// } + +// extern "C" const char* ifrt_shape_debug_string(ifrt::Shape* shape) { +// return cstr_from_string(shape->DebugString()); +// } +// #pragma endregion + +// #pragma region xla::ifrt::DynamicShape +// extern "C" ifrt::DynamicShape* ifrt_dynamicshape_ctor(ifrt::Shape* shape, const bool* dynamic_dims_mask) { +// auto tag = ifrt::BoundedDynamicShapeTag(absl::Span(dynamic_dims_mask, shape->dims().size())); +// auto dynshape = xla::ValueOrThrow(ifrt::DynamicShape::Create(*shape, tag)); +// return new ifrt::DynamicShape(dynshape); +// } + +// extern "C" void ifrt_dynamicshape_free(ifrt::DynamicShape* shape) { +// delete shape; +// } + +// // TODO ifrt::DynamicShape::GetTag + +// extern "C" bool ifrt_dynamicshape_eq(ifrt::DynamicShape* shape1, ifrt::DynamicShape* shape2) { +// return *shape1 == *shape2; +// } + +// extern "C" bool ifrt_dynamicshape_ne(ifrt::DynamicShape* shape1, ifrt::DynamicShape* shape2) { +// return *shape1 != *shape2; +// } + +// extern "C" ifrt::Shape* ifrt_dynamicshape_get_padded_shape(ifrt::DynamicShape* shape) { +// auto padshape = xla::ValueOrThrow(shape->GetPaddedShape()); +// return new ifrt::Shape(padshape); +// } + +// extern "C" bool ifrt_dynamicshape_is_dynamic_dim(ifrt::DynamicShape* shape, int dimension) { +// return shape->IsDynamicDim(dimension); +// } + +// extern "C" const char* ifrt_dynamicshape_debug_string(ifrt::DynamicShape* shape) { +// return cstr_from_string(shape->DebugString()); +// } +// #pragma endregion + +// #pragma region xla::ifrt::Index +// extern "C" ifrt::Index* ifrt_index_ctor(const int64_t* elements, size_t elements_size) { +// return new ifrt::Index(absl::Span(elements, elements_size)); +// } + +// extern "C" ifrt::Index* ifrt_index_zeros(int num_elements) { +// return new ifrt::Index(ifrt::Index::Zeros(num_elements)); +// } + +// extern "C" void ifrt_index_free(ifrt::Index* index) { +// delete index; +// } + +// extern "C" const int64_t* ifrt_index_elements(ifrt::Index* index) { +// return index->elements().data(); +// } + +// extern "C" int ifrt_index_count(ifrt::Index* index) { +// return index->elements().size(); +// } + +// extern "C" bool ifrt_index_eq(ifrt::Index* index1, ifrt::Index* index2) { +// return *index1 == *index2; +// } + +// extern "C" bool ifrt_index_ne(ifrt::Index* index1, ifrt::Index* index2) { +// return *index1 != *index2; +// } + +// extern "C" ifrt::Index* ifrt_index_add(ifrt::Index* index, ifrt::Index* offset) { +// return new ifrt::Index(*index + *offset); +// } + +// extern "C" ifrt::Index* ifrt_index_sub(ifrt::Index* index, ifrt::Index* offset) { +// return new ifrt::Index(*index - *offset); +// } + +// // WARN we're not checking if the multiplier has the same size as the index +// extern "C" ifrt::Index* ifrt_index_mul(ifrt::Index* index, const int64_t* multiplier) { +// return new ifrt::Index(*index * absl::Span(multiplier, ifrt_index_count(index))); +// } + +// extern "C" void ifrt_index_add_inplace(ifrt::Index* index, ifrt::Index* offset) { +// *index += *offset; +// } + +// extern "C" void ifrt_index_sub_inplace(ifrt::Index* index, ifrt::Index* offset) { +// *index -= *offset; +// } + +// extern "C" void ifrt_index_mul_inplace(ifrt::Index* index, const int64_t* multiplier) { +// *index *= absl::Span(multiplier, ifrt_index_count(index)); +// } + +// extern "C" const char* ifrt_index_debug_string(ifrt::Index* index) { +// return cstr_from_string(index->DebugString()); +// } +// #pragma endregion + +// #pragma region xla::ifrt::IndexDomain +// extern "C" ifrt::IndexDomain* ifrt_indexdomain_ctor(ifrt::Shape* shape) { +// return new ifrt::IndexDomain(*shape); +// } + +// extern "C" ifrt::IndexDomain* ifrt_indexdomain_ctor_with_origin(ifrt::Index* origin, ifrt::Shape* shape) { +// return new ifrt::IndexDomain(*origin, *shape); +// } + +// extern "C" void ifrt_indexdomain_free(ifrt::IndexDomain* index_domain) { +// delete index_domain; +// } + +// extern "C" const ifrt::Index* ifrt_indexdomain_origin(ifrt::IndexDomain* index_domain) { +// return &index_domain->origin(); +// } + +// extern "C" const ifrt::Shape* ifrt_indexdomain_shape(ifrt::IndexDomain* index_domain) { +// return &index_domain->shape(); +// } + +// extern "C" bool ifrt_indexdomain_eq(ifrt::IndexDomain* index_domain1, ifrt::IndexDomain* index_domain2) { +// return *index_domain1 == *index_domain2; +// } + +// extern "C" bool ifrt_indexdomain_ne(ifrt::IndexDomain* index_domain1, ifrt::IndexDomain* index_domain2) { +// return *index_domain1 != *index_domain2; +// } + +// extern "C" ifrt::IndexDomain* ifrt_indexdomain_add(ifrt::IndexDomain* index_domain, ifrt::Index* offset) { +// return new ifrt::IndexDomain(*index_domain + *offset); +// } + +// extern "C" ifrt::IndexDomain* ifrt_indexdomain_sub(ifrt::IndexDomain* index_domain, ifrt::Index* offset) { +// return new ifrt::IndexDomain(*index_domain - *offset); +// } + +// extern "C" void ifrt_indexdomain_add_inplace(ifrt::IndexDomain* index_domain, ifrt::Index* offset) { +// *index_domain += *offset; +// } + +// extern "C" void ifrt_indexdomain_sub_inplace(ifrt::IndexDomain* index_domain, ifrt::Index* offset) { +// *index_domain -= *offset; +// } + +// extern "C" const char* ifrt_indexdomain_debug_string(ifrt::IndexDomain* index_domain) { +// return cstr_from_string(index_domain->DebugString()); +// } +// #pragma endregion + +// #pragma region xla::ifrt::MemoryKind +// // Pass a nullptr to create a `MemoryKind` with no memory chosen. +// extern "C" ifrt::MemoryKind* ifrt_memorykind_ctor(const char* memory_kind) { +// if (memory_kind == nullptr) +// return new ifrt::MemoryKind(); +// return new ifrt::MemoryKind(std::string(memory_kind)); +// } + +// extern "C" void ifrt_memorykind_free(ifrt::MemoryKind* memory_kind) { +// delete memory_kind; +// } + +// extern "C" bool ifrt_memorykind_eq(ifrt::MemoryKind* mk1, ifrt::MemoryKind* mk2) { +// return *mk1 == *mk2; +// } + +// extern "C" bool ifrt_memorykind_ne(ifrt::MemoryKind* mk1, ifrt::MemoryKind* mk2) { +// return *mk1 != *mk2; +// } + +// extern "C" const char* ifrt_memorykind_string(ifrt::MemoryKind* memory_kind) { +// if (memory_kind->memory_kind().has_value()) +// return cstr_from_string(memory_kind->memory_kind().value()); +// else +// return nullptr; +// } + +// extern "C" ifrt::MemoryKind* ifrt_memorykind_canonicalize(ifrt::MemoryKind* memory_kind, ifrt::Device* device) { +// return new ifrt::MemoryKind(CanonicalizeMemoryKind(*memory_kind, device)); +// } +// #pragma endregion + +// #pragma region xla::ifrt::Memory +// // MemoryId is a struct with a single int32_t field --> check out xla/python/ifrt/memory.h +// extern "C" ifrt::MemoryId ifrt_memory_id(ifrt::Memory* memory) { +// return memory->Id(); +// } + +// extern "C" const ifrt::MemoryKind* ifrt_memory_kind(ifrt::Memory* memory) { +// return &(memory->Kind()); +// } + +// extern "C" const char* ifrt_memory_to_string(ifrt::Memory* memory) { +// return cstr_from_string(memory->ToString()); +// } + +// extern "C" const char* ifrt_memory_debug_string(ifrt::Memory* memory) { +// return cstr_from_string(memory->DebugString()); +// } + +// extern "C" std::tuple ifrt_memory_devices(ifrt::Memory* memory) { +// auto devices = memory->Devices(); +// return std::make_tuple(devices.size(), devices.data()); +// } +// #pragma endregion + +// #pragma region xla::ifrt::PjRtMemory +// extern "C" ifrt::PjRtMemory* ifrt_pjrt_memory_ctor(ifrt::PjRtClient* client, xla::PjRtMemorySpace* memory_space) { +// return new ifrt::PjRtMemory(client, memory_space); +// } + +// extern "C" void ifrt_pjrt_memory_free(ifrt::PjRtMemory* memory) { +// delete memory; +// } + +// extern "C" ifrt::PjRtClient* ifrt_pjrt_memory_client(ifrt::PjRtMemory* memory) { +// return memory->client(); +// } + +// extern "C" xla::PjRtMemorySpace* ifrt_pjrt_memory_space(ifrt::PjRtMemory* memory) { +// return memory->pjrt_memory(); +// } +// #pragma endregion + +// #pragma region xla::ifrt::Device +// extern "C" ifrt::Client* ifrt_device_client(ifrt::Device* device) { +// return device->client(); +// } + +// // DeviceId is a struct with a single int32_t field --> check out xla/pjrt/pjrt_common.h +// extern "C" ifrt::DeviceId ifrt_device_id(ifrt::Device* device) { +// return device->Id(); +// } + +// // TODO ifrt_device_attributes + +// extern "C" const char* ifrt_device_kind(ifrt::Device* device) { +// return cstr_from_string(device->Kind()); +// } + +// extern "C" const char* ifrt_device_to_string(ifrt::Device* device) { +// return cstr_from_string(device->ToString()); +// } + +// extern "C" const char* ifrt_device_debug_string(ifrt::Device* device) { +// return cstr_from_string(device->DebugString()); +// } + +// extern "C" ifrt::Memory* ifrt_device_default_memory(ifrt::Device* device) { +// return xla::ValueOrThrow(device->DefaultMemory()); +// } + +// // TODO ifrt_device_memories + +// extern "C" bool ifrt_device_is_addressable(ifrt::Device* device) { +// return device->IsAddressable(); +// } + +// extern "C" int ifrt_device_process_index(ifrt::Device* device) { +// return device->ProcessIndex(); +// } +// #pragma endregion + +// #pragma region xla::ifrt::PjRtDevice +// // DeviceId is a struct with a single int32_t field --> check out xla/pjrt/pjrt_common.h +// // TODO support `attributes` parameter +// extern "C" ifrt::PjRtDevice* ifrt_pjrt_device_ctor(ifrt::PjRtClient* client, ifrt::DeviceId device_id, const char* kind, const char* to_string, const char* debug_string, int process_index, xla::PjRtDevice* pjrt_device) { +// return new ifrt::PjRtDevice(client, device_id, kind, to_string, debug_string, process_index, absl::flat_hash_map(), pjrt_device); +// } + +// extern "C" void ifrt_pjrt_device_free(ifrt::PjRtDevice* device) { +// delete device; +// } + +// extern "C" xla::PjRtDevice* ifrt_pjrt_device_pjrt_device(ifrt::PjRtDevice* device) { +// return device->pjrt_device(); +// } +// #pragma endregion + +// #pragma region xla::ifrt::Sharding +// // TODO ifrt_sharding_devices +// // TODO ifrt_sharding_memory_kind + +// // extern "C" void ifrt_sharding_disassemble(ifrt::Sharding* sharding, ifrt::Shape* shape, char** error) { +// // auto status = sharding->Disassemble(*shape); +// // if (!status.ok()) { +// // auto str = status.message(); +// // char* err = (char*)malloc(str.size()+1); +// // memcpy(err, str.data(), str.size()+1); +// // *error = err; +// // } +// // } + +// // TODO ifrt_sharding_disassemble_dynamic_shape +// // TODO ifrt_sharding_index_domains + +// extern "C" const char* ifrt_sharding_debug_string(ifrt::Sharding* sharding) { +// return cstr_from_string(sharding->DebugString()); +// } +// #pragma endregion + +// #pragma region xla::ifrt::Array +// extern "C" ifrt::DType* ifrt_array_dtype(ifrt::Array* array) { +// return new ifrt::DType(array->dtype()); +// } + +// extern "C" const ifrt::Shape* ifrt_array_shape(ifrt::Array* array) { +// return &(array->shape()); +// } + +// extern "C" const ifrt::Sharding* ifrt_array_sharding(ifrt::Array* array) { +// return &(array->sharding()); +// } + +// extern "C" PjRtLayout* ifrt_array_layout(ifrt::Array* array) { +// return xla::ValueOrThrow(array->layout()).release(); +// } + +// // TODO xla::ifrt::Array::DisassembleIntoSingleDeviceArrays +// // TODO xla::ifrt::Array::FullyReplicatedShard + +// extern "C" ifrt::Future<> ifrt_array_copy_to_host_buffer(ifrt::Array* array, void* data, const int64_t* byte_strides, int semantics) { +// return array->CopyToHostBuffer(data, absl::Span(byte_strides, array->shape().num_elements()), ifrt::ArrayCopySemantics(semantics)); +// } +// #pragma endregion + +// #pragma region xla::ifrt::PjRtArray +// // TODO constructors / `Create` + +// extern "C" std::tuple ifrt_pjrt_array_pjrt_buffers(ifrt::PjRtArray* array) { +// auto buffers = array->pjrt_buffers(); +// auto buffers_ptr = new xla::PjRtBuffer*[buffers.size()]; +// for (int i=0; iplatform_name()); +// } + +// extern "C" const char* ifrt_topology_platform_version(ifrt::Topology* topology) { +// return cstr_from_string(topology->platform_version()); +// } + +// // returns PjRtPlatformId which is a type alias for uint64_t +// extern "C" uint64_t ifrt_topology_platform_id(ifrt::Topology* topology) { +// return topology->platform_id(); +// } + +// extern "C" std::tuple ifrt_topology_device_descriptions(ifrt::Topology* topology) { +// auto descriptions = topology->DeviceDescriptions(); +// auto descriptions_ptr = new const xla::PjRtDeviceDescription*[descriptions.size()]; +// for (int i=0; iSerialize())); +// } + +// // TODO xla::ifrt::Topology::Attributes + +// #pragma endregion + +// #pragma region xla::ifrt::PjRtTopology +// extern "C" ifrt::PjRtTopology* ifrt_pjrt_topology_ctor(const xla::PjRtTopologyDescription* description) { +// return new ifrt::PjRtTopology(std::shared_ptr{description}); +// } + +// extern "C" const xla::PjRtTopologyDescription* ifrt_pjrt_topology_description(ifrt::PjRtTopology* topology) { +// return topology->description().get(); +// } +// #pragma endregion + +// #pragma region xla::ifrt::Client +// extern "C" int ifrt_client_device_count(ifrt::Client* client) { +// return client->device_count(); +// } + +// extern "C" int ifrt_client_addressable_device_count(ifrt::Client* client) { +// return client->addressable_device_count(); +// } + +// extern "C" ifrt::Device* const* ifrt_client_devices(ifrt::Client* client) { +// return client->devices().data(); +// } + +// extern "C" ifrt::Device* const* ifrt_client_addressable_devices(ifrt::Client* client) { +// return client->addressable_devices().data(); +// } + +// extern "C" int ifrt_client_process_index(ifrt::Client* client) { +// return client->process_index(); +// } + +// // TODO xla::ifrt::Client::GetDefaultDeviceAssignment + +// extern "C" ifrt::Device* ifrt_client_lookup_device(ifrt::Client* client, int device_id) { +// return xla::ValueOrThrow(client->LookupDevice(ifrt::DeviceId(device_id))); +// } + +// extern "C" ifrt::Device* ifrt_client_lookup_addressable_device(ifrt::Client* client, int device_id) { +// return xla::ValueOrThrow(client->LookupAddressableDevice(device_id)); +// } + +// extern "C" ifrt::Compiler* ifrt_client_default_compiler(ifrt::Client* client) { +// return client->GetDefaultCompiler(); +// } + +// // TODO ifrt_client_topology_for_devices +// // TODO ifrt_client_default_layout_for_device +// #pragma endregion + +// #pragma region xla::ifrt::PjRtClient +// // TODO support more parameters of `PjRtClient::CreateOptions` +// extern "C" ifrt::PjRtClient* ifrt_pjrt_client_ctor(xla::PjRtClient* pjrt_client) { +// return xla::ValueOrThrow(ifrt::PjRtClient::Create(ifrt::PjRtClient::CreateOptions{std::shared_ptr{pjrt_client}})).release(); +// } + +// extern "C" void ifrt_pjrt_client_free(ifrt::PjRtClient* client) { +// delete client; +// } + +// extern "C" xla::PjRtClient* ifrt_pjrt_client_pjrt_client(ifrt::PjRtClient* client) { +// return client->pjrt_client(); +// } + +// // TODO there are problems with using `make_shared +// // extern "C" ifrt::PjRtCompatibleArray* ifrt_pjrt_client_create_pjrt_array(ifrt::PjRtClient* client, xla::PjRtBuffer* pjrt_buffer) { +// // auto buffer_ptr = std::make_shared(*pjrt_buffer); +// // return xla::ValueOrThrow(client->CreatePjRtArray(buffer_ptr)).release(); +// // } + +// // TODO extern "C" ifrt::PjRtCompatibleArray* ifrt_pjrt_client_create_pjrt_array_from_buffers(ifrt::Shape* shape, ifrt::PjRtBuffer** pjrt_buffers, int num_buffers) {} + +// extern "C" ifrt::PjRtCompatibleDevice* ifrt_pjrt_client_lookup_pjrt_device(ifrt::PjRtClient* client, xla::PjRtDevice* pjrt_device) { +// return xla::ValueOrThrow(client->LookupPjRtDevice(pjrt_device)); +// } + +// extern "C" ifrt::PjRtCompatibleMemory* ifrt_pjrt_client_lookup_pjrt_memory(ifrt::PjRtClient* client, xla::PjRtMemorySpace* pjrt_memory_space) { +// return xla::ValueOrThrow(client->LookupPjRtMemory(pjrt_memory_space)); +// } +// #pragma endregion + +// #pragma region xla::ifrt::HostCallback +// extern "C" const char* ifrt_hostcallback_serialize(ifrt::HostCallback* host_callback) { +// return cstr_from_string(host_callback->Serialize()); +// } +// #pragma endregion + +// #pragma region xla::ifrt::LoadedHostCallback +// extern "C" ifrt::Client* ifrt_loadedhostcallback_client(ifrt::LoadedHostCallback* host_callback) { +// return host_callback->client(); +// } + +// extern "C" const char* ifrt_loadedhostcallback_serialize(ifrt::LoadedHostCallback* host_callback) { +// // auto msg = ; +// return cstr_from_string(xla::ValueOrThrow(host_callback->Serialize())); +// } +// #pragma endregion + +// #pragma region xla::ifrt::PjRtHostSendAndRecvLoadedHostCallback +// extern "C" ifrt::PjRtHostSendAndRecvLoadedHostCallback* ifrt_pjrt_hostsendandrecv_loadhostcallback_ctor(ifrt::PjRtClient* client, xla::HostCallback* host_callback) { +// auto xla_callback_ptr = std::make_unique(*host_callback); +// return new ifrt::PjRtHostSendAndRecvLoadedHostCallback(client, std::move(xla_callback_ptr)); +// } + +// extern "C" void ifrt_pjrt_hostsendandrecv_loadhostcallback_free(ifrt::PjRtHostSendAndRecvLoadedHostCallback* host_callback) { +// delete host_callback; +// } + +// extern "C" xla::HostCallback* ifrt_pjrt_hostsendandrecv_loadhostcallback_host_callback(ifrt::PjRtHostSendAndRecvLoadedHostCallback* host_callback) { +// return new xla::HostCallback(host_callback->host_callback()); +// } +// #pragma endregion + +// #pragma region xla::ifrt::Executable +// extern "C" const char* ifrt_executable_name(ifrt::Executable* executable) { +// return cstr_from_string(executable->name()); +// } + +// extern "C" const char* ifrt_executable_fingerprint(ifrt::Executable* executable) { +// auto result = xla::ValueOrThrow(executable->Fingerprint()); +// if (!result.has_value()) return ""; +// return cstr_from_string(result.value()); +// } + +// extern "C" const char* ifrt_executable_serialize(ifrt::Executable* executable) { +// return cstr_from_string(xla::ValueOrThrow(executable->Serialize())); +// } + +// extern "C" int ifrt_executable_num_devices(ifrt::Executable* executable) { +// return executable->num_devices(); +// } + +// extern "C" int64_t ifrt_executable_size(ifrt::Executable* executable) { +// return executable->SizeOfGeneratedCodeInBytes(); +// } + +// // TODO xla::ifrt::Executable::GetCompiledMemoryStats + +// extern "C" std::tuple ifrt_executable_parameter_shardings(ifrt::Executable* executable) { +// auto shardings = executable->GetParameterShardings(); +// if (!shardings.has_value()) return std::make_tuple(0, nullptr); +// return std::make_tuple(shardings.value().size(), shardings.value().data()); +// } + +// extern "C" std::tuple ifrt_executable_output_shardings(ifrt::Executable* executable) { +// auto shardings = executable->GetOutputShardings(); +// if (!shardings.has_value()) return std::make_tuple(0, nullptr); +// return std::make_tuple(shardings.value().size(), shardings.value().data()); +// } + +// extern "C" std::tuple ifrt_executable_parameter_layouts(ifrt::Executable* executable) { +// auto layouts = xla::ValueOrThrow(executable->GetParameterLayouts()); +// auto layouts_ptr = new xla::PjRtLayout*[layouts.size()]; +// for (int i=0; i ifrt_executable_output_layouts(ifrt::Executable* executable) { +// auto layouts = xla::ValueOrThrow(executable->GetOutputLayouts()); +// auto layouts_ptr = new xla::PjRtLayout*[layouts.size()]; +// for (int i=0; i ifrt_executable_hlo_modules(ifrt::Executable* executable) { +// auto modules = xla::ValueOrThrow(executable->GetHloModules()); +// auto modules_ptr = new xla::HloModule*[modules.size()]; +// for (int i=0; i(*pjrt_executable); +// // auto options = std::make_unique(*compile_options); +// // return xla::ValueOrThrow(ifrt::PjRtExecutable::Create(pjrt_executable_shared, std::move(options))).release(); +// // } + +// extern "C" void ifrt_pjrt_executable_free(ifrt::PjRtExecutable* executable) { +// delete executable; +// } + +// extern "C" xla::PjRtExecutable* ifrt_pjrt_executable_pjrt_executable(ifrt::PjRtExecutable* executable) { +// return executable->pjrt_executable(); +// } +// #pragma endregion + +// #pragma region xla::ifrt::LoadedExecutable +// extern "C" ifrt::Client* ifrt_loadedexecutable_client(ifrt::LoadedExecutable* executable) { +// return executable->client(); +// } + +// extern "C" const char* ifrt_loadedexecutable_name(ifrt::LoadedExecutable* executable) { +// return cstr_from_string(executable->name()); +// } + +// extern "C" const char* ifrt_loadedexecutable_fingerprint(ifrt::LoadedExecutable* executable) { +// auto result = xla::ValueOrThrow(executable->Fingerprint()); +// if (!result.has_value()) return ""; +// return cstr_from_string(result.value()); +// } + +// extern "C" const char* ifrt_loadedexecutable_serialize(ifrt::LoadedExecutable* executable) { +// return cstr_from_string(xla::ValueOrThrow(executable->Serialize())); +// } + +// extern "C" ifrt::Future<> ifrt_loadedexecutable_get_ready_future(ifrt::LoadedExecutable* executable) { +// return executable->GetReadyFuture(); +// } + +// extern "C" int ifrt_loadedexecutable_num_devices(ifrt::LoadedExecutable* executable) { +// return executable->num_devices(); +// } + +// extern "C" int64_t ifrt_loadedexecutable_size(ifrt::LoadedExecutable* executable) { +// return executable->SizeOfGeneratedCodeInBytes(); +// } + +// // TODO xla::ifrt::GetCompiledMemoryStats + +// extern "C" std::tuple ifrt_loadedexecutable_parameter_shardings(ifrt::LoadedExecutable* executable) { +// auto shardings = executable->GetParameterShardings(); +// if (!shardings.has_value()) return std::make_tuple(0, nullptr); +// return std::make_tuple(shardings.value().size(), shardings.value().data()); +// } + +// extern "C" std::tuple ifrt_loadedexecutable_output_shardings(ifrt::LoadedExecutable* executable) { +// auto shardings = executable->GetOutputShardings(); +// if (!shardings.has_value()) return std::make_tuple(0, nullptr); +// return std::make_tuple(shardings.value().size(), shardings.value().data()); +// } + +// extern "C" std::tuple ifrt_loadedexecutable_parameter_layouts(ifrt::LoadedExecutable* executable) { +// auto layouts = xla::ValueOrThrow(executable->GetParameterLayouts()); +// auto layouts_ptr = new xla::PjRtLayout*[layouts.size()]; +// for (int i=0; i ifrt_loadedexecutable_output_layouts(ifrt::LoadedExecutable* executable) { +// auto layouts = xla::ValueOrThrow(executable->GetOutputLayouts()); +// auto layouts_ptr = new xla::PjRtLayout*[layouts.size()]; +// for (int i=0; i ifrt_loadedexecutable_hlo_modules(ifrt::LoadedExecutable* executable) { +// auto modules = xla::ValueOrThrow(executable->GetHloModules()); +// auto modules_ptr = new xla::HloModule*[modules.size()]; +// for (int i=0; i** futures, size_t futures_size) { +// // std::vector arguments(args, args + args_size); +// // std::vector result(results, results + results_size); +// // std::vector*> future(futures, futures + futures_size); +// // return xla::ValueOrThrow(executable->Execute(arguments, result, future)); +// // } + +// extern "C" ifrt::Future<> ifrt_loadedexecutable_delete(ifrt::LoadedExecutable* executable) { +// return executable->Delete(); +// } + +// extern "C" bool ifrt_loadedexecutable_is_deleted(ifrt::LoadedExecutable* executable) { +// return executable->IsDeleted(); +// } + +// extern "C" std::tuple ifrt_loadedexecutable_addressable_devices(ifrt::LoadedExecutable* executable) { +// auto devices = executable->addressable_devices(); +// return std::make_tuple(devices.size(), devices.data()); +// } + +// // TODO auxiliary functions for xla::ifrt::LoadedExecutable::ExecuteResult +// #pragma endregion + +// #pragma region xla::ifrt::PjRtLoadedExecutable +// // TODO add support for LoadedHostCallback +// // TODO there are problems with using `make_shared +// // extern "C" ifrt::LoadedExecutable* ifrt_pjrt_loadedexecutable_ctor(ifrt::PjRtCompatibleClient* client, xla::PjRtLoadedExecutable* pjrt_loaded_executable) { +// // auto pjrt_loaded_executable_ptr = std::make_shared(*pjrt_loaded_executable); +// // return xla::ValueOrThrow(ifrt::PjRtLoadedExecutable::Create(client, pjrt_loaded_executable_ptr, std::vector>())).release(); +// // } + +// // TODO add support for LoadedHostCallback +// extern "C" ifrt::LoadedExecutable* ifrt_pjrt_loadedexecutable_ctor_from_mlir_module(ifrt::PjRtCompatibleClient* client, mlir::ModuleOp* module, xla::CompileOptions* compile_options) { +// return xla::ValueOrThrow(ifrt::PjRtLoadedExecutable::Create(client, *module, *compile_options, std::vector>())).release(); +// } + +// extern "C" void ifrt_pjrt_loadedexecutable_free(ifrt::PjRtLoadedExecutable* executable) { +// delete executable; +// } + +// extern "C" xla::PjRtLoadedExecutable* ifrt_pjrt_loadedexecutable_pjrt_loadedexecutable(ifrt::PjRtLoadedExecutable* executable) { +// return executable->pjrt_loaded_executable(); +// } +// #pragma endregion + +// #pragma region xla::ifrt::CustomCallProgram +// #pragma endregion + +// #pragma region xla::ifrt::HloProgram +// extern "C" ifrt::HloProgram* ifrt_hloprogram_ctor() { +// return new ifrt::HloProgram(); +// } + +// extern "C" ifrt::HloProgram* ifrt_hloprogram_ctor_with_module(mlir::ModuleOp* module) { +// return new ifrt::HloProgram(*module); +// } + +// // extern "C" ifrt::HloProgram* ifrt_hloprogram_ctor_with_context_and_module(mlir::MLIRContext* context, mlir::ModuleOp* module) { +// // auto context_ptr = std::make_unique(*context); +// // return new ifrt::HloProgram(std::move(context_ptr), *module); +// // } +// #pragma endregion + +// #pragma region xla::ifrt::Compiler +// extern "C" ifrt::LoadedExecutable* ifrt_compiler_compile(ifrt::Compiler* compiler, ifrt::Program* program) { +// // apparently ifrt::CompileOptions is a legacy artifact so we don't use it and set directly to the default +// auto program_ptr = std::make_unique(*program); +// auto options = std::make_unique(); +// return xla::ValueOrThrow(compiler->Compile(std::move(program_ptr), std::move(options))).release(); +// } + +// extern "C" ifrt::Executable* ifrt_compiler_compile_with_topology(ifrt::Compiler* compiler, ifrt::Program* program, const ifrt::Topology* topology) { +// // apparently ifrt::CompileOptions is a legacy artifact so we don't use it and set directly to the default +// auto options = std::make_unique(); +// auto program_ptr = std::make_unique(*program); +// auto exec_ptr = xla::ValueOrThrow(compiler->Compile(std::move(program_ptr), *topology, std::move(options))).release(); +// return exec_ptr; +// } + +// extern "C" ifrt::LoadedExecutable* ifrt_compiler_deserialize_loadedexecutable(ifrt::Compiler* compiler, const char* data) { +// // apparently ifrt::DeserializeExecutableOptions is a legacy artifact so we don't use it and set directly to the default +// auto options = std::make_unique(); +// return xla::ValueOrThrow(compiler->DeserializeLoadedExecutable(std::string(data), std::move(options))).release(); +// } +// #pragma endregion + +// #pragma region xla::ifrt::PjRtCompiler +// extern "C" ifrt::PjRtCompiler* ifrt_pjrt_compiler_ctor(ifrt::PjRtClient* client) { +// return new ifrt::PjRtCompiler(client); +// } + +// extern "C" void ifrt_pjrt_compiler_free(ifrt::PjRtCompiler* compiler) { +// delete compiler; +// } +// #pragma endregion + +// #pragma endregion + +JLCXX_MODULE reactant_module_ifrt(jlcxx::Module& mod) { + // mod.add_type("Value") + // .method("client", &ifrt::Value::client) + // .method("get_ready_future", &ifrt::Value::GetReadyFuture) + // .method("delete!", &ifrt::Value::Delete) + // .method("isdeleted", &ifrt::Value::IsDeleted) + // .method("debug_string", &ifrt::Value::DebugString); + + mod.add_bits("DTypeKind", jlcxx::julia_type("CppEnum")); + mod.set_const("DTypeKindInvalid", ifrt::DType::Kind::kInvalid); + mod.set_const("DTypeKindPred", ifrt::DType::Kind::kPred); + mod.set_const("DTypeKindS2", ifrt::DType::Kind::kS2); + mod.set_const("DTypeKindS4", ifrt::DType::Kind::kS4); + mod.set_const("DTypeKindS8", ifrt::DType::Kind::kS8); + mod.set_const("DTypeKindS16", ifrt::DType::Kind::kS16); + mod.set_const("DTypeKindS32", ifrt::DType::Kind::kS32); + mod.set_const("DTypeKindS64", ifrt::DType::Kind::kS64); + mod.set_const("DTypeKindU2", ifrt::DType::Kind::kU2); + mod.set_const("DTypeKindU4", ifrt::DType::Kind::kU4); + mod.set_const("DTypeKindU8", ifrt::DType::Kind::kU8); + mod.set_const("DTypeKindU16", ifrt::DType::Kind::kU16); + mod.set_const("DTypeKindU32", ifrt::DType::Kind::kU32); + mod.set_const("DTypeKindU64", ifrt::DType::Kind::kU64); + mod.set_const("DTypeKindF16", ifrt::DType::Kind::kF16); + mod.set_const("DTypeKindF32", ifrt::DType::Kind::kF32); + mod.set_const("DTypeKindF64", ifrt::DType::Kind::kF64); + mod.set_const("DTypeKindBF16", ifrt::DType::Kind::kBF16); + mod.set_const("DTypeKindC64", ifrt::DType::Kind::kC64); + mod.set_const("DTypeKindC128", ifrt::DType::Kind::kC128); + mod.set_const("DTypeKindToken", ifrt::DType::Kind::kToken); + // mod.set_const("DTypeKindOpaque", ifrt::DType::Kind::kOpaque); + mod.set_const("DTypeKindF8E3M4", ifrt::DType::Kind::kF8E3M4); + mod.set_const("DTypeKindF8E4M3", ifrt::DType::Kind::kF8E4M3); + mod.set_const("DTypeKindF8E4M3FN", ifrt::DType::Kind::kF8E4M3FN); + mod.set_const("DTypeKindF8E4M3B11FNUZ", ifrt::DType::Kind::kF8E4M3B11FNUZ); + mod.set_const("DTypeKindF8E4M3FNUZ", ifrt::DType::Kind::kF8E4M3FNUZ); + mod.set_const("DTypeKindF8E5M2", ifrt::DType::Kind::kF8E5M2); + mod.set_const("DTypeKindF8E5M2FNUZ", ifrt::DType::Kind::kF8E5M2FNUZ); + mod.set_const("DTypeKindString", ifrt::DType::Kind::kString); + + // TODO destructor?? + // mod.add_type("DType") + // .constructor(jlcxx::finalize_policy::no); + // .method("kind", &ifrt::DType::kind) + // .method("byte_size", &ifrt::DType::byte_size) + // .method("bit_size", &ifrt::DType::bit_size); + // mod.set_override_module(jl_base_module); + // mod.method("==", [](ifrt::DType* a, ifrt::DType* b) { return *a == *b; }); + // mod.method("!=", [](ifrt::DType* a, ifrt::DType* b) { return *a != *b; }); + // mod.method("copy", [](const ifrt::DType& x) { return ifrt::DType(x); }); + // // mod.method("string", &ifrt::DType::DebugString); + // mod.unset_override_module(); + + // TODO conversion from/to `xla::PrimitiveType` using `ToPrimitiveType`,`ToDType` +// mod.add_type("Shape") +// // .constructor<...>() // TODO explicit constructor +// ; +// mod.set_override_module(jl_base_module); +// mod.method("==", [](ifrt::Shape* a, ifrt::Shape* b) { return *a == *b; }); +// mod.method("!=", [](ifrt::Shape* a, ifrt::Shape* b) { return *a != *b; }); +// mod.method("copy", [](const ifrt::Shape& x) { return ifrt::Shape(x); }); +// mod.method("string", &ifrt::Shape::DebugString); +// mod.method("size", &ifrt::Shape::dims); +// mod.method("length", &ifrt::Shape::num_elements); +// mod.unset_override_module(); +} \ No newline at end of file From 89233ea26c467892273183a6f017694b5d55eb57 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Fri, 20 Dec 2024 11:39:25 +0100 Subject: [PATCH 24/29] fix symbol visibility --- deps/ReactantExtra/BUILD | 44 +----------------- deps/ReactantExtra/src/IFRT.cpp | 80 +++++++++++++++++---------------- src/IFRT.jl | 4 ++ 3 files changed, 47 insertions(+), 81 deletions(-) diff --git a/deps/ReactantExtra/BUILD b/deps/ReactantExtra/BUILD index f3a6814d0..154d047c2 100644 --- a/deps/ReactantExtra/BUILD +++ b/deps/ReactantExtra/BUILD @@ -382,48 +382,8 @@ cc_library( linkopts = select({ "//conditions:default": [], "@bazel_tools//src/conditions:darwin": [ - "-Wl,-exported_symbol,_stablehlo*", - "-Wl,-exported_symbol,_mlir*", - "-Wl,-exported_symbol,_InitializeLogs", - "-Wl,-exported_symbol,_SetLogLevel", - "-Wl,-exported_symbol,_SetModuleLogLevel", - "-Wl,-exported_symbol,_GetDefaultTargetTriple", - "-Wl,-exported_symbol,_enzymeActivityAttrGet", - "-Wl,-exported_symbol,_MakeCPUClient", - "-Wl,-exported_symbol,_MakeGPUClient", - "-Wl,-exported_symbol,_MakeTPUClient", - "-Wl,-exported_symbol,_LoadPjrtPlugin", - "-Wl,-exported_symbol,_InitializePjrtPlugin", - "-Wl,-exported_symbol,_GetCApiClient", - "-Wl,-exported_symbol,_ClientNumDevices", - "-Wl,-exported_symbol,_ClientNumAddressableDevices", - "-Wl,-exported_symbol,_ClientProcessIndex", - "-Wl,-exported_symbol,_ClientGetDevice", - "-Wl,-exported_symbol,_ClientGetAddressableDevice", - "-Wl,-exported_symbol,_ExecutableFree", - "-Wl,-exported_symbol,_BufferToDevice", - "-Wl,-exported_symbol,_BufferToClient", - "-Wl,-exported_symbol,_DeviceToClient", - "-Wl,-exported_symbol,_PjRtBufferFree", - "-Wl,-exported_symbol,_UnsafeBufferPointer", - "-Wl,-exported_symbol,_ArrayFromHostBuffer", - "-Wl,-exported_symbol,_BufferOnCPU", - "-Wl,-exported_symbol,_CopyBufferToDevice", - "-Wl,-exported_symbol,_BufferToHost", - "-Wl,-exported_symbol,_FreeClient", - "-Wl,-exported_symbol,_ClientCompile", - "-Wl,-exported_symbol,_LinkInModule", - "-Wl,-exported_symbol,_FreeFuture", - "-Wl,-exported_symbol,_FutureIsReady", - "-Wl,-exported_symbol,_FutureAwait", - "-Wl,-exported_symbol,_XLAExecute", - "-Wl,-exported_symbol,_RegisterDialects", - "-Wl,-exported_symbol,_InitializeRegistryAndPasses", - "-Wl,-exported_symbol,_ifrt_*", - "-Wl,-exported_symbol,_RegisterCustomCallTarget", - "-Wl,-exported_symbol,_ConvertLLVMToMLIR", - "-Wl,-exported_symbol,_reactant_*", - "-Wl,-exported_symbol,_register_julia_module", + "-fvisibility=default", + "-Wl,-unexported_symbol,_*llvm*", ], }), deps = [ diff --git a/deps/ReactantExtra/src/IFRT.cpp b/deps/ReactantExtra/src/IFRT.cpp index 3f3cda513..bb57f460b 100644 --- a/deps/ReactantExtra/src/IFRT.cpp +++ b/deps/ReactantExtra/src/IFRT.cpp @@ -1,36 +1,38 @@ #include "jlcxx/jlcxx.hpp" +#include // IFRT -#include "xla/python/ifrt/value.h" -#include "xla/python/ifrt/tuple.h" +#include "xla/python/ifrt/array.h" +#include "xla/python/ifrt/client.h" +#include "xla/python/ifrt/compiler.h" +#include "xla/python/ifrt/device.h" #include "xla/python/ifrt/dtype.h" -#include "xla/python/ifrt/shape.h" +#include "xla/python/ifrt/executable.h" +#include "xla/python/ifrt/hlo/hlo_program.h" +#include "xla/python/ifrt/host_callback.h" #include "xla/python/ifrt/index.h" #include "xla/python/ifrt/index_domain.h" #include "xla/python/ifrt/memory.h" -#include "xla/python/ifrt/device.h" +#include "xla/python/ifrt/shape.h" #include "xla/python/ifrt/sharding.h" -#include "xla/python/ifrt/array.h" #include "xla/python/ifrt/topology.h" -#include "xla/python/ifrt/client.h" -#include "xla/python/ifrt/host_callback.h" -#include "xla/python/ifrt/executable.h" -#include "xla/python/ifrt/hlo/hlo_program.h" -#include "xla/python/ifrt/compiler.h" +#include "xla/python/ifrt/tuple.h" +#include "xla/python/ifrt/value.h" // IFRT - PJRT -#include "xla/python/pjrt_ifrt/pjrt_dtype.h" -#include "xla/python/pjrt_ifrt/pjrt_tuple.h" -#include "xla/python/pjrt_ifrt/pjrt_memory.h" -#include "xla/python/pjrt_ifrt/pjrt_device.h" #include "xla/python/pjrt_ifrt/pjrt_array.h" -#include "xla/python/pjrt_ifrt/pjrt_topology.h" #include "xla/python/pjrt_ifrt/pjrt_client.h" -#include "xla/python/pjrt_ifrt/pjrt_host_callback.h" -#include "xla/python/pjrt_ifrt/pjrt_executable.h" #include "xla/python/pjrt_ifrt/pjrt_compiler.h" +#include "xla/python/pjrt_ifrt/pjrt_device.h" +#include "xla/python/pjrt_ifrt/pjrt_dtype.h" +#include "xla/python/pjrt_ifrt/pjrt_executable.h" +#include "xla/python/pjrt_ifrt/pjrt_host_callback.h" +#include "xla/python/pjrt_ifrt/pjrt_memory.h" +#include "xla/python/pjrt_ifrt/pjrt_topology.h" +#include "xla/python/pjrt_ifrt/pjrt_tuple.h" using namespace xla; +using namespace xla::ifrt; // #pragma region xla::ifrt @@ -837,7 +839,8 @@ using namespace xla; // #pragma endregion -JLCXX_MODULE reactant_module_ifrt(jlcxx::Module& mod) { +JLCXX_MODULE reactant_module_ifrt(jlcxx::Module& mod) +{ // mod.add_type("Value") // .method("client", &ifrt::Value::client) // .method("get_ready_future", &ifrt::Value::GetReadyFuture) @@ -877,29 +880,28 @@ JLCXX_MODULE reactant_module_ifrt(jlcxx::Module& mod) { mod.set_const("DTypeKindF8E5M2FNUZ", ifrt::DType::Kind::kF8E5M2FNUZ); mod.set_const("DTypeKindString", ifrt::DType::Kind::kString); - // TODO destructor?? - // mod.add_type("DType") - // .constructor(jlcxx::finalize_policy::no); - // .method("kind", &ifrt::DType::kind) - // .method("byte_size", &ifrt::DType::byte_size) - // .method("bit_size", &ifrt::DType::bit_size); - // mod.set_override_module(jl_base_module); - // mod.method("==", [](ifrt::DType* a, ifrt::DType* b) { return *a == *b; }); + mod.add_type("DType") + .constructor() + .method("kind", &ifrt::DType::kind); + // .method("byte_size", &ifrt::DType::byte_size) + // .method("bit_size", &ifrt::DType::bit_size); + mod.set_override_module(jl_base_module); + // mod.method("==", [](ifrt::DType& a, ifrt::DType& b) { return a == b; }); // mod.method("!=", [](ifrt::DType* a, ifrt::DType* b) { return *a != *b; }); // mod.method("copy", [](const ifrt::DType& x) { return ifrt::DType(x); }); - // // mod.method("string", &ifrt::DType::DebugString); - // mod.unset_override_module(); + mod.method("string", [](const ifrt::DType& x) { return x.DebugString(); }); + mod.unset_override_module(); // TODO conversion from/to `xla::PrimitiveType` using `ToPrimitiveType`,`ToDType` -// mod.add_type("Shape") -// // .constructor<...>() // TODO explicit constructor -// ; -// mod.set_override_module(jl_base_module); -// mod.method("==", [](ifrt::Shape* a, ifrt::Shape* b) { return *a == *b; }); -// mod.method("!=", [](ifrt::Shape* a, ifrt::Shape* b) { return *a != *b; }); -// mod.method("copy", [](const ifrt::Shape& x) { return ifrt::Shape(x); }); -// mod.method("string", &ifrt::Shape::DebugString); -// mod.method("size", &ifrt::Shape::dims); -// mod.method("length", &ifrt::Shape::num_elements); -// mod.unset_override_module(); + // mod.add_type("Shape") + // // .constructor<...>() // TODO explicit constructor + // ; + // mod.set_override_module(jl_base_module); + // mod.method("==", [](ifrt::Shape* a, ifrt::Shape* b) { return *a == *b; }); + // mod.method("!=", [](ifrt::Shape* a, ifrt::Shape* b) { return *a != *b; }); + // mod.method("copy", [](const ifrt::Shape& x) { return ifrt::Shape(x); }); + // mod.method("string", &ifrt::Shape::DebugString); + // mod.method("size", &ifrt::Shape::dims); + // mod.method("length", &ifrt::Shape::num_elements); + // mod.unset_override_module(); } \ No newline at end of file diff --git a/src/IFRT.jl b/src/IFRT.jl index 37e90e097..a16c934c1 100644 --- a/src/IFRT.jl +++ b/src/IFRT.jl @@ -5,4 +5,8 @@ using Reactant_jll @wrapmodule(() -> Reactant_jll.libReactantExtra, :reactant_module_ifrt) +function __init__() + @initcxx +end + end From dae95070195a5729582110870f8eb22904e881e7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Sat, 21 Dec 2024 10:43:23 +0100 Subject: [PATCH 25/29] update --- deps/ReactantExtra/src/IFRT.cpp | 669 +++++++++++--------------------- 1 file changed, 237 insertions(+), 432 deletions(-) diff --git a/deps/ReactantExtra/src/IFRT.cpp b/deps/ReactantExtra/src/IFRT.cpp index bb57f460b..570b99204 100644 --- a/deps/ReactantExtra/src/IFRT.cpp +++ b/deps/ReactantExtra/src/IFRT.cpp @@ -31,41 +31,14 @@ #include "xla/python/pjrt_ifrt/pjrt_topology.h" #include "xla/python/pjrt_ifrt/pjrt_tuple.h" -using namespace xla; +// Utils +#include "xla/pjrt/status_casters.h" + +// using namespace xla; using namespace xla::ifrt; // #pragma region xla::ifrt -// #pragma region xla::ifrt::Value -// extern "C" ifrt::Client* ifrt_value_client(ifrt::Value* value) { -// return value->client(); -// } - -// extern "C" ifrt::Future<> ifrt_value_get_ready_future(ifrt::Value* value) { -// return value->GetReadyFuture(); -// } - -// extern "C" ifrt::Future<> ifrt_value_delete(ifrt::Value* value) { -// return value->Delete(); -// } - -// extern "C" bool ifrt_value_is_deleted(ifrt::Value* value) { -// return value->IsDeleted(); -// } - -// extern "C" const char* ifrt_value_debug_string(ifrt::Value* value) { -// return cstr_from_string(value->DebugString()); -// } -// #pragma endregion - -// #pragma region xla::ifrt::Tuple -// extern "C" int ifrt_tuple_arity(ifrt::Tuple* tuple) { -// return tuple->Arity(); -// } - -// // TODO ifrt::Tuple::Unpack -// #pragma endregion - // #pragma region xla::ifrt::PjRtTuple // extern "C" ifrt::PjRtTuple* ifrt_pjrt_tuple_ctor(ifrt::PjRtCompatibleClient* client, ifrt::Value* values, int nvalues) { // auto values_ptr = new tsl::RCReference[nvalues]; @@ -82,228 +55,6 @@ using namespace xla::ifrt; // } // #pragma endregion -// #pragma region xla::ifrt::Shape -// extern "C" ifrt::Shape* ifrt_shape_ctor(const int64_t* dims, size_t dims_size) { -// return new ifrt::Shape(absl::Span(dims, dims_size)); -// } - -// extern "C" void ifrt_shape_free(ifrt::Shape* shape) { -// delete shape; -// } - -// extern "C" const int64_t* ifrt_shape_dims(ifrt::Shape* shape) { -// return shape->dims().data(); -// } - -// extern "C" int64_t ifrt_shape_dims_num_elements(ifrt::Shape* shape) { -// return shape->num_elements(); -// } - -// extern "C" const char* ifrt_shape_debug_string(ifrt::Shape* shape) { -// return cstr_from_string(shape->DebugString()); -// } -// #pragma endregion - -// #pragma region xla::ifrt::DynamicShape -// extern "C" ifrt::DynamicShape* ifrt_dynamicshape_ctor(ifrt::Shape* shape, const bool* dynamic_dims_mask) { -// auto tag = ifrt::BoundedDynamicShapeTag(absl::Span(dynamic_dims_mask, shape->dims().size())); -// auto dynshape = xla::ValueOrThrow(ifrt::DynamicShape::Create(*shape, tag)); -// return new ifrt::DynamicShape(dynshape); -// } - -// extern "C" void ifrt_dynamicshape_free(ifrt::DynamicShape* shape) { -// delete shape; -// } - -// // TODO ifrt::DynamicShape::GetTag - -// extern "C" bool ifrt_dynamicshape_eq(ifrt::DynamicShape* shape1, ifrt::DynamicShape* shape2) { -// return *shape1 == *shape2; -// } - -// extern "C" bool ifrt_dynamicshape_ne(ifrt::DynamicShape* shape1, ifrt::DynamicShape* shape2) { -// return *shape1 != *shape2; -// } - -// extern "C" ifrt::Shape* ifrt_dynamicshape_get_padded_shape(ifrt::DynamicShape* shape) { -// auto padshape = xla::ValueOrThrow(shape->GetPaddedShape()); -// return new ifrt::Shape(padshape); -// } - -// extern "C" bool ifrt_dynamicshape_is_dynamic_dim(ifrt::DynamicShape* shape, int dimension) { -// return shape->IsDynamicDim(dimension); -// } - -// extern "C" const char* ifrt_dynamicshape_debug_string(ifrt::DynamicShape* shape) { -// return cstr_from_string(shape->DebugString()); -// } -// #pragma endregion - -// #pragma region xla::ifrt::Index -// extern "C" ifrt::Index* ifrt_index_ctor(const int64_t* elements, size_t elements_size) { -// return new ifrt::Index(absl::Span(elements, elements_size)); -// } - -// extern "C" ifrt::Index* ifrt_index_zeros(int num_elements) { -// return new ifrt::Index(ifrt::Index::Zeros(num_elements)); -// } - -// extern "C" void ifrt_index_free(ifrt::Index* index) { -// delete index; -// } - -// extern "C" const int64_t* ifrt_index_elements(ifrt::Index* index) { -// return index->elements().data(); -// } - -// extern "C" int ifrt_index_count(ifrt::Index* index) { -// return index->elements().size(); -// } - -// extern "C" bool ifrt_index_eq(ifrt::Index* index1, ifrt::Index* index2) { -// return *index1 == *index2; -// } - -// extern "C" bool ifrt_index_ne(ifrt::Index* index1, ifrt::Index* index2) { -// return *index1 != *index2; -// } - -// extern "C" ifrt::Index* ifrt_index_add(ifrt::Index* index, ifrt::Index* offset) { -// return new ifrt::Index(*index + *offset); -// } - -// extern "C" ifrt::Index* ifrt_index_sub(ifrt::Index* index, ifrt::Index* offset) { -// return new ifrt::Index(*index - *offset); -// } - -// // WARN we're not checking if the multiplier has the same size as the index -// extern "C" ifrt::Index* ifrt_index_mul(ifrt::Index* index, const int64_t* multiplier) { -// return new ifrt::Index(*index * absl::Span(multiplier, ifrt_index_count(index))); -// } - -// extern "C" void ifrt_index_add_inplace(ifrt::Index* index, ifrt::Index* offset) { -// *index += *offset; -// } - -// extern "C" void ifrt_index_sub_inplace(ifrt::Index* index, ifrt::Index* offset) { -// *index -= *offset; -// } - -// extern "C" void ifrt_index_mul_inplace(ifrt::Index* index, const int64_t* multiplier) { -// *index *= absl::Span(multiplier, ifrt_index_count(index)); -// } - -// extern "C" const char* ifrt_index_debug_string(ifrt::Index* index) { -// return cstr_from_string(index->DebugString()); -// } -// #pragma endregion - -// #pragma region xla::ifrt::IndexDomain -// extern "C" ifrt::IndexDomain* ifrt_indexdomain_ctor(ifrt::Shape* shape) { -// return new ifrt::IndexDomain(*shape); -// } - -// extern "C" ifrt::IndexDomain* ifrt_indexdomain_ctor_with_origin(ifrt::Index* origin, ifrt::Shape* shape) { -// return new ifrt::IndexDomain(*origin, *shape); -// } - -// extern "C" void ifrt_indexdomain_free(ifrt::IndexDomain* index_domain) { -// delete index_domain; -// } - -// extern "C" const ifrt::Index* ifrt_indexdomain_origin(ifrt::IndexDomain* index_domain) { -// return &index_domain->origin(); -// } - -// extern "C" const ifrt::Shape* ifrt_indexdomain_shape(ifrt::IndexDomain* index_domain) { -// return &index_domain->shape(); -// } - -// extern "C" bool ifrt_indexdomain_eq(ifrt::IndexDomain* index_domain1, ifrt::IndexDomain* index_domain2) { -// return *index_domain1 == *index_domain2; -// } - -// extern "C" bool ifrt_indexdomain_ne(ifrt::IndexDomain* index_domain1, ifrt::IndexDomain* index_domain2) { -// return *index_domain1 != *index_domain2; -// } - -// extern "C" ifrt::IndexDomain* ifrt_indexdomain_add(ifrt::IndexDomain* index_domain, ifrt::Index* offset) { -// return new ifrt::IndexDomain(*index_domain + *offset); -// } - -// extern "C" ifrt::IndexDomain* ifrt_indexdomain_sub(ifrt::IndexDomain* index_domain, ifrt::Index* offset) { -// return new ifrt::IndexDomain(*index_domain - *offset); -// } - -// extern "C" void ifrt_indexdomain_add_inplace(ifrt::IndexDomain* index_domain, ifrt::Index* offset) { -// *index_domain += *offset; -// } - -// extern "C" void ifrt_indexdomain_sub_inplace(ifrt::IndexDomain* index_domain, ifrt::Index* offset) { -// *index_domain -= *offset; -// } - -// extern "C" const char* ifrt_indexdomain_debug_string(ifrt::IndexDomain* index_domain) { -// return cstr_from_string(index_domain->DebugString()); -// } -// #pragma endregion - -// #pragma region xla::ifrt::MemoryKind -// // Pass a nullptr to create a `MemoryKind` with no memory chosen. -// extern "C" ifrt::MemoryKind* ifrt_memorykind_ctor(const char* memory_kind) { -// if (memory_kind == nullptr) -// return new ifrt::MemoryKind(); -// return new ifrt::MemoryKind(std::string(memory_kind)); -// } - -// extern "C" void ifrt_memorykind_free(ifrt::MemoryKind* memory_kind) { -// delete memory_kind; -// } - -// extern "C" bool ifrt_memorykind_eq(ifrt::MemoryKind* mk1, ifrt::MemoryKind* mk2) { -// return *mk1 == *mk2; -// } - -// extern "C" bool ifrt_memorykind_ne(ifrt::MemoryKind* mk1, ifrt::MemoryKind* mk2) { -// return *mk1 != *mk2; -// } - -// extern "C" const char* ifrt_memorykind_string(ifrt::MemoryKind* memory_kind) { -// if (memory_kind->memory_kind().has_value()) -// return cstr_from_string(memory_kind->memory_kind().value()); -// else -// return nullptr; -// } - -// extern "C" ifrt::MemoryKind* ifrt_memorykind_canonicalize(ifrt::MemoryKind* memory_kind, ifrt::Device* device) { -// return new ifrt::MemoryKind(CanonicalizeMemoryKind(*memory_kind, device)); -// } -// #pragma endregion - -// #pragma region xla::ifrt::Memory -// // MemoryId is a struct with a single int32_t field --> check out xla/python/ifrt/memory.h -// extern "C" ifrt::MemoryId ifrt_memory_id(ifrt::Memory* memory) { -// return memory->Id(); -// } - -// extern "C" const ifrt::MemoryKind* ifrt_memory_kind(ifrt::Memory* memory) { -// return &(memory->Kind()); -// } - -// extern "C" const char* ifrt_memory_to_string(ifrt::Memory* memory) { -// return cstr_from_string(memory->ToString()); -// } - -// extern "C" const char* ifrt_memory_debug_string(ifrt::Memory* memory) { -// return cstr_from_string(memory->DebugString()); -// } - -// extern "C" std::tuple ifrt_memory_devices(ifrt::Memory* memory) { -// auto devices = memory->Devices(); -// return std::make_tuple(devices.size(), devices.data()); -// } -// #pragma endregion - // #pragma region xla::ifrt::PjRtMemory // extern "C" ifrt::PjRtMemory* ifrt_pjrt_memory_ctor(ifrt::PjRtClient* client, xla::PjRtMemorySpace* memory_space) { // return new ifrt::PjRtMemory(client, memory_space); @@ -322,45 +73,6 @@ using namespace xla::ifrt; // } // #pragma endregion -// #pragma region xla::ifrt::Device -// extern "C" ifrt::Client* ifrt_device_client(ifrt::Device* device) { -// return device->client(); -// } - -// // DeviceId is a struct with a single int32_t field --> check out xla/pjrt/pjrt_common.h -// extern "C" ifrt::DeviceId ifrt_device_id(ifrt::Device* device) { -// return device->Id(); -// } - -// // TODO ifrt_device_attributes - -// extern "C" const char* ifrt_device_kind(ifrt::Device* device) { -// return cstr_from_string(device->Kind()); -// } - -// extern "C" const char* ifrt_device_to_string(ifrt::Device* device) { -// return cstr_from_string(device->ToString()); -// } - -// extern "C" const char* ifrt_device_debug_string(ifrt::Device* device) { -// return cstr_from_string(device->DebugString()); -// } - -// extern "C" ifrt::Memory* ifrt_device_default_memory(ifrt::Device* device) { -// return xla::ValueOrThrow(device->DefaultMemory()); -// } - -// // TODO ifrt_device_memories - -// extern "C" bool ifrt_device_is_addressable(ifrt::Device* device) { -// return device->IsAddressable(); -// } - -// extern "C" int ifrt_device_process_index(ifrt::Device* device) { -// return device->ProcessIndex(); -// } -// #pragma endregion - // #pragma region xla::ifrt::PjRtDevice // // DeviceId is a struct with a single int32_t field --> check out xla/pjrt/pjrt_common.h // // TODO support `attributes` parameter @@ -377,53 +89,6 @@ using namespace xla::ifrt; // } // #pragma endregion -// #pragma region xla::ifrt::Sharding -// // TODO ifrt_sharding_devices -// // TODO ifrt_sharding_memory_kind - -// // extern "C" void ifrt_sharding_disassemble(ifrt::Sharding* sharding, ifrt::Shape* shape, char** error) { -// // auto status = sharding->Disassemble(*shape); -// // if (!status.ok()) { -// // auto str = status.message(); -// // char* err = (char*)malloc(str.size()+1); -// // memcpy(err, str.data(), str.size()+1); -// // *error = err; -// // } -// // } - -// // TODO ifrt_sharding_disassemble_dynamic_shape -// // TODO ifrt_sharding_index_domains - -// extern "C" const char* ifrt_sharding_debug_string(ifrt::Sharding* sharding) { -// return cstr_from_string(sharding->DebugString()); -// } -// #pragma endregion - -// #pragma region xla::ifrt::Array -// extern "C" ifrt::DType* ifrt_array_dtype(ifrt::Array* array) { -// return new ifrt::DType(array->dtype()); -// } - -// extern "C" const ifrt::Shape* ifrt_array_shape(ifrt::Array* array) { -// return &(array->shape()); -// } - -// extern "C" const ifrt::Sharding* ifrt_array_sharding(ifrt::Array* array) { -// return &(array->sharding()); -// } - -// extern "C" PjRtLayout* ifrt_array_layout(ifrt::Array* array) { -// return xla::ValueOrThrow(array->layout()).release(); -// } - -// // TODO xla::ifrt::Array::DisassembleIntoSingleDeviceArrays -// // TODO xla::ifrt::Array::FullyReplicatedShard - -// extern "C" ifrt::Future<> ifrt_array_copy_to_host_buffer(ifrt::Array* array, void* data, const int64_t* byte_strides, int semantics) { -// return array->CopyToHostBuffer(data, absl::Span(byte_strides, array->shape().num_elements()), ifrt::ArrayCopySemantics(semantics)); -// } -// #pragma endregion - // #pragma region xla::ifrt::PjRtArray // // TODO constructors / `Create` @@ -437,39 +102,6 @@ using namespace xla::ifrt; // } // #pragma endregion -// #pragma region xla::ifrt::Topology -// extern "C" const char* ifrt_topology_platform_name(ifrt::Topology* topology) { -// return cstr_from_string(topology->platform_name()); -// } - -// extern "C" const char* ifrt_topology_platform_version(ifrt::Topology* topology) { -// return cstr_from_string(topology->platform_version()); -// } - -// // returns PjRtPlatformId which is a type alias for uint64_t -// extern "C" uint64_t ifrt_topology_platform_id(ifrt::Topology* topology) { -// return topology->platform_id(); -// } - -// extern "C" std::tuple ifrt_topology_device_descriptions(ifrt::Topology* topology) { -// auto descriptions = topology->DeviceDescriptions(); -// auto descriptions_ptr = new const xla::PjRtDeviceDescription*[descriptions.size()]; -// for (int i=0; iSerialize())); -// } - -// // TODO xla::ifrt::Topology::Attributes - -// #pragma endregion - // #pragma region xla::ifrt::PjRtTopology // extern "C" ifrt::PjRtTopology* ifrt_pjrt_topology_ctor(const xla::PjRtTopologyDescription* description) { // return new ifrt::PjRtTopology(std::shared_ptr{description}); @@ -841,67 +473,240 @@ using namespace xla::ifrt; JLCXX_MODULE reactant_module_ifrt(jlcxx::Module& mod) { - // mod.add_type("Value") - // .method("client", &ifrt::Value::client) - // .method("get_ready_future", &ifrt::Value::GetReadyFuture) - // .method("delete!", &ifrt::Value::Delete) - // .method("isdeleted", &ifrt::Value::IsDeleted) - // .method("debug_string", &ifrt::Value::DebugString); - - mod.add_bits("DTypeKind", jlcxx::julia_type("CppEnum")); - mod.set_const("DTypeKindInvalid", ifrt::DType::Kind::kInvalid); - mod.set_const("DTypeKindPred", ifrt::DType::Kind::kPred); - mod.set_const("DTypeKindS2", ifrt::DType::Kind::kS2); - mod.set_const("DTypeKindS4", ifrt::DType::Kind::kS4); - mod.set_const("DTypeKindS8", ifrt::DType::Kind::kS8); - mod.set_const("DTypeKindS16", ifrt::DType::Kind::kS16); - mod.set_const("DTypeKindS32", ifrt::DType::Kind::kS32); - mod.set_const("DTypeKindS64", ifrt::DType::Kind::kS64); - mod.set_const("DTypeKindU2", ifrt::DType::Kind::kU2); - mod.set_const("DTypeKindU4", ifrt::DType::Kind::kU4); - mod.set_const("DTypeKindU8", ifrt::DType::Kind::kU8); - mod.set_const("DTypeKindU16", ifrt::DType::Kind::kU16); - mod.set_const("DTypeKindU32", ifrt::DType::Kind::kU32); - mod.set_const("DTypeKindU64", ifrt::DType::Kind::kU64); - mod.set_const("DTypeKindF16", ifrt::DType::Kind::kF16); - mod.set_const("DTypeKindF32", ifrt::DType::Kind::kF32); - mod.set_const("DTypeKindF64", ifrt::DType::Kind::kF64); - mod.set_const("DTypeKindBF16", ifrt::DType::Kind::kBF16); - mod.set_const("DTypeKindC64", ifrt::DType::Kind::kC64); - mod.set_const("DTypeKindC128", ifrt::DType::Kind::kC128); - mod.set_const("DTypeKindToken", ifrt::DType::Kind::kToken); - // mod.set_const("DTypeKindOpaque", ifrt::DType::Kind::kOpaque); - mod.set_const("DTypeKindF8E3M4", ifrt::DType::Kind::kF8E3M4); - mod.set_const("DTypeKindF8E4M3", ifrt::DType::Kind::kF8E4M3); - mod.set_const("DTypeKindF8E4M3FN", ifrt::DType::Kind::kF8E4M3FN); - mod.set_const("DTypeKindF8E4M3B11FNUZ", ifrt::DType::Kind::kF8E4M3B11FNUZ); - mod.set_const("DTypeKindF8E4M3FNUZ", ifrt::DType::Kind::kF8E4M3FNUZ); - mod.set_const("DTypeKindF8E5M2", ifrt::DType::Kind::kF8E5M2); - mod.set_const("DTypeKindF8E5M2FNUZ", ifrt::DType::Kind::kF8E5M2FNUZ); - mod.set_const("DTypeKindString", ifrt::DType::Kind::kString); - - mod.add_type("DType") - .constructor() - .method("kind", &ifrt::DType::kind); - // .method("byte_size", &ifrt::DType::byte_size) - // .method("bit_size", &ifrt::DType::bit_size); + mod.map_type("Int32"); + mod.map_type("Int32"); + mod.map_type("UInt64"); // TODO move to PjRT.cpp + + auto wrap_future = mod.add_type>("Future"); + auto wrap_value = mod.add_type("Value"); + auto wrap_tuple = mod.add_type("Tuple"); + auto wrap_dtype = mod.add_type("DType"); + auto wrap_shape = mod.add_type("Shape"); + auto wrap_boundeddynamicshapetag = mod.add_type("BoundedDynamicShapeTag"); + auto wrap_dynamicshape = mod.add_type("DynamicShape"); + auto wrap_index = mod.add_type("Index"); + auto wrap_indexdomain = mod.add_type("IndexDomain"); + auto wrap_memorykind = mod.add_type("MemoryKind"); + auto wrap_memory = mod.add_type("Memory"); + auto wrap_device = mod.add_type("Device"); + auto wrap_pjrtdevice = mod.add_type("PjRtDevice"); + auto wrap_sharding = mod.add_type("Sharding"); + auto wrap_array = mod.add_type("Array"); + auto wrap_pjrtarray = mod.add_type("PjRtArray"); + auto wrap_topology = mod.add_type("Topology"); + auto wrap_pjrttopology = mod.add_type("PjRtTopology"); + auto wrap_client = mod.add_type("Client"); + // auto wrap_pjrtclient = mod.add_type("PjRtClient"); + auto wrap_hostcallback = mod.add_type("HostCallback"); + auto wrap_loadedhostcallback = mod.add_type("LoadedHostCallback"); + auto wrap_pjrt_hostsendandrecv_loadedhostcallback = mod.add_type("PjRtHostSendAndRecvLoadedHostCallback"); + auto wrap_executable = mod.add_type("Executable"); + auto wrap_pjrtexecutable = mod.add_type("PjRtExecutable"); + auto wrap_loadedexecutable = mod.add_type("LoadedExecutable"); + auto wrap_pjrtloadedexecutable = mod.add_type("PjRtLoadedExecutable"); + // auto wrap_customcallprogram = mod.add_type("CustomCallProgram"); + auto wrap_hloprogram = mod.add_type("HloProgram"); + auto wrap_compiler = mod.add_type("Compiler"); + auto wrap_pjrtcompiler = mod.add_type("PjRtCompiler"); + + // Value (virtual) + wrap_value.method("client", &Value::client) + .method("get_ready_future", &Value::GetReadyFuture) + .method("delete!", &Value::Delete) + .method("isdeleted", &Value::IsDeleted); + mod.set_override_module(jl_base_module); - // mod.method("==", [](ifrt::DType& a, ifrt::DType& b) { return a == b; }); - // mod.method("!=", [](ifrt::DType* a, ifrt::DType* b) { return *a != *b; }); - // mod.method("copy", [](const ifrt::DType& x) { return ifrt::DType(x); }); - mod.method("string", [](const ifrt::DType& x) { return x.DebugString(); }); + wrap_value.method("string", &Value::DebugString); mod.unset_override_module(); + // Tuple + // TODO Unpack + mod.set_override_module(jl_base_module); + wrap_tuple.method("length", &Tuple::Arity); + mod.unset_override_module(); + + // DType::Kind + mod.add_bits("DTypeKind", jlcxx::julia_type("CppEnum")); + mod.set_const("DTypeKindInvalid", DType::Kind::kInvalid); + mod.set_const("DTypeKindPred", DType::Kind::kPred); + mod.set_const("DTypeKindS2", DType::Kind::kS2); + mod.set_const("DTypeKindS4", DType::Kind::kS4); + mod.set_const("DTypeKindS8", DType::Kind::kS8); + mod.set_const("DTypeKindS16", DType::Kind::kS16); + mod.set_const("DTypeKindS32", DType::Kind::kS32); + mod.set_const("DTypeKindS64", DType::Kind::kS64); + mod.set_const("DTypeKindU2", DType::Kind::kU2); + mod.set_const("DTypeKindU4", DType::Kind::kU4); + mod.set_const("DTypeKindU8", DType::Kind::kU8); + mod.set_const("DTypeKindU16", DType::Kind::kU16); + mod.set_const("DTypeKindU32", DType::Kind::kU32); + mod.set_const("DTypeKindU64", DType::Kind::kU64); + mod.set_const("DTypeKindF16", DType::Kind::kF16); + mod.set_const("DTypeKindF32", DType::Kind::kF32); + mod.set_const("DTypeKindF64", DType::Kind::kF64); + mod.set_const("DTypeKindBF16", DType::Kind::kBF16); + mod.set_const("DTypeKindC64", DType::Kind::kC64); + mod.set_const("DTypeKindC128", DType::Kind::kC128); + mod.set_const("DTypeKindToken", DType::Kind::kToken); + // mod.set_const("DTypeKindOpaque", DType::Kind::kOpaque); + mod.set_const("DTypeKindF8E3M4", DType::Kind::kF8E3M4); + mod.set_const("DTypeKindF8E4M3", DType::Kind::kF8E4M3); + mod.set_const("DTypeKindF8E4M3FN", DType::Kind::kF8E4M3FN); + mod.set_const("DTypeKindF8E4M3B11FNUZ", DType::Kind::kF8E4M3B11FNUZ); + mod.set_const("DTypeKindF8E4M3FNUZ", DType::Kind::kF8E4M3FNUZ); + mod.set_const("DTypeKindF8E5M2", DType::Kind::kF8E5M2); + mod.set_const("DTypeKindF8E5M2FNUZ", DType::Kind::kF8E5M2FNUZ); + mod.set_const("DTypeKindString", DType::Kind::kString); + + // DType // TODO conversion from/to `xla::PrimitiveType` using `ToPrimitiveType`,`ToDType` - // mod.add_type("Shape") - // // .constructor<...>() // TODO explicit constructor - // ; - // mod.set_override_module(jl_base_module); - // mod.method("==", [](ifrt::Shape* a, ifrt::Shape* b) { return *a == *b; }); - // mod.method("!=", [](ifrt::Shape* a, ifrt::Shape* b) { return *a != *b; }); - // mod.method("copy", [](const ifrt::Shape& x) { return ifrt::Shape(x); }); - // mod.method("string", &ifrt::Shape::DebugString); - // mod.method("size", &ifrt::Shape::dims); - // mod.method("length", &ifrt::Shape::num_elements); - // mod.unset_override_module(); -} \ No newline at end of file + wrap_dtype + .constructor() + .method("kind", &DType::kind); + // .method("byte_size", &DType::byte_size) + // .method("bit_size", &DType::bit_size); + mod.set_override_module(jl_base_module); + // mod.method("==", [](DType& a, DType& b) { return a == b; }); + // mod.method("!=", [](DType* a, DType* b) { return *a != *b; }); + // mod.method("copy", [](const DType& x) { return DType(x); }); + mod.method("string", [](const DType& x) { return x.DebugString(); }); + mod.unset_override_module(); + + // Shape + // wrap_shape + // .constructor([](std::vector dims) { + // return new Shape(dims); + // }); + mod.set_override_module(jl_base_module); + // mod.method("==", [](Shape* a, Shape* b) { return *a == *b; }); + // mod.method("!=", [](Shape* a, Shape* b) { return *a != *b; }); + // mod.method("copy", [](const Shape& x) { return Shape(x); }); + mod.method("string", [](const Shape& x) { return x.DebugString(); }); + // mod.method("size", [](const Shape& x) { return x.dims(); }); + mod.method("length", [](const Shape& x) { return x.num_elements(); }); + mod.unset_override_module(); + + // DynamicShape + // TODO implement remaining methods + wrap_dynamicshape + .method("isdyndim", &DynamicShape::IsDynamicDim); + + mod.set_override_module(jl_base_module); + mod.method("string", [](const DynamicShape& x) { return x.DebugString(); }); + mod.unset_override_module(); + + // Index + // TODO how do we overload +=, -=, *=? + // wrap_index + // .constructor>([](std::vector elements) { ... }); + mod.set_override_module(jl_base_module); + mod.method("zeros", &Index::Zeros); + // mod.method("==", &Index::operator==); + // mod.method("!=", &Index::operator!=); + mod.method("+", [](const Index& a, const Index& b) { return a + b; }); + mod.method("-", [](const Index& a, const Index& b) { return a - b; }); + // mod.method("*", [](const Index& a, std::vector mul) { return a * mul; }); + mod.method("string", [](const Index& x) { return x.DebugString(); }); + mod.unset_override_module(); + + // IndexDomain + // TODO how do we overload +=, -=, *=? + wrap_indexdomain + .constructor() + .constructor() + .method("origin", &IndexDomain::origin) + .method("shape", &IndexDomain::shape); + + mod.set_override_module(jl_base_module); + mod.method("+", [](const IndexDomain& x, const Index& offset) { return x + offset; }); + mod.method("-", [](const IndexDomain& x, const Index& offset) { return x - offset; }); + mod.method("string", [](const IndexDomain& x) { return x.DebugString(); }); + mod.unset_override_module(); + + // MemoryKind + // TODO `memory_kind` returns optional + wrap_memorykind + .constructor<>() + .constructor([](const std::string& name) { return new MemoryKind(name); }); + + mod.set_override_module(jl_base_module); + // mod.method("string", [](const MemoryKind& x) { return x.DebugString(); }); + mod.unset_override_module(); + + // TODO `CanonicalizeMemoryKind` + + // Memory (virtual) + // TODO `Devices` + wrap_memory + .method("id", &Memory::Id) + .method("kind", &Memory::Kind) + // .method("devices", [](const Memory& x) { + // auto devices_span = x.Devices(); + // return std::vector(devices_span.begin(), devices_span.end()); + // }) + ; + + mod.set_override_module(jl_base_module); + wrap_memory.method("string", [](const Memory& x) { return std::string(x.ToString()); }); + mod.unset_override_module(); + + // Device (virtual) + // TODO `Memories` + wrap_device + .method("client", &Device::client) + .method("id", &Device::Id) + // .method("attributes", &Device::Attributes) + .method("kind", [](const Device& x) { return std::string(x.Kind()); }) + .method("isaddressable", &Device::IsAddressable) + .method("process_index", &Device::ProcessIndex); + + // Sharding (virtual) + mod.add_bits("SingleDeviceShardSemantics", jlcxx::julia_type("CppEnum")); + mod.set_const("SingleDeviceShardSemanticsAddressable", SingleDeviceShardSemantics::kAddressableShards); + mod.set_const("SingleDeviceShardSemanticsAll", SingleDeviceShardSemantics::kAllShards); + + wrap_sharding + // .method("devices", ...) + .method("kind", &Sharding::memory_kind) + .method("is_fully_replicated", &Sharding::IsFullyReplicated) + .method("has_same_partitioning", &Sharding::HasSamePartitioning) + // .method("with_device_assignment", &Sharding::WithDeviceAssignment) + // .method("disassemble", &Sharding::Disassemble) + // .method("IndexDomains", &Sharding::IndexDomains) + .method("get_shard_shape", [](const Sharding& x, const Shape& shape) { return xla::ValueOrThrow(x.GetShardShape(shape)); }); + ; + + mod.set_override_module(jl_base_module); + wrap_sharding.method("string", [](const Sharding& x) { return x.DebugString(); }); + mod.unset_override_module(); + + // TODO SingleDeviceSharding, OpaqueSharding, ConcreteSharding, ConcreteEvenSharding, ShardingParamSharding + + // Array (virtual) + mod.add_bits("ArrayCopySemantics", jlcxx::julia_type("CppEnum")); + mod.set_const("ArrayCopySemanticsAlwaysCopy", ArrayCopySemantics::kAlwaysCopy); + mod.set_const("ArrayCopySemanticsReuseInput", ArrayCopySemantics::kReuseInput); + mod.set_const("ArrayCopySemanticsDonateInput", ArrayCopySemantics::kDonateInput); + + wrap_array + .method("dtype", &Array::dtype) + .method("shape", &Array::shape) + .method("sharding", &Array::sharding) + // .method("shared_ptr_sharding", &Array::shared_ptr_sharding) + // .method("layout", &Array::layout) + // .method("disassemble", &Array::DisassembleIntoSingleDeviceArrays) + // .method("replicate", &Array::FullyReplicatedShard) + // .method("copy_to_host_buffer", &Array::CopyToHostBuffer) + ; + + // Topology (virtual) + wrap_topology + .method("platform_name", [](const Topology& x) { return std::string(x.platform_name()); }) + .method("platform_version", [](const Topology& x) { return std::string(x.platform_version()); }) + .method("platform_id", &Topology::platform_id) + // .method("descriptions", &Topology::DeviceDescriptions) + // .method("layout", &Topology::GetDefaultLayout) + // .method("serialize", &Topology::Serialize) + // .method("Attributes", &Topology::Attributes) + ; +} From fed422e799500d533bd1ed05ed7e05911ccafcff Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Sun, 22 Dec 2024 11:15:04 +0100 Subject: [PATCH 26/29] refactor `DType` wrap --- deps/ReactantExtra/src/IFRT.cpp | 457 +------------------------------- src/IFRT.jl | 40 +++ 2 files changed, 53 insertions(+), 444 deletions(-) diff --git a/deps/ReactantExtra/src/IFRT.cpp b/deps/ReactantExtra/src/IFRT.cpp index 570b99204..3c10c2bc6 100644 --- a/deps/ReactantExtra/src/IFRT.cpp +++ b/deps/ReactantExtra/src/IFRT.cpp @@ -34,442 +34,10 @@ // Utils #include "xla/pjrt/status_casters.h" -// using namespace xla; using namespace xla::ifrt; -// #pragma region xla::ifrt - -// #pragma region xla::ifrt::PjRtTuple -// extern "C" ifrt::PjRtTuple* ifrt_pjrt_tuple_ctor(ifrt::PjRtCompatibleClient* client, ifrt::Value* values, int nvalues) { -// auto values_ptr = new tsl::RCReference[nvalues]; -// for (int i=0; i(); -// values_ptr[i].reset(&values[i]); -// } -// auto span = absl::Span>(values_ptr, nvalues); -// return xla::ValueOrThrow(ifrt::PjRtTuple::Create(client, span)).release(); -// } - -// extern "C" void ifrt_pjrt_tuple_free(ifrt::PjRtTuple* tuple) { -// delete tuple; -// } -// #pragma endregion - -// #pragma region xla::ifrt::PjRtMemory -// extern "C" ifrt::PjRtMemory* ifrt_pjrt_memory_ctor(ifrt::PjRtClient* client, xla::PjRtMemorySpace* memory_space) { -// return new ifrt::PjRtMemory(client, memory_space); -// } - -// extern "C" void ifrt_pjrt_memory_free(ifrt::PjRtMemory* memory) { -// delete memory; -// } - -// extern "C" ifrt::PjRtClient* ifrt_pjrt_memory_client(ifrt::PjRtMemory* memory) { -// return memory->client(); -// } - -// extern "C" xla::PjRtMemorySpace* ifrt_pjrt_memory_space(ifrt::PjRtMemory* memory) { -// return memory->pjrt_memory(); -// } -// #pragma endregion - -// #pragma region xla::ifrt::PjRtDevice -// // DeviceId is a struct with a single int32_t field --> check out xla/pjrt/pjrt_common.h -// // TODO support `attributes` parameter -// extern "C" ifrt::PjRtDevice* ifrt_pjrt_device_ctor(ifrt::PjRtClient* client, ifrt::DeviceId device_id, const char* kind, const char* to_string, const char* debug_string, int process_index, xla::PjRtDevice* pjrt_device) { -// return new ifrt::PjRtDevice(client, device_id, kind, to_string, debug_string, process_index, absl::flat_hash_map(), pjrt_device); -// } - -// extern "C" void ifrt_pjrt_device_free(ifrt::PjRtDevice* device) { -// delete device; -// } - -// extern "C" xla::PjRtDevice* ifrt_pjrt_device_pjrt_device(ifrt::PjRtDevice* device) { -// return device->pjrt_device(); -// } -// #pragma endregion - -// #pragma region xla::ifrt::PjRtArray -// // TODO constructors / `Create` - -// extern "C" std::tuple ifrt_pjrt_array_pjrt_buffers(ifrt::PjRtArray* array) { -// auto buffers = array->pjrt_buffers(); -// auto buffers_ptr = new xla::PjRtBuffer*[buffers.size()]; -// for (int i=0; i{description}); -// } - -// extern "C" const xla::PjRtTopologyDescription* ifrt_pjrt_topology_description(ifrt::PjRtTopology* topology) { -// return topology->description().get(); -// } -// #pragma endregion - -// #pragma region xla::ifrt::Client -// extern "C" int ifrt_client_device_count(ifrt::Client* client) { -// return client->device_count(); -// } - -// extern "C" int ifrt_client_addressable_device_count(ifrt::Client* client) { -// return client->addressable_device_count(); -// } - -// extern "C" ifrt::Device* const* ifrt_client_devices(ifrt::Client* client) { -// return client->devices().data(); -// } - -// extern "C" ifrt::Device* const* ifrt_client_addressable_devices(ifrt::Client* client) { -// return client->addressable_devices().data(); -// } - -// extern "C" int ifrt_client_process_index(ifrt::Client* client) { -// return client->process_index(); -// } - -// // TODO xla::ifrt::Client::GetDefaultDeviceAssignment - -// extern "C" ifrt::Device* ifrt_client_lookup_device(ifrt::Client* client, int device_id) { -// return xla::ValueOrThrow(client->LookupDevice(ifrt::DeviceId(device_id))); -// } - -// extern "C" ifrt::Device* ifrt_client_lookup_addressable_device(ifrt::Client* client, int device_id) { -// return xla::ValueOrThrow(client->LookupAddressableDevice(device_id)); -// } - -// extern "C" ifrt::Compiler* ifrt_client_default_compiler(ifrt::Client* client) { -// return client->GetDefaultCompiler(); -// } - -// // TODO ifrt_client_topology_for_devices -// // TODO ifrt_client_default_layout_for_device -// #pragma endregion - -// #pragma region xla::ifrt::PjRtClient -// // TODO support more parameters of `PjRtClient::CreateOptions` -// extern "C" ifrt::PjRtClient* ifrt_pjrt_client_ctor(xla::PjRtClient* pjrt_client) { -// return xla::ValueOrThrow(ifrt::PjRtClient::Create(ifrt::PjRtClient::CreateOptions{std::shared_ptr{pjrt_client}})).release(); -// } - -// extern "C" void ifrt_pjrt_client_free(ifrt::PjRtClient* client) { -// delete client; -// } - -// extern "C" xla::PjRtClient* ifrt_pjrt_client_pjrt_client(ifrt::PjRtClient* client) { -// return client->pjrt_client(); -// } - -// // TODO there are problems with using `make_shared -// // extern "C" ifrt::PjRtCompatibleArray* ifrt_pjrt_client_create_pjrt_array(ifrt::PjRtClient* client, xla::PjRtBuffer* pjrt_buffer) { -// // auto buffer_ptr = std::make_shared(*pjrt_buffer); -// // return xla::ValueOrThrow(client->CreatePjRtArray(buffer_ptr)).release(); -// // } - -// // TODO extern "C" ifrt::PjRtCompatibleArray* ifrt_pjrt_client_create_pjrt_array_from_buffers(ifrt::Shape* shape, ifrt::PjRtBuffer** pjrt_buffers, int num_buffers) {} - -// extern "C" ifrt::PjRtCompatibleDevice* ifrt_pjrt_client_lookup_pjrt_device(ifrt::PjRtClient* client, xla::PjRtDevice* pjrt_device) { -// return xla::ValueOrThrow(client->LookupPjRtDevice(pjrt_device)); -// } - -// extern "C" ifrt::PjRtCompatibleMemory* ifrt_pjrt_client_lookup_pjrt_memory(ifrt::PjRtClient* client, xla::PjRtMemorySpace* pjrt_memory_space) { -// return xla::ValueOrThrow(client->LookupPjRtMemory(pjrt_memory_space)); -// } -// #pragma endregion - -// #pragma region xla::ifrt::HostCallback -// extern "C" const char* ifrt_hostcallback_serialize(ifrt::HostCallback* host_callback) { -// return cstr_from_string(host_callback->Serialize()); -// } -// #pragma endregion - -// #pragma region xla::ifrt::LoadedHostCallback -// extern "C" ifrt::Client* ifrt_loadedhostcallback_client(ifrt::LoadedHostCallback* host_callback) { -// return host_callback->client(); -// } - -// extern "C" const char* ifrt_loadedhostcallback_serialize(ifrt::LoadedHostCallback* host_callback) { -// // auto msg = ; -// return cstr_from_string(xla::ValueOrThrow(host_callback->Serialize())); -// } -// #pragma endregion - -// #pragma region xla::ifrt::PjRtHostSendAndRecvLoadedHostCallback -// extern "C" ifrt::PjRtHostSendAndRecvLoadedHostCallback* ifrt_pjrt_hostsendandrecv_loadhostcallback_ctor(ifrt::PjRtClient* client, xla::HostCallback* host_callback) { -// auto xla_callback_ptr = std::make_unique(*host_callback); -// return new ifrt::PjRtHostSendAndRecvLoadedHostCallback(client, std::move(xla_callback_ptr)); -// } - -// extern "C" void ifrt_pjrt_hostsendandrecv_loadhostcallback_free(ifrt::PjRtHostSendAndRecvLoadedHostCallback* host_callback) { -// delete host_callback; -// } - -// extern "C" xla::HostCallback* ifrt_pjrt_hostsendandrecv_loadhostcallback_host_callback(ifrt::PjRtHostSendAndRecvLoadedHostCallback* host_callback) { -// return new xla::HostCallback(host_callback->host_callback()); -// } -// #pragma endregion - -// #pragma region xla::ifrt::Executable -// extern "C" const char* ifrt_executable_name(ifrt::Executable* executable) { -// return cstr_from_string(executable->name()); -// } - -// extern "C" const char* ifrt_executable_fingerprint(ifrt::Executable* executable) { -// auto result = xla::ValueOrThrow(executable->Fingerprint()); -// if (!result.has_value()) return ""; -// return cstr_from_string(result.value()); -// } - -// extern "C" const char* ifrt_executable_serialize(ifrt::Executable* executable) { -// return cstr_from_string(xla::ValueOrThrow(executable->Serialize())); -// } - -// extern "C" int ifrt_executable_num_devices(ifrt::Executable* executable) { -// return executable->num_devices(); -// } - -// extern "C" int64_t ifrt_executable_size(ifrt::Executable* executable) { -// return executable->SizeOfGeneratedCodeInBytes(); -// } - -// // TODO xla::ifrt::Executable::GetCompiledMemoryStats - -// extern "C" std::tuple ifrt_executable_parameter_shardings(ifrt::Executable* executable) { -// auto shardings = executable->GetParameterShardings(); -// if (!shardings.has_value()) return std::make_tuple(0, nullptr); -// return std::make_tuple(shardings.value().size(), shardings.value().data()); -// } - -// extern "C" std::tuple ifrt_executable_output_shardings(ifrt::Executable* executable) { -// auto shardings = executable->GetOutputShardings(); -// if (!shardings.has_value()) return std::make_tuple(0, nullptr); -// return std::make_tuple(shardings.value().size(), shardings.value().data()); -// } - -// extern "C" std::tuple ifrt_executable_parameter_layouts(ifrt::Executable* executable) { -// auto layouts = xla::ValueOrThrow(executable->GetParameterLayouts()); -// auto layouts_ptr = new xla::PjRtLayout*[layouts.size()]; -// for (int i=0; i ifrt_executable_output_layouts(ifrt::Executable* executable) { -// auto layouts = xla::ValueOrThrow(executable->GetOutputLayouts()); -// auto layouts_ptr = new xla::PjRtLayout*[layouts.size()]; -// for (int i=0; i ifrt_executable_hlo_modules(ifrt::Executable* executable) { -// auto modules = xla::ValueOrThrow(executable->GetHloModules()); -// auto modules_ptr = new xla::HloModule*[modules.size()]; -// for (int i=0; i(*pjrt_executable); -// // auto options = std::make_unique(*compile_options); -// // return xla::ValueOrThrow(ifrt::PjRtExecutable::Create(pjrt_executable_shared, std::move(options))).release(); -// // } - -// extern "C" void ifrt_pjrt_executable_free(ifrt::PjRtExecutable* executable) { -// delete executable; -// } - -// extern "C" xla::PjRtExecutable* ifrt_pjrt_executable_pjrt_executable(ifrt::PjRtExecutable* executable) { -// return executable->pjrt_executable(); -// } -// #pragma endregion - -// #pragma region xla::ifrt::LoadedExecutable -// extern "C" ifrt::Client* ifrt_loadedexecutable_client(ifrt::LoadedExecutable* executable) { -// return executable->client(); -// } - -// extern "C" const char* ifrt_loadedexecutable_name(ifrt::LoadedExecutable* executable) { -// return cstr_from_string(executable->name()); -// } - -// extern "C" const char* ifrt_loadedexecutable_fingerprint(ifrt::LoadedExecutable* executable) { -// auto result = xla::ValueOrThrow(executable->Fingerprint()); -// if (!result.has_value()) return ""; -// return cstr_from_string(result.value()); -// } - -// extern "C" const char* ifrt_loadedexecutable_serialize(ifrt::LoadedExecutable* executable) { -// return cstr_from_string(xla::ValueOrThrow(executable->Serialize())); -// } - -// extern "C" ifrt::Future<> ifrt_loadedexecutable_get_ready_future(ifrt::LoadedExecutable* executable) { -// return executable->GetReadyFuture(); -// } - -// extern "C" int ifrt_loadedexecutable_num_devices(ifrt::LoadedExecutable* executable) { -// return executable->num_devices(); -// } - -// extern "C" int64_t ifrt_loadedexecutable_size(ifrt::LoadedExecutable* executable) { -// return executable->SizeOfGeneratedCodeInBytes(); -// } - -// // TODO xla::ifrt::GetCompiledMemoryStats - -// extern "C" std::tuple ifrt_loadedexecutable_parameter_shardings(ifrt::LoadedExecutable* executable) { -// auto shardings = executable->GetParameterShardings(); -// if (!shardings.has_value()) return std::make_tuple(0, nullptr); -// return std::make_tuple(shardings.value().size(), shardings.value().data()); -// } - -// extern "C" std::tuple ifrt_loadedexecutable_output_shardings(ifrt::LoadedExecutable* executable) { -// auto shardings = executable->GetOutputShardings(); -// if (!shardings.has_value()) return std::make_tuple(0, nullptr); -// return std::make_tuple(shardings.value().size(), shardings.value().data()); -// } - -// extern "C" std::tuple ifrt_loadedexecutable_parameter_layouts(ifrt::LoadedExecutable* executable) { -// auto layouts = xla::ValueOrThrow(executable->GetParameterLayouts()); -// auto layouts_ptr = new xla::PjRtLayout*[layouts.size()]; -// for (int i=0; i ifrt_loadedexecutable_output_layouts(ifrt::LoadedExecutable* executable) { -// auto layouts = xla::ValueOrThrow(executable->GetOutputLayouts()); -// auto layouts_ptr = new xla::PjRtLayout*[layouts.size()]; -// for (int i=0; i ifrt_loadedexecutable_hlo_modules(ifrt::LoadedExecutable* executable) { -// auto modules = xla::ValueOrThrow(executable->GetHloModules()); -// auto modules_ptr = new xla::HloModule*[modules.size()]; -// for (int i=0; i** futures, size_t futures_size) { -// // std::vector arguments(args, args + args_size); -// // std::vector result(results, results + results_size); -// // std::vector*> future(futures, futures + futures_size); -// // return xla::ValueOrThrow(executable->Execute(arguments, result, future)); -// // } - -// extern "C" ifrt::Future<> ifrt_loadedexecutable_delete(ifrt::LoadedExecutable* executable) { -// return executable->Delete(); -// } - -// extern "C" bool ifrt_loadedexecutable_is_deleted(ifrt::LoadedExecutable* executable) { -// return executable->IsDeleted(); -// } - -// extern "C" std::tuple ifrt_loadedexecutable_addressable_devices(ifrt::LoadedExecutable* executable) { -// auto devices = executable->addressable_devices(); -// return std::make_tuple(devices.size(), devices.data()); -// } - -// // TODO auxiliary functions for xla::ifrt::LoadedExecutable::ExecuteResult -// #pragma endregion - -// #pragma region xla::ifrt::PjRtLoadedExecutable -// // TODO add support for LoadedHostCallback -// // TODO there are problems with using `make_shared -// // extern "C" ifrt::LoadedExecutable* ifrt_pjrt_loadedexecutable_ctor(ifrt::PjRtCompatibleClient* client, xla::PjRtLoadedExecutable* pjrt_loaded_executable) { -// // auto pjrt_loaded_executable_ptr = std::make_shared(*pjrt_loaded_executable); -// // return xla::ValueOrThrow(ifrt::PjRtLoadedExecutable::Create(client, pjrt_loaded_executable_ptr, std::vector>())).release(); -// // } - -// // TODO add support for LoadedHostCallback -// extern "C" ifrt::LoadedExecutable* ifrt_pjrt_loadedexecutable_ctor_from_mlir_module(ifrt::PjRtCompatibleClient* client, mlir::ModuleOp* module, xla::CompileOptions* compile_options) { -// return xla::ValueOrThrow(ifrt::PjRtLoadedExecutable::Create(client, *module, *compile_options, std::vector>())).release(); -// } - -// extern "C" void ifrt_pjrt_loadedexecutable_free(ifrt::PjRtLoadedExecutable* executable) { -// delete executable; -// } - -// extern "C" xla::PjRtLoadedExecutable* ifrt_pjrt_loadedexecutable_pjrt_loadedexecutable(ifrt::PjRtLoadedExecutable* executable) { -// return executable->pjrt_loaded_executable(); -// } -// #pragma endregion - -// #pragma region xla::ifrt::CustomCallProgram -// #pragma endregion - -// #pragma region xla::ifrt::HloProgram -// extern "C" ifrt::HloProgram* ifrt_hloprogram_ctor() { -// return new ifrt::HloProgram(); -// } - -// extern "C" ifrt::HloProgram* ifrt_hloprogram_ctor_with_module(mlir::ModuleOp* module) { -// return new ifrt::HloProgram(*module); -// } - -// // extern "C" ifrt::HloProgram* ifrt_hloprogram_ctor_with_context_and_module(mlir::MLIRContext* context, mlir::ModuleOp* module) { -// // auto context_ptr = std::make_unique(*context); -// // return new ifrt::HloProgram(std::move(context_ptr), *module); -// // } -// #pragma endregion - -// #pragma region xla::ifrt::Compiler -// extern "C" ifrt::LoadedExecutable* ifrt_compiler_compile(ifrt::Compiler* compiler, ifrt::Program* program) { -// // apparently ifrt::CompileOptions is a legacy artifact so we don't use it and set directly to the default -// auto program_ptr = std::make_unique(*program); -// auto options = std::make_unique(); -// return xla::ValueOrThrow(compiler->Compile(std::move(program_ptr), std::move(options))).release(); -// } - -// extern "C" ifrt::Executable* ifrt_compiler_compile_with_topology(ifrt::Compiler* compiler, ifrt::Program* program, const ifrt::Topology* topology) { -// // apparently ifrt::CompileOptions is a legacy artifact so we don't use it and set directly to the default -// auto options = std::make_unique(); -// auto program_ptr = std::make_unique(*program); -// auto exec_ptr = xla::ValueOrThrow(compiler->Compile(std::move(program_ptr), *topology, std::move(options))).release(); -// return exec_ptr; -// } - -// extern "C" ifrt::LoadedExecutable* ifrt_compiler_deserialize_loadedexecutable(ifrt::Compiler* compiler, const char* data) { -// // apparently ifrt::DeserializeExecutableOptions is a legacy artifact so we don't use it and set directly to the default -// auto options = std::make_unique(); -// return xla::ValueOrThrow(compiler->DeserializeLoadedExecutable(std::string(data), std::move(options))).release(); -// } -// #pragma endregion - -// #pragma region xla::ifrt::PjRtCompiler -// extern "C" ifrt::PjRtCompiler* ifrt_pjrt_compiler_ctor(ifrt::PjRtClient* client) { -// return new ifrt::PjRtCompiler(client); -// } - -// extern "C" void ifrt_pjrt_compiler_free(ifrt::PjRtCompiler* compiler) { -// delete compiler; -// } -// #pragma endregion - -// #pragma endregion +#define JLCXX_CLASS_DEF_EQ(WRAP, CLASS) WRAP.method("==", &CLASS::operator==); +#define JLCXX_CLASS_DEF_NE(WRAP, CLASS) WRAP.method("!=", &CLASS::operator!=); JLCXX_MODULE reactant_module_ifrt(jlcxx::Module& mod) { @@ -480,7 +48,7 @@ JLCXX_MODULE reactant_module_ifrt(jlcxx::Module& mod) auto wrap_future = mod.add_type>("Future"); auto wrap_value = mod.add_type("Value"); auto wrap_tuple = mod.add_type("Tuple"); - auto wrap_dtype = mod.add_type("DType"); + auto wrap_dtype = mod.add_type("DType"); // NOTE this could be a `map_type` instead? needs experimentation auto wrap_shape = mod.add_type("Shape"); auto wrap_boundeddynamicshapetag = mod.add_type("BoundedDynamicShapeTag"); auto wrap_dynamicshape = mod.add_type("DynamicShape"); @@ -509,7 +77,7 @@ JLCXX_MODULE reactant_module_ifrt(jlcxx::Module& mod) auto wrap_compiler = mod.add_type("Compiler"); auto wrap_pjrtcompiler = mod.add_type("PjRtCompiler"); - // Value (virtual) + // Value wrap_value.method("client", &Value::client) .method("get_ready_future", &Value::GetReadyFuture) .method("delete!", &Value::Delete) @@ -520,7 +88,7 @@ JLCXX_MODULE reactant_module_ifrt(jlcxx::Module& mod) mod.unset_override_module(); // Tuple - // TODO Unpack + // TODO `Unpack` might not be as interesting to offer as it is mod.set_override_module(jl_base_module); wrap_tuple.method("length", &Tuple::Arity); mod.unset_override_module(); @@ -559,16 +127,17 @@ JLCXX_MODULE reactant_module_ifrt(jlcxx::Module& mod) mod.set_const("DTypeKindString", DType::Kind::kString); // DType - // TODO conversion from/to `xla::PrimitiveType` using `ToPrimitiveType`,`ToDType` + // TODO conversion from/to `xla::PrimitiveType` using `ToPrimitiveType`,`ToDType` => might require PjRT with CxxWrap + // TODO fix return of `optional` on `byte_size`, `bit_size` wrap_dtype .constructor() - .method("kind", &DType::kind); - // .method("byte_size", &DType::byte_size) - // .method("bit_size", &DType::bit_size); + .method("kind", [](const DType& x) { return x.kind(); }) + .method("byte_size", [](const DType& x) { return x.byte_size().value_or(0); }) + .method("bit_size", [](const DType& x) { return x.bit_size().value_or(0); }); + mod.set_override_module(jl_base_module); - // mod.method("==", [](DType& a, DType& b) { return a == b; }); - // mod.method("!=", [](DType* a, DType* b) { return *a != *b; }); - // mod.method("copy", [](const DType& x) { return DType(x); }); + JLCXX_CLASS_DEF_EQ(wrap_dtype, DType) + JLCXX_CLASS_DEF_NE(wrap_dtype, DType) mod.method("string", [](const DType& x) { return x.DebugString(); }); mod.unset_override_module(); diff --git a/src/IFRT.jl b/src/IFRT.jl index a16c934c1..19b69fef4 100644 --- a/src/IFRT.jl +++ b/src/IFRT.jl @@ -9,4 +9,44 @@ function __init__() @initcxx end +# NOTE some DType kinds lack a corresponding Julia type +function Base.convert(::Type{DType}, type::Type) + kind = if type === Bool + DTypeKindPred + elseif type === Int8 + DTypeKindS8 + elseif type === Int16 + DTypeKindS16 + elseif type === Int32 + DTypeKindS32 + elseif type === Int64 + DTypeKindS64 + elseif type === UInt8 + DTypeKindU8 + elseif type === UInt16 + DTypeKindU16 + elseif type === UInt32 + DTypeKindU32 + elseif type === UInt64 + DTypeKindU64 + elseif type === Float16 + DTypeKindF16 + elseif type === Float32 + DTypeKindF32 + elseif type === Float64 + DTypeKindF64 + elseif type === ComplexF32 + DTypeKindC64 + elseif type === ComplexF64 + DTypeKindC128 + elseif type === String + DTypeKindString + else + @warn "`$type` can not be converted to DType" + DTypeKindInvalid + end + + return DType(kind) +end + end From 9f83d5d349b7825a4e2a0920f27605df02b692a2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Sun, 22 Dec 2024 13:43:39 +0100 Subject: [PATCH 27/29] refactor `Index` --- deps/ReactantExtra/src/IFRT.cpp | 25 +++++++++++++++++-------- 1 file changed, 17 insertions(+), 8 deletions(-) diff --git a/deps/ReactantExtra/src/IFRT.cpp b/deps/ReactantExtra/src/IFRT.cpp index 3c10c2bc6..007f15667 100644 --- a/deps/ReactantExtra/src/IFRT.cpp +++ b/deps/ReactantExtra/src/IFRT.cpp @@ -38,6 +38,8 @@ using namespace xla::ifrt; #define JLCXX_CLASS_DEF_EQ(WRAP, CLASS) WRAP.method("==", &CLASS::operator==); #define JLCXX_CLASS_DEF_NE(WRAP, CLASS) WRAP.method("!=", &CLASS::operator!=); +#define JLCXX_CLASS_DEF_ADD(WRAP, CLASS) WRAP.method("+", &CLASS::operator+); +#define JLCXX_CLASS_DEF_SUB(WRAP, CLASS) WRAP.method("-", &CLASS::operator-); JLCXX_MODULE reactant_module_ifrt(jlcxx::Module& mod) { @@ -166,15 +168,20 @@ JLCXX_MODULE reactant_module_ifrt(jlcxx::Module& mod) // Index // TODO how do we overload +=, -=, *=? - // wrap_index - // .constructor>([](std::vector elements) { ... }); + // TODO call `crete_index_zeros` from `zeros(::Type{Index}, ...)` + wrap_index + .constructor([](std::vector& elements) { + return new Index(absl::Span(elements)); + }) + .method("elements", [](const Index& x) { return std::vector(x.elements().begin(), x.elements().end()); }); + mod.method("create_index_zeros", &Index::Zeros); + mod.set_override_module(jl_base_module); - mod.method("zeros", &Index::Zeros); - // mod.method("==", &Index::operator==); - // mod.method("!=", &Index::operator!=); - mod.method("+", [](const Index& a, const Index& b) { return a + b; }); - mod.method("-", [](const Index& a, const Index& b) { return a - b; }); - // mod.method("*", [](const Index& a, std::vector mul) { return a * mul; }); + JLCXX_CLASS_DEF_EQ(wrap_index, Index) + JLCXX_CLASS_DEF_NE(wrap_index, Index) + JLCXX_CLASS_DEF_ADD(wrap_index, Index) + JLCXX_CLASS_DEF_SUB(wrap_index, Index) + mod.method("*", [](const Index& a, std::vector mul) { return a * absl::Span(mul); }); mod.method("string", [](const Index& x) { return x.DebugString(); }); mod.unset_override_module(); @@ -187,6 +194,8 @@ JLCXX_MODULE reactant_module_ifrt(jlcxx::Module& mod) .method("shape", &IndexDomain::shape); mod.set_override_module(jl_base_module); + JLCXX_CLASS_DEF_EQ(wrap_indexdomain, IndexDomain) + JLCXX_CLASS_DEF_NE(wrap_indexdomain, IndexDomain) mod.method("+", [](const IndexDomain& x, const Index& offset) { return x + offset; }); mod.method("-", [](const IndexDomain& x, const Index& offset) { return x - offset; }); mod.method("string", [](const IndexDomain& x) { return x.DebugString(); }); From b403bb3b4619d594f9d00dca252f94a7c983f00b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Sun, 22 Dec 2024 15:13:36 +0100 Subject: [PATCH 28/29] update --- deps/ReactantExtra/src/IFRT.cpp | 88 +++++++++++++++++++++++---------- 1 file changed, 61 insertions(+), 27 deletions(-) diff --git a/deps/ReactantExtra/src/IFRT.cpp b/deps/ReactantExtra/src/IFRT.cpp index 007f15667..35dad8b85 100644 --- a/deps/ReactantExtra/src/IFRT.cpp +++ b/deps/ReactantExtra/src/IFRT.cpp @@ -36,11 +36,27 @@ using namespace xla::ifrt; +namespace jlcxx { +// template +// struct julia_type_factory> { +// static inline jl_datatype_t* julia_type() +// { +// jl_datatype_t* union_params_types[2] = { julia_base_type(), jl_nothing_type }; +// auto union_nothing = apply_type(jlcxx::julia_type("Union"), jl_nothing_type); +// // return apply_type(jlcxx::julia_type("Union"), reinterpret_cast(&union_params_types), 2); +// return apply_type(jlcxx::julia_type("Union"), julia_base_type()); +// } +// }; +} + #define JLCXX_CLASS_DEF_EQ(WRAP, CLASS) WRAP.method("==", &CLASS::operator==); #define JLCXX_CLASS_DEF_NE(WRAP, CLASS) WRAP.method("!=", &CLASS::operator!=); #define JLCXX_CLASS_DEF_ADD(WRAP, CLASS) WRAP.method("+", &CLASS::operator+); #define JLCXX_CLASS_DEF_SUB(WRAP, CLASS) WRAP.method("-", &CLASS::operator-); +#define JLCXX_CLASS_DEF_DBGSTR(WRAP, CLASS) WRAP.method("string", &CLASS::DebugString); +// TODO refactor `DebugString` calls for `AbslStringify` +// TODO impl calls to `hash` using `AbslHashValue` JLCXX_MODULE reactant_module_ifrt(jlcxx::Module& mod) { mod.map_type("Int32"); @@ -86,7 +102,7 @@ JLCXX_MODULE reactant_module_ifrt(jlcxx::Module& mod) .method("isdeleted", &Value::IsDeleted); mod.set_override_module(jl_base_module); - wrap_value.method("string", &Value::DebugString); + JLCXX_CLASS_DEF_DBGSTR(wrap_value, Value) mod.unset_override_module(); // Tuple @@ -140,7 +156,7 @@ JLCXX_MODULE reactant_module_ifrt(jlcxx::Module& mod) mod.set_override_module(jl_base_module); JLCXX_CLASS_DEF_EQ(wrap_dtype, DType) JLCXX_CLASS_DEF_NE(wrap_dtype, DType) - mod.method("string", [](const DType& x) { return x.DebugString(); }); + JLCXX_CLASS_DEF_DBGSTR(wrap_dtype, DType) mod.unset_override_module(); // Shape @@ -149,10 +165,10 @@ JLCXX_MODULE reactant_module_ifrt(jlcxx::Module& mod) // return new Shape(dims); // }); mod.set_override_module(jl_base_module); - // mod.method("==", [](Shape* a, Shape* b) { return *a == *b; }); - // mod.method("!=", [](Shape* a, Shape* b) { return *a != *b; }); + JLCXX_CLASS_DEF_EQ(wrap_shape, Shape) + JLCXX_CLASS_DEF_NE(wrap_shape, Shape) // mod.method("copy", [](const Shape& x) { return Shape(x); }); - mod.method("string", [](const Shape& x) { return x.DebugString(); }); + JLCXX_CLASS_DEF_DBGSTR(wrap_shape, Shape) // mod.method("size", [](const Shape& x) { return x.dims(); }); mod.method("length", [](const Shape& x) { return x.num_elements(); }); mod.unset_override_module(); @@ -163,7 +179,7 @@ JLCXX_MODULE reactant_module_ifrt(jlcxx::Module& mod) .method("isdyndim", &DynamicShape::IsDynamicDim); mod.set_override_module(jl_base_module); - mod.method("string", [](const DynamicShape& x) { return x.DebugString(); }); + JLCXX_CLASS_DEF_DBGSTR(wrap_dynamicshape, DynamicShape) mod.unset_override_module(); // Index @@ -182,7 +198,7 @@ JLCXX_MODULE reactant_module_ifrt(jlcxx::Module& mod) JLCXX_CLASS_DEF_ADD(wrap_index, Index) JLCXX_CLASS_DEF_SUB(wrap_index, Index) mod.method("*", [](const Index& a, std::vector mul) { return a * absl::Span(mul); }); - mod.method("string", [](const Index& x) { return x.DebugString(); }); + JLCXX_CLASS_DEF_DBGSTR(wrap_index, Index) mod.unset_override_module(); // IndexDomain @@ -196,39 +212,39 @@ JLCXX_MODULE reactant_module_ifrt(jlcxx::Module& mod) mod.set_override_module(jl_base_module); JLCXX_CLASS_DEF_EQ(wrap_indexdomain, IndexDomain) JLCXX_CLASS_DEF_NE(wrap_indexdomain, IndexDomain) - mod.method("+", [](const IndexDomain& x, const Index& offset) { return x + offset; }); - mod.method("-", [](const IndexDomain& x, const Index& offset) { return x - offset; }); - mod.method("string", [](const IndexDomain& x) { return x.DebugString(); }); + JLCXX_CLASS_DEF_ADD(wrap_indexdomain, IndexDomain) + JLCXX_CLASS_DEF_SUB(wrap_indexdomain, IndexDomain) + JLCXX_CLASS_DEF_DBGSTR(wrap_indexdomain, IndexDomain) mod.unset_override_module(); // MemoryKind - // TODO `memory_kind` returns optional wrap_memorykind .constructor<>() - .constructor([](const std::string& name) { return new MemoryKind(name); }); + .constructor([](const std::string& name) { return new MemoryKind(name); }) + .method("canonicalize", [](MemoryKind& x, Device& dev) { return CanonicalizeMemoryKind(x, &dev); }); mod.set_override_module(jl_base_module); - // mod.method("string", [](const MemoryKind& x) { return x.DebugString(); }); + JLCXX_CLASS_DEF_EQ(wrap_memorykind, MemoryKind) + JLCXX_CLASS_DEF_NE(wrap_memorykind, MemoryKind) + wrap_memorykind.method("string", [](const MemoryKind& x) { return std::string(x.memory_kind().value_or("")); }); mod.unset_override_module(); - // TODO `CanonicalizeMemoryKind` - - // Memory (virtual) - // TODO `Devices` + // Memory + // TODO check if `Devices` is correct (why does it return a span of pointers?) wrap_memory + .constructor<>() .method("id", &Memory::Id) .method("kind", &Memory::Kind) - // .method("devices", [](const Memory& x) { - // auto devices_span = x.Devices(); - // return std::vector(devices_span.begin(), devices_span.end()); - // }) - ; + .method("devices", [](const Memory& x) { + auto devices_span = x.Devices(); + return std::vector(devices_span.begin(), devices_span.end()); + }); mod.set_override_module(jl_base_module); - wrap_memory.method("string", [](const Memory& x) { return std::string(x.ToString()); }); + JLCXX_CLASS_DEF_DBGSTR(wrap_memory, Memory) mod.unset_override_module(); - // Device (virtual) + // Device // TODO `Memories` wrap_device .method("client", &Device::client) @@ -238,7 +254,7 @@ JLCXX_MODULE reactant_module_ifrt(jlcxx::Module& mod) .method("isaddressable", &Device::IsAddressable) .method("process_index", &Device::ProcessIndex); - // Sharding (virtual) + // Sharding mod.add_bits("SingleDeviceShardSemantics", jlcxx::julia_type("CppEnum")); mod.set_const("SingleDeviceShardSemanticsAddressable", SingleDeviceShardSemantics::kAddressableShards); mod.set_const("SingleDeviceShardSemanticsAll", SingleDeviceShardSemantics::kAllShards); @@ -260,7 +276,7 @@ JLCXX_MODULE reactant_module_ifrt(jlcxx::Module& mod) // TODO SingleDeviceSharding, OpaqueSharding, ConcreteSharding, ConcreteEvenSharding, ShardingParamSharding - // Array (virtual) + // Array mod.add_bits("ArrayCopySemantics", jlcxx::julia_type("CppEnum")); mod.set_const("ArrayCopySemanticsAlwaysCopy", ArrayCopySemantics::kAlwaysCopy); mod.set_const("ArrayCopySemanticsReuseInput", ArrayCopySemantics::kReuseInput); @@ -277,7 +293,7 @@ JLCXX_MODULE reactant_module_ifrt(jlcxx::Module& mod) // .method("copy_to_host_buffer", &Array::CopyToHostBuffer) ; - // Topology (virtual) + // Topology wrap_topology .method("platform_name", [](const Topology& x) { return std::string(x.platform_name()); }) .method("platform_version", [](const Topology& x) { return std::string(x.platform_version()); }) @@ -287,4 +303,22 @@ JLCXX_MODULE reactant_module_ifrt(jlcxx::Module& mod) // .method("serialize", &Topology::Serialize) // .method("Attributes", &Topology::Attributes) ; + + // Client + + // HostCallback + + // LoadedHostCallback + + // PjRtHostSendAndRecvLoadedHostCallback + + // Executable + + // LoadedExecutable + + // CustomCallProgram + + // HloProgram + + // Compiler } From df7720b2af75fff569666ff992c0a6c9c30ebd09 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Sun, 22 Dec 2024 16:12:45 +0100 Subject: [PATCH 29/29] update --- deps/ReactantExtra/src/IFRT.cpp | 40 ++++++++++++++++++++++++--------- 1 file changed, 29 insertions(+), 11 deletions(-) diff --git a/deps/ReactantExtra/src/IFRT.cpp b/deps/ReactantExtra/src/IFRT.cpp index 35dad8b85..f1ee2b5b8 100644 --- a/deps/ReactantExtra/src/IFRT.cpp +++ b/deps/ReactantExtra/src/IFRT.cpp @@ -61,7 +61,7 @@ JLCXX_MODULE reactant_module_ifrt(jlcxx::Module& mod) { mod.map_type("Int32"); mod.map_type("Int32"); - mod.map_type("UInt64"); // TODO move to PjRT.cpp + // mod.map_type("UInt64"); // TODO move to PjRT.cpp auto wrap_future = mod.add_type>("Future"); auto wrap_value = mod.add_type("Value"); @@ -72,6 +72,7 @@ JLCXX_MODULE reactant_module_ifrt(jlcxx::Module& mod) auto wrap_dynamicshape = mod.add_type("DynamicShape"); auto wrap_index = mod.add_type("Index"); auto wrap_indexdomain = mod.add_type("IndexDomain"); + auto wrap_attributemap = mod.add_type("AttributeMap"); auto wrap_memorykind = mod.add_type("MemoryKind"); auto wrap_memory = mod.add_type("Memory"); auto wrap_device = mod.add_type("Device"); @@ -217,6 +218,8 @@ JLCXX_MODULE reactant_module_ifrt(jlcxx::Module& mod) JLCXX_CLASS_DEF_DBGSTR(wrap_indexdomain, IndexDomain) mod.unset_override_module(); + // TODO AttributeMap + // MemoryKind wrap_memorykind .constructor<>() @@ -230,30 +233,39 @@ JLCXX_MODULE reactant_module_ifrt(jlcxx::Module& mod) mod.unset_override_module(); // Memory - // TODO check if `Devices` is correct (why does it return a span of pointers?) + // TODO check if `Devices` is correct (why does it return a span of pointers?) => errors due to `std::vector` conversion wrap_memory - .constructor<>() .method("id", &Memory::Id) .method("kind", &Memory::Kind) - .method("devices", [](const Memory& x) { - auto devices_span = x.Devices(); - return std::vector(devices_span.begin(), devices_span.end()); - }); + // .method("devices", [](const Memory& x) { + // auto devices_span = x.Devices(); + // return std::vector(devices_span.begin(), devices_span.end()); + // }) + ; mod.set_override_module(jl_base_module); JLCXX_CLASS_DEF_DBGSTR(wrap_memory, Memory) mod.unset_override_module(); // Device - // TODO `Memories` + // TODO `Attributes`, check if `Memories` is ok wrap_device - .method("client", &Device::client) + .method("client", &Device::client) // why does it return a pointer? .method("id", &Device::Id) // .method("attributes", &Device::Attributes) .method("kind", [](const Device& x) { return std::string(x.Kind()); }) + .method("default_memory", [](const Device& x) { return xla::ValueOrThrow(x.DefaultMemory()); }) + .method("memories", [](const Device& x) { + auto mems = x.Memories(); + return std::vector(mems.begin(), mems.end()); + }) .method("isaddressable", &Device::IsAddressable) .method("process_index", &Device::ProcessIndex); + mod.set_override_module(jl_base_module); + JLCXX_CLASS_DEF_DBGSTR(wrap_device, Device) + mod.unset_override_module(); + // Sharding mod.add_bits("SingleDeviceShardSemantics", jlcxx::julia_type("CppEnum")); mod.set_const("SingleDeviceShardSemanticsAddressable", SingleDeviceShardSemantics::kAddressableShards); @@ -307,12 +319,16 @@ JLCXX_MODULE reactant_module_ifrt(jlcxx::Module& mod) // Client // HostCallback + wrap_hostcallback.method("serialize", &HostCallback::Serialize); // LoadedHostCallback - - // PjRtHostSendAndRecvLoadedHostCallback + wrap_loadedhostcallback + .method("client", &LoadedHostCallback::client) + .method("serialize", [](const LoadedHostCallback& x) { return xla::ValueOrThrow(x.Serialize()); }); // Executable + wrap_executable + .method("name", [](const Executable& x) { return std::string(x.name()); }); // LoadedExecutable @@ -320,5 +336,7 @@ JLCXX_MODULE reactant_module_ifrt(jlcxx::Module& mod) // HloProgram + // PluginProgram + // Compiler }