From 0cd27bd910f4262ae288483d26cf37c21681f5c0 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Sat, 6 Dec 2025 17:04:21 +0000 Subject: [PATCH 01/20] Initial plan From baab7850d83354c266e8f57702ddf1eb186bfc1d Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Sat, 6 Dec 2025 17:18:37 +0000 Subject: [PATCH 02/20] Add export_to_enzymeax function for JAX integration Co-authored-by: avik-pal <30564094+avik-pal@users.noreply.github.com> --- src/serialization/EnzymeJAX.jl | 365 +++++++++++++++++++++++++++++ src/serialization/Serialization.jl | 18 ++ test/export_enzymeax.jl | 57 +++++ 3 files changed, 440 insertions(+) create mode 100644 src/serialization/EnzymeJAX.jl create mode 100644 test/export_enzymeax.jl diff --git a/src/serialization/EnzymeJAX.jl b/src/serialization/EnzymeJAX.jl new file mode 100644 index 0000000000..19c433b805 --- /dev/null +++ b/src/serialization/EnzymeJAX.jl @@ -0,0 +1,365 @@ +module EnzymeJAX + +using ..Reactant: Reactant, Compiler, MLIR + +const NUMPY_SIMPLE_TYPES = Dict( + Bool => "np.bool_", + Int8 => "np.int8", + Int16 => "np.int16", + Int32 => "np.int32", + Int64 => "np.int64", + UInt8 => "np.uint8", + UInt16 => "np.uint16", + UInt32 => "np.uint32", + UInt64 => "np.uint64", + Float16 => "np.float16", + Float32 => "np.float32", + Float64 => "np.float64", + ComplexF16 => "np.complex64", # Note: NumPy doesn't have float16 complex + ComplexF32 => "np.complex64", + ComplexF64 => "np.complex128", +) + +""" + export_to_enzymeax( + f, + args...; + output_dir::String=".", + function_name::String="exported_function", + ) + +Export a Julia function to EnzymeJAX format for use in Python/JAX. + +This function: +1. Compiles the function to StableHLO via `Reactant.@code_hlo` +2. Saves the MLIR/StableHLO code to a `.mlir` file +3. Saves input arrays to `.npy` files (transposed to account for row-major vs column-major) +4. Generates a Python script with the function wrapped for EnzymeJAX's `hlo_call` + +## Arguments + + - `f`: The Julia function to export + - `args...`: The arguments to the function (used to infer types and shapes) + +## Keyword Arguments + + - `output_dir::String="."`: Directory where output files will be saved + - `function_name::String="exported_function"`: Base name for generated files + +## Returns + +A tuple `(mlir_path, python_path, input_paths)` containing paths to: + - The generated `.mlir` file + - The generated `.py` file + - A vector of paths to input `.npy` files + +## Example + +```julia +using Reactant + +# Define a simple function +function my_function(x, y) + return x .+ y +end + +# Create some example inputs +x = Reactant.to_rarray(Float32[1, 2, 3]) +y = Reactant.to_rarray(Float32[4, 5, 6]) + +# Export to EnzymeJAX +mlir_path, python_path, input_paths = Reactant.Serialization.export_to_enzymeax( + my_function, x, y; + output_dir="/tmp/exported", + function_name="my_function" +) +``` + +Then in Python: +```python +# Run the generated Python script +from exported.my_function import run_my_function +import jax + +result = jax.jit(run_my_function)(*inputs) +``` +""" +function export_to_enzymeax( + f, + args...; + output_dir::String=".", + function_name::String="exported_function", +) + # Create output directory if it doesn't exist + mkpath(output_dir) + + # Generate the StableHLO/MLIR code using compile_mlir directly + mod, mlir_fn_res = Compiler.compile_mlir( + f, args; + shardy_passes=:none + ) + hlo_code = string(mod) + + # Save MLIR code + mlir_path = joinpath(output_dir, "$(function_name).mlir") + write(mlir_path, hlo_code) + + # Process and save inputs + input_paths = String[] + input_info = [] + + for (i, arg) in enumerate(args) + # Convert to array if needed + arr = _to_array(arg) + + # Save the input (transposed for row-major Python/NumPy) + input_path = joinpath(output_dir, "$(function_name)_input_$(i).npy") + _save_transposed_array(input_path, arr) + push!(input_paths, input_path) + + # Store shape and dtype info (in Julia's column-major ordering) + push!(input_info, (shape=size(arr), dtype=eltype(arr))) + end + + # Generate Python script + python_path = joinpath(output_dir, "$(function_name).py") + _generate_python_script(python_path, function_name, mlir_path, input_paths, input_info) + + return (mlir_path, python_path, input_paths) +end + +""" +Convert Reactant types to regular Julia arrays for saving. +""" +function _to_array(x::Reactant.ConcreteRArray) + return Array(x) +end + +function _to_array(x::Reactant.ConcreteRNumber) + return [x.data] +end + +function _to_array(x::AbstractArray) + return Array(x) +end + +function _to_array(x::Number) + return [x] +end + +function _to_array(x::Tuple) + error("Tuple arguments are not yet supported. Please flatten your arguments.") +end + +function _to_array(x::NamedTuple) + error("NamedTuple arguments are not yet supported. Please flatten your arguments.") +end + +""" +Save an array to a .npy file, transposing to account for row-major vs column-major ordering. +""" +function _save_transposed_array(path::String, arr::AbstractArray) + # For multi-dimensional arrays, we need to reverse the dimensions for Python/NumPy + # Julia: column-major (fastest changing index first) + # Python: row-major (fastest changing index last) + transposed = permutedims(arr, reverse(1:ndims(arr))) + + # Use a simple .npy writer + # NPY format v1.0: magic (6 bytes) + version (2 bytes) + header_len (2 bytes) + header + data + open(path, "w") do io + # Magic number for .npy format + write(io, UInt8[0x93, 0x4E, 0x55, 0x4D, 0x50, 0x59]) + # Version 1.0 + write(io, UInt8[0x01, 0x00]) + + # Prepare header + dtype_str = _numpy_dtype_string(eltype(arr)) + shape_str = join(size(transposed), ", ") + header = "{'descr': '$(dtype_str)', 'fortran_order': False, 'shape': ($(shape_str),)}" + + # Pad header to be aligned on 64 bytes + header_len = length(header) + 1 # +1 for newline + total_len = 10 + header_len # 10 = magic(6) + version(2) + header_len(2) + padding = (64 - (total_len % 64)) % 64 + header = header * " "^padding * "\n" + header_len = length(header) + + # Write header length (little-endian UInt16) + write(io, UInt16(header_len)) + # Write header + write(io, header) + # Write data + write(io, vec(transposed)) + end +end + +""" +Get NumPy dtype string for a Julia type. +""" +function _numpy_dtype_string(::Type{Bool}) + return "|b1" +end + +function _numpy_dtype_string(::Type{Int8}) + return "|i1" +end + +function _numpy_dtype_string(::Type{UInt8}) + return "|u1" +end + +function _numpy_dtype_string(::Type{Int16}) + return " 0 + end + + println("✓ All export_to_enzymeax tests passed!") + println(" - MLIR file created: $(mlir_path)") + println(" - Python file created: $(python_path)") + println(" - Input files created: $(length(input_paths))") + finally + # Clean up + rm(tmpdir; recursive=true, force=true) + end +end From fe4150c9edd929bb508ed4beacc242f01aac8d99 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Sat, 6 Dec 2025 17:21:59 +0000 Subject: [PATCH 03/20] Add comprehensive tests for export_to_enzymeax Co-authored-by: avik-pal <30564094+avik-pal@users.noreply.github.com> --- test/export_enzymeax_comprehensive.jl | 115 ++++++++++++++++++++++++++ 1 file changed, 115 insertions(+) create mode 100644 test/export_enzymeax_comprehensive.jl diff --git a/test/export_enzymeax_comprehensive.jl b/test/export_enzymeax_comprehensive.jl new file mode 100644 index 0000000000..0f1fc9c132 --- /dev/null +++ b/test/export_enzymeax_comprehensive.jl @@ -0,0 +1,115 @@ +using Reactant +using Test + +@testset "Export to EnzymeJAX - Multi-dimensional Arrays" begin + tmpdir = mktempdir() + + try + # Define a function with 2D arrays + function matrix_multiply(x, y) + return x * y + end + + # Create 2D arrays - Julia uses column-major order + x = Reactant.to_rarray(Float32[1 2 3; 4 5 6]) # 2x3 matrix + y = Reactant.to_rarray(Float32[7 8; 9 10; 11 12]) # 3x2 matrix + + # Export to EnzymeJAX + mlir_path, python_path, input_paths = Reactant.Serialization.export_to_enzymeax( + matrix_multiply, x, y; + output_dir=tmpdir, + function_name="matrix_multiply" + ) + + @test isfile(mlir_path) + @test isfile(python_path) + @test length(input_paths) == 2 + + # Read Python file and check for correct shape information + python_content = read(python_path, String) + + # The shapes should be transposed for Python (row-major) + # Julia x: (2, 3) -> Python: (3, 2) + # Julia y: (3, 2) -> Python: (2, 3) + @test occursin("(3, 2)", python_content) # Transposed shape of x + @test occursin("(2, 3)", python_content) # Transposed shape of y + + println("✓ Multi-dimensional array export test passed!") + + finally + rm(tmpdir; recursive=true, force=true) + end +end + +@testset "Export to EnzymeJAX - 3D Arrays" begin + tmpdir = mktempdir() + + try + # Define a function with 3D arrays (like image data) + function add_3d(x, y) + return x .+ y + end + + # Create 3D arrays - e.g., (height, width, channels, batch) + # Julia: (28, 28, 1, 4) -> Python: (4, 1, 28, 28) + x = Reactant.to_rarray(rand(Float32, 28, 28, 1, 4)) + y = Reactant.to_rarray(rand(Float32, 28, 28, 1, 4)) + + # Export to EnzymeJAX + mlir_path, python_path, input_paths = Reactant.Serialization.export_to_enzymeax( + add_3d, x, y; + output_dir=tmpdir, + function_name="add_3d" + ) + + @test isfile(mlir_path) + @test isfile(python_path) + + # Check that Python file mentions the transposed shape + python_content = read(python_path, String) + @test occursin("(4, 1, 28, 28)", python_content) + + println("✓ 3D array export test passed!") + + finally + rm(tmpdir; recursive=true, force=true) + end +end + +@testset "Export to EnzymeJAX - File Content Verification" begin + tmpdir = mktempdir() + + try + function simple_fn(x) + return x .* 2.0f0 + end + + x = Reactant.to_rarray(Float32[1.0, 2.0, 3.0, 4.0]) + + mlir_path, python_path, input_paths = Reactant.Serialization.export_to_enzymeax( + simple_fn, x; + output_dir=tmpdir, + function_name="test_fn" + ) + + # Verify MLIR contains necessary elements + mlir_content = read(mlir_path, String) + @test occursin("module", mlir_content) + + # Verify Python file structure + python_content = read(python_path, String) + @test occursin("import jax", python_content) + @test occursin("import numpy as np", python_content) + @test occursin("from enzyme_ad.jax import hlo_call", python_content) + @test occursin("def run_test_fn(arg1)", python_content) + @test occursin("source=_hlo_code", python_content) + @test occursin("jax.jit(run_test_fn)", python_content) + + println("✓ File content verification test passed!") + + finally + rm(tmpdir; recursive=true, force=true) + end +end + +println("\n✅ All comprehensive tests passed!") From 390176c0fedd82236b425e67ef19ba4c6f9db4b7 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Sat, 6 Dec 2025 17:29:51 +0000 Subject: [PATCH 04/20] Fix code review issues: binary mode, ComplexF16 support, Python indentation Co-authored-by: avik-pal <30564094+avik-pal@users.noreply.github.com> --- src/serialization/EnzymeJAX.jl | 150 +++++++++++++++------------------ test/runtests.jl | 1 + 2 files changed, 70 insertions(+), 81 deletions(-) diff --git a/src/serialization/EnzymeJAX.jl b/src/serialization/EnzymeJAX.jl index 19c433b805..ec4f0b391f 100644 --- a/src/serialization/EnzymeJAX.jl +++ b/src/serialization/EnzymeJAX.jl @@ -15,7 +15,6 @@ const NUMPY_SIMPLE_TYPES = Dict( Float16 => "np.float16", Float32 => "np.float32", Float64 => "np.float64", - ComplexF16 => "np.complex64", # Note: NumPy doesn't have float16 complex ComplexF32 => "np.complex64", ComplexF64 => "np.complex128", ) @@ -177,10 +176,11 @@ function _save_transposed_array(path::String, arr::AbstractArray) shape_str = join(size(transposed), ", ") header = "{'descr': '$(dtype_str)', 'fortran_order': False, 'shape': ($(shape_str),)}" - # Pad header to be aligned on 64 bytes + # Pad header to be aligned on 64 bytes (16-byte alignment for v1.0) + # Total size needs to be divisible by 16 header_len = length(header) + 1 # +1 for newline total_len = 10 + header_len # 10 = magic(6) + version(2) + header_len(2) - padding = (64 - (total_len % 64)) % 64 + padding = (16 - (total_len % 16)) % 16 header = header * " "^padding * "\n" header_len = length(header) @@ -191,6 +191,7 @@ function _save_transposed_array(path::String, arr::AbstractArray) # Write data write(io, vec(transposed)) end + return nothing end """ @@ -267,99 +268,86 @@ function _generate_python_script( mlir_rel = relpath(mlir_path, output_dir) input_rels = [relpath(p, output_dir) for p in input_paths] - # Start building the Python script - script = """ - \"\"\" - Auto-generated Python script for calling exported Julia/Reactant function via EnzymeJAX. + # Build the Python script without leading indentation + lines = String[] - This script was generated by Reactant.Serialization.export_to_enzymeax(). - \"\"\" + # Header + push!(lines, "\"\"\"") + push!(lines, "Auto-generated Python script for calling exported Julia/Reactant function via EnzymeJAX.") + push!(lines, "") + push!(lines, "This script was generated by Reactant.Serialization.export_to_enzymeax().") + push!(lines, "\"\"\"") + push!(lines, "") + push!(lines, "from enzyme_ad.jax import hlo_call") + push!(lines, "import jax") + push!(lines, "import jax.numpy as jnp") + push!(lines, "import numpy as np") + push!(lines, "import os") + push!(lines, "") + push!(lines, "# Get the directory of this script") + push!(lines, "_script_dir = os.path.dirname(os.path.abspath(__file__))") + push!(lines, "") + push!(lines, "# Load the MLIR/StableHLO code") + push!(lines, "with open(os.path.join(_script_dir, \"$(mlir_rel)\"), \"r\") as f:") + push!(lines, " _hlo_code = f.read()") + push!(lines, "") - from enzyme_ad.jax import hlo_call - import jax - import jax.numpy as jnp - import numpy as np - import os - - # Get the directory of this script - _script_dir = os.path.dirname(os.path.abspath(__file__)) - - # Load the MLIR/StableHLO code - with open(os.path.join(_script_dir, "$(mlir_rel)"), "r") as f: - _hlo_code = f.read() - - """ - - # Add function to load inputs - script *= """ - def load_inputs(): - \"\"\"Load the example inputs that were exported from Julia.\"\"\" - inputs = [] - """ - - for (i, input_rel) in enumerate(input_rels) - script *= """ - inputs.append(np.load(os.path.join(_script_dir, "$(input_rel)"))) - """ + # Function to load inputs + push!(lines, "def load_inputs():") + push!(lines, " \"\"\"Load the example inputs that were exported from Julia.\"\"\"") + push!(lines, " inputs = []") + for input_rel in input_rels + push!(lines, " inputs.append(np.load(os.path.join(_script_dir, \"$(input_rel)\")))") end + push!(lines, " return tuple(inputs)") + push!(lines, "") - script *= """ - return tuple(inputs) - - """ - - # Add the main function that calls the HLO code + # Main function arg_names = ["arg$i" for i in 1:length(input_paths)] arg_list = join(arg_names, ", ") - script *= """ - def run_$(function_name)($(arg_list)): - \"\"\" - Call the exported Julia function via EnzymeJAX. - - Args: - """ + push!(lines, "def run_$(function_name)($(arg_list)):") + push!(lines, " \"\"\"") + push!(lines, " Call the exported Julia function via EnzymeJAX.") + push!(lines, " ") + push!(lines, " Args:") for (i, info) in enumerate(input_info) # Note: shapes are already transposed for Python python_shape = reverse(info.shape) - script *= """ - $(arg_names[i]): Array of shape $(python_shape) and dtype $(NUMPY_SIMPLE_TYPES[info.dtype]) - """ + push!(lines, " $(arg_names[i]): Array of shape $(python_shape) and dtype $(NUMPY_SIMPLE_TYPES[info.dtype])") end - script *= """ - - Returns: - The result of calling the exported function. - - Note: - All inputs must be in row-major (Python/NumPy) order. If you're passing - arrays from Julia, make sure to transpose them first using: - `permutedims(arr, reverse(1:ndims(arr)))` - \"\"\" - return hlo_call( - $(arg_list), - source=_hlo_code, - ) - - """ + push!(lines, " ") + push!(lines, " Returns:") + push!(lines, " The result of calling the exported function.") + push!(lines, " ") + push!(lines, " Note:") + push!(lines, " All inputs must be in row-major (Python/NumPy) order. If you're passing") + push!(lines, " arrays from Julia, make sure to transpose them first using:") + push!(lines, " `permutedims(arr, reverse(1:ndims(arr)))`") + push!(lines, " \"\"\"") + push!(lines, " return hlo_call(") + push!(lines, " $(arg_list),") + push!(lines, " source=_hlo_code,") + push!(lines, " )") + push!(lines, "") - # Add a main block for testing - script *= """ - if __name__ == "__main__": - # Load the example inputs - inputs = load_inputs() - - # Run the function (with JIT compilation) - print("Running $(function_name) with JIT compilation...") - result = jax.jit(run_$(function_name))(*inputs) - print("Result:", result) - print("Result shape:", result.shape if hasattr(result, 'shape') else 'scalar') - print("Result dtype:", result.dtype if hasattr(result, 'dtype') else type(result)) - """ + # Main block + push!(lines, "if __name__ == \"__main__\":") + push!(lines, " # Load the example inputs") + push!(lines, " inputs = load_inputs()") + push!(lines, " ") + push!(lines, " # Run the function (with JIT compilation)") + push!(lines, " print(\"Running $(function_name) with JIT compilation...\")") + push!(lines, " result = jax.jit(run_$(function_name))(*inputs)") + push!(lines, " print(\"Result:\", result)") + push!(lines, " print(\"Result shape:\", result.shape if hasattr(result, 'shape') else 'scalar')") + push!(lines, " print(\"Result dtype:\", result.dtype if hasattr(result, 'dtype') else type(result))") - write(python_path, script) + # Write the script + write(python_path, join(lines, "\n") * "\n") + return nothing end end # module diff --git a/test/runtests.jl b/test/runtests.jl index 4000cd40a8..97a258d29c 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -38,6 +38,7 @@ const REACTANT_TEST_GROUP = lowercase(get(ENV, "REACTANT_TEST_GROUP", "all")) @safetestset "Cluster Detection" include("cluster_detector.jl") @safetestset "Config" include("config.jl") @safetestset "Batching" include("batching.jl") + @safetestset "Export to EnzymeJAX" include("export_enzymeax.jl") @safetestset "QA" include("qa.jl") end From 9605ef51c20669ba6f7b611da3ab85b87fe41103 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Sat, 6 Dec 2025 17:31:56 +0000 Subject: [PATCH 05/20] Add documentation for export_to_enzymeax function Co-authored-by: avik-pal <30564094+avik-pal@users.noreply.github.com> --- docs/src/api/serialization.md | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/docs/src/api/serialization.md b/docs/src/api/serialization.md index 2620085f58..6e9affed04 100644 --- a/docs/src/api/serialization.md +++ b/docs/src/api/serialization.md @@ -27,3 +27,23 @@ or [TensorFlow Hub](https://tensorflow.org/hub). Refer to the ```@docs Reactant.Serialization.export_as_tf_saved_model ``` + +## Exporting to JAX via EnzymeAD + +!!! note "No Dependencies Required" + + Unlike TensorFlow SavedModel export, exporting to JAX via EnzymeAD does not require any + Python dependencies at build time. It generates standalone files that can be used with + EnzymeAD/JAX in Python. + +This export functionality generates: +1. A `.mlir` file containing the StableHLO representation of your Julia function +2. Example input `.npy` files with properly transposed arrays (column-major → row-major) +3. A Python script that wraps the function for use with `enzyme_ad.jax.hlo_call` + +The generated Python script can be immediately used with JAX and EnzymeAD without any +additional Julia dependencies. + +```@docs +Reactant.Serialization.export_to_enzymeax +``` From 5ef177158955ab553eacd6b6bdb9c87fa465f3d1 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 6 Dec 2025 13:35:44 -0500 Subject: [PATCH 06/20] chore: run formatter --- src/serialization/EnzymeJAX.jl | 91 ++++++++++++++++++++-------------- 1 file changed, 53 insertions(+), 38 deletions(-) diff --git a/src/serialization/EnzymeJAX.jl b/src/serialization/EnzymeJAX.jl index ec4f0b391f..c26d9f1e43 100644 --- a/src/serialization/EnzymeJAX.jl +++ b/src/serialization/EnzymeJAX.jl @@ -84,46 +84,40 @@ result = jax.jit(run_my_function)(*inputs) ``` """ function export_to_enzymeax( - f, - args...; - output_dir::String=".", - function_name::String="exported_function", + f, args...; output_dir::String=".", function_name::String="exported_function" ) # Create output directory if it doesn't exist mkpath(output_dir) - + # Generate the StableHLO/MLIR code using compile_mlir directly - mod, mlir_fn_res = Compiler.compile_mlir( - f, args; - shardy_passes=:none - ) + mod, mlir_fn_res = Compiler.compile_mlir(f, args; shardy_passes=:none) hlo_code = string(mod) - + # Save MLIR code mlir_path = joinpath(output_dir, "$(function_name).mlir") write(mlir_path, hlo_code) - + # Process and save inputs input_paths = String[] input_info = [] - + for (i, arg) in enumerate(args) # Convert to array if needed arr = _to_array(arg) - + # Save the input (transposed for row-major Python/NumPy) input_path = joinpath(output_dir, "$(function_name)_input_$(i).npy") _save_transposed_array(input_path, arr) push!(input_paths, input_path) - + # Store shape and dtype info (in Julia's column-major ordering) push!(input_info, (shape=size(arr), dtype=eltype(arr))) end - + # Generate Python script python_path = joinpath(output_dir, "$(function_name).py") _generate_python_script(python_path, function_name, mlir_path, input_paths, input_info) - + return (mlir_path, python_path, input_paths) end @@ -147,11 +141,13 @@ function _to_array(x::Number) end function _to_array(x::Tuple) - error("Tuple arguments are not yet supported. Please flatten your arguments.") + return error("Tuple arguments are not yet supported. Please flatten your arguments.") end function _to_array(x::NamedTuple) - error("NamedTuple arguments are not yet supported. Please flatten your arguments.") + return error( + "NamedTuple arguments are not yet supported. Please flatten your arguments." + ) end """ @@ -162,7 +158,7 @@ function _save_transposed_array(path::String, arr::AbstractArray) # Julia: column-major (fastest changing index first) # Python: row-major (fastest changing index last) transposed = permutedims(arr, reverse(1:ndims(arr))) - + # Use a simple .npy writer # NPY format v1.0: magic (6 bytes) + version (2 bytes) + header_len (2 bytes) + header + data open(path, "w") do io @@ -170,12 +166,12 @@ function _save_transposed_array(path::String, arr::AbstractArray) write(io, UInt8[0x93, 0x4E, 0x55, 0x4D, 0x50, 0x59]) # Version 1.0 write(io, UInt8[0x01, 0x00]) - + # Prepare header dtype_str = _numpy_dtype_string(eltype(arr)) shape_str = join(size(transposed), ", ") header = "{'descr': '$(dtype_str)', 'fortran_order': False, 'shape': ($(shape_str),)}" - + # Pad header to be aligned on 64 bytes (16-byte alignment for v1.0) # Total size needs to be divisible by 16 header_len = length(header) + 1 # +1 for newline @@ -183,7 +179,7 @@ function _save_transposed_array(path::String, arr::AbstractArray) padding = (16 - (total_len % 16)) % 16 header = header * " "^padding * "\n" header_len = length(header) - + # Write header length (little-endian UInt16) write(io, UInt16(header_len)) # Write header @@ -267,15 +263,20 @@ function _generate_python_script( output_dir = dirname(python_path) mlir_rel = relpath(mlir_path, output_dir) input_rels = [relpath(p, output_dir) for p in input_paths] - + # Build the Python script without leading indentation lines = String[] - + # Header push!(lines, "\"\"\"") - push!(lines, "Auto-generated Python script for calling exported Julia/Reactant function via EnzymeJAX.") + push!( + lines, + "Auto-generated Python script for calling exported Julia/Reactant function via EnzymeJAX.", + ) push!(lines, "") - push!(lines, "This script was generated by Reactant.Serialization.export_to_enzymeax().") + push!( + lines, "This script was generated by Reactant.Serialization.export_to_enzymeax()." + ) push!(lines, "\"\"\"") push!(lines, "") push!(lines, "from enzyme_ad.jax import hlo_call") @@ -291,39 +292,47 @@ function _generate_python_script( push!(lines, "with open(os.path.join(_script_dir, \"$(mlir_rel)\"), \"r\") as f:") push!(lines, " _hlo_code = f.read()") push!(lines, "") - + # Function to load inputs push!(lines, "def load_inputs():") push!(lines, " \"\"\"Load the example inputs that were exported from Julia.\"\"\"") push!(lines, " inputs = []") for input_rel in input_rels - push!(lines, " inputs.append(np.load(os.path.join(_script_dir, \"$(input_rel)\")))") + push!( + lines, " inputs.append(np.load(os.path.join(_script_dir, \"$(input_rel)\")))" + ) end push!(lines, " return tuple(inputs)") push!(lines, "") - + # Main function arg_names = ["arg$i" for i in 1:length(input_paths)] arg_list = join(arg_names, ", ") - + push!(lines, "def run_$(function_name)($(arg_list)):") push!(lines, " \"\"\"") push!(lines, " Call the exported Julia function via EnzymeJAX.") push!(lines, " ") push!(lines, " Args:") - + for (i, info) in enumerate(input_info) # Note: shapes are already transposed for Python python_shape = reverse(info.shape) - push!(lines, " $(arg_names[i]): Array of shape $(python_shape) and dtype $(NUMPY_SIMPLE_TYPES[info.dtype])") + push!( + lines, + " $(arg_names[i]): Array of shape $(python_shape) and dtype $(NUMPY_SIMPLE_TYPES[info.dtype])", + ) end - + push!(lines, " ") push!(lines, " Returns:") push!(lines, " The result of calling the exported function.") push!(lines, " ") push!(lines, " Note:") - push!(lines, " All inputs must be in row-major (Python/NumPy) order. If you're passing") + push!( + lines, + " All inputs must be in row-major (Python/NumPy) order. If you're passing", + ) push!(lines, " arrays from Julia, make sure to transpose them first using:") push!(lines, " `permutedims(arr, reverse(1:ndims(arr)))`") push!(lines, " \"\"\"") @@ -332,7 +341,7 @@ function _generate_python_script( push!(lines, " source=_hlo_code,") push!(lines, " )") push!(lines, "") - + # Main block push!(lines, "if __name__ == \"__main__\":") push!(lines, " # Load the example inputs") @@ -342,9 +351,15 @@ function _generate_python_script( push!(lines, " print(\"Running $(function_name) with JIT compilation...\")") push!(lines, " result = jax.jit(run_$(function_name))(*inputs)") push!(lines, " print(\"Result:\", result)") - push!(lines, " print(\"Result shape:\", result.shape if hasattr(result, 'shape') else 'scalar')") - push!(lines, " print(\"Result dtype:\", result.dtype if hasattr(result, 'dtype') else type(result))") - + push!( + lines, + " print(\"Result shape:\", result.shape if hasattr(result, 'shape') else 'scalar')", + ) + push!( + lines, + " print(\"Result dtype:\", result.dtype if hasattr(result, 'dtype') else type(result))", + ) + # Write the script write(python_path, join(lines, "\n") * "\n") return nothing From db32dff17f48e3e434de858fd84821e596d5c322 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 6 Dec 2025 13:43:58 -0500 Subject: [PATCH 07/20] chore: cleanup --- .../ReactantPythonCallExt.jl | 19 +- src/serialization/EnzymeJAX.jl | 183 +++++++----------- src/serialization/Serialization.jl | 18 ++ src/serialization/TFSavedModel.jl | 20 +- 4 files changed, 95 insertions(+), 145 deletions(-) diff --git a/ext/ReactantPythonCallExt/ReactantPythonCallExt.jl b/ext/ReactantPythonCallExt/ReactantPythonCallExt.jl index 1f10630808..60e176e8ac 100644 --- a/ext/ReactantPythonCallExt/ReactantPythonCallExt.jl +++ b/ext/ReactantPythonCallExt/ReactantPythonCallExt.jl @@ -3,6 +3,7 @@ module ReactantPythonCallExt using PythonCall: PythonCall, Py, pyconvert, pydict, pyfunc, pyimport, pylist using Reactant: Reactant, TracedRArray, TracedRNumber, @reactant_overlay using Reactant.Ops: @opcall +using Reactant.Serialization: NUMPY_SIMPLE_TYPES const jaxptr = Ref{Py}() const jnpptr = Ref{Py}() @@ -15,24 +16,6 @@ const npptr = Ref{Py}() const SAVED_MODEL_EXPORT_SUPPORTED = Ref{Bool}(false) -const NUMPY_SIMPLE_TYPES = Dict( - Bool => :bool, - Int8 => :int8, - Int16 => :int16, - Int32 => :int32, - Int64 => :int64, - UInt8 => :uint8, - UInt16 => :uint16, - UInt32 => :uint32, - UInt64 => :uint64, - Float16 => :float16, - Float32 => :float32, - Float64 => :float64, - ComplexF16 => :complex16, - ComplexF32 => :complex32, - ComplexF64 => :complex64, -) - function __init__() try jaxptr[] = pyimport("jax") diff --git a/src/serialization/EnzymeJAX.jl b/src/serialization/EnzymeJAX.jl index c26d9f1e43..2acb1247c4 100644 --- a/src/serialization/EnzymeJAX.jl +++ b/src/serialization/EnzymeJAX.jl @@ -1,23 +1,6 @@ module EnzymeJAX -using ..Reactant: Reactant, Compiler, MLIR - -const NUMPY_SIMPLE_TYPES = Dict( - Bool => "np.bool_", - Int8 => "np.int8", - Int16 => "np.int16", - Int32 => "np.int32", - Int64 => "np.int64", - UInt8 => "np.uint8", - UInt16 => "np.uint16", - UInt32 => "np.uint32", - UInt64 => "np.uint64", - Float16 => "np.float16", - Float32 => "np.float32", - Float64 => "np.float64", - ComplexF32 => "np.complex64", - ComplexF64 => "np.complex128", -) +using ..Reactant: Reactant, Compiler, MLIR, Serialization """ export_to_enzymeax( @@ -264,104 +247,88 @@ function _generate_python_script( mlir_rel = relpath(mlir_path, output_dir) input_rels = [relpath(p, output_dir) for p in input_paths] - # Build the Python script without leading indentation - lines = String[] - - # Header - push!(lines, "\"\"\"") - push!( - lines, - "Auto-generated Python script for calling exported Julia/Reactant function via EnzymeJAX.", - ) - push!(lines, "") - push!( - lines, "This script was generated by Reactant.Serialization.export_to_enzymeax()." + # Generate input loading code + input_loads = join( + [ + " inputs.append(np.load(os.path.join(_script_dir, \"$rel\")))" for + rel in input_rels + ], + "\n", ) - push!(lines, "\"\"\"") - push!(lines, "") - push!(lines, "from enzyme_ad.jax import hlo_call") - push!(lines, "import jax") - push!(lines, "import jax.numpy as jnp") - push!(lines, "import numpy as np") - push!(lines, "import os") - push!(lines, "") - push!(lines, "# Get the directory of this script") - push!(lines, "_script_dir = os.path.dirname(os.path.abspath(__file__))") - push!(lines, "") - push!(lines, "# Load the MLIR/StableHLO code") - push!(lines, "with open(os.path.join(_script_dir, \"$(mlir_rel)\"), \"r\") as f:") - push!(lines, " _hlo_code = f.read()") - push!(lines, "") - - # Function to load inputs - push!(lines, "def load_inputs():") - push!(lines, " \"\"\"Load the example inputs that were exported from Julia.\"\"\"") - push!(lines, " inputs = []") - for input_rel in input_rels - push!( - lines, " inputs.append(np.load(os.path.join(_script_dir, \"$(input_rel)\")))" - ) - end - push!(lines, " return tuple(inputs)") - push!(lines, "") - # Main function + # Generate argument list and documentation arg_names = ["arg$i" for i in 1:length(input_paths)] arg_list = join(arg_names, ", ") - push!(lines, "def run_$(function_name)($(arg_list)):") - push!(lines, " \"\"\"") - push!(lines, " Call the exported Julia function via EnzymeJAX.") - push!(lines, " ") - push!(lines, " Args:") - - for (i, info) in enumerate(input_info) - # Note: shapes are already transposed for Python - python_shape = reverse(info.shape) - push!( - lines, - " $(arg_names[i]): Array of shape $(python_shape) and dtype $(NUMPY_SIMPLE_TYPES[info.dtype])", + # Generate docstring for arguments + arg_docs = join( + [ + " $(arg_names[i]): Array of shape $(reverse(info.shape)) and dtype $(Serialization.NUMPY_SIMPLE_TYPES[info.dtype])" + for (i, info) in enumerate(input_info) + ], + "\n", + ) + + # Build the complete Python script + script = """ + \"\"\" + Auto-generated Python script for calling exported Julia/Reactant function via EnzymeJAX. + + This script was generated by Reactant.Serialization.export_to_enzymeax(). + \"\"\" + + from enzyme_ad.jax import hlo_call + import jax + import jax.numpy as jnp + import numpy as np + import os + + # Get the directory of this script + _script_dir = os.path.dirname(os.path.abspath(__file__)) + + # Load the MLIR/StableHLO code + with open(os.path.join(_script_dir, \"$(mlir_rel)\"), \"r\") as f: + _hlo_code = f.read() + + def load_inputs(): + \"\"\"Load the example inputs that were exported from Julia.\"\"\" + inputs = [] + $input_loads + return tuple(inputs) + + def run_$(function_name)($(arg_list)): + \"\"\" + Call the exported Julia function via EnzymeJAX. + + Args: + $arg_docs + + Returns: + The result of calling the exported function. + + Note: + All inputs must be in row-major (Python/NumPy) order. If you're passing + arrays from Julia, make sure to transpose them first using: + \`permutedims(arr, reverse(1:ndims(arr)))\` + \"\"\" + return hlo_call( + $(arg_list), + source=_hlo_code, ) - end - push!(lines, " ") - push!(lines, " Returns:") - push!(lines, " The result of calling the exported function.") - push!(lines, " ") - push!(lines, " Note:") - push!( - lines, - " All inputs must be in row-major (Python/NumPy) order. If you're passing", - ) - push!(lines, " arrays from Julia, make sure to transpose them first using:") - push!(lines, " `permutedims(arr, reverse(1:ndims(arr)))`") - push!(lines, " \"\"\"") - push!(lines, " return hlo_call(") - push!(lines, " $(arg_list),") - push!(lines, " source=_hlo_code,") - push!(lines, " )") - push!(lines, "") - - # Main block - push!(lines, "if __name__ == \"__main__\":") - push!(lines, " # Load the example inputs") - push!(lines, " inputs = load_inputs()") - push!(lines, " ") - push!(lines, " # Run the function (with JIT compilation)") - push!(lines, " print(\"Running $(function_name) with JIT compilation...\")") - push!(lines, " result = jax.jit(run_$(function_name))(*inputs)") - push!(lines, " print(\"Result:\", result)") - push!( - lines, - " print(\"Result shape:\", result.shape if hasattr(result, 'shape') else 'scalar')", - ) - push!( - lines, - " print(\"Result dtype:\", result.dtype if hasattr(result, 'dtype') else type(result))", - ) + if __name__ == \"__main__\": + # Load the example inputs + inputs = load_inputs() + + # Run the function (with JIT compilation) + print(\"Running $(function_name) with JIT compilation...\") + result = jax.jit(run_$(function_name))(*inputs) + print(\"Result:\", result) + print(\"Result shape:\", result.shape if hasattr(result, 'shape') else 'scalar') + print(\"Result dtype:\", result.dtype if hasattr(result, 'dtype') else type(result)) + """ - # Write the script - write(python_path, join(lines, "\n") * "\n") + write(python_path, strip(script) * "\n") return nothing end diff --git a/src/serialization/Serialization.jl b/src/serialization/Serialization.jl index a55b5632db..c0fb2bde84 100644 --- a/src/serialization/Serialization.jl +++ b/src/serialization/Serialization.jl @@ -10,6 +10,24 @@ using ..Reactant: Reactant, Compiler serialization_supported(::Val) = false +const NUMPY_SIMPLE_TYPES = Dict( + Bool => :bool, + Int8 => :int8, + Int16 => :int16, + Int32 => :int32, + Int64 => :int64, + UInt8 => :uint8, + UInt16 => :uint16, + UInt32 => :uint32, + UInt64 => :uint64, + Float16 => :float16, + Float32 => :float32, + Float64 => :float64, + ComplexF16 => :complex16, + ComplexF32 => :complex32, + ComplexF64 => :complex64, +) + include("TFSavedModel.jl") include("EnzymeJAX.jl") diff --git a/src/serialization/TFSavedModel.jl b/src/serialization/TFSavedModel.jl index 38ad8ee814..0409c15b8e 100644 --- a/src/serialization/TFSavedModel.jl +++ b/src/serialization/TFSavedModel.jl @@ -1,28 +1,10 @@ module TFSavedModel -using ..Serialization: serialization_supported +using ..Serialization: serialization_supported, NUMPY_SIMPLE_TYPES using ..Reactant: Compiler, MLIR # https://github.com/openxla/stablehlo/blob/955fa7e6e3b0a6411edc8ff6fcce1e644440acbd/stablehlo/integrations/python/stablehlo/savedmodel/stablehlo_to_tf_saved_model.py -const NUMPY_SIMPLE_TYPES = Dict( - Bool => :bool, - Int8 => :int8, - Int16 => :int16, - Int32 => :int32, - Int64 => :int64, - UInt8 => :uint8, - UInt16 => :uint16, - UInt32 => :uint32, - UInt64 => :uint64, - Float16 => :float16, - Float32 => :float32, - Float64 => :float64, - ComplexF16 => :complex16, - ComplexF32 => :complex32, - ComplexF64 => :complex64, -) - struct VariableSignature shape::Vector{Int} dtype::Symbol From d57a401147f3df3b630e8666e1eab4cefaa45ebe Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 6 Dec 2025 14:04:47 -0500 Subject: [PATCH 08/20] fix: proper export --- src/serialization/EnzymeJAX.jl | 174 ++++++++++----------------------- 1 file changed, 52 insertions(+), 122 deletions(-) diff --git a/src/serialization/EnzymeJAX.jl b/src/serialization/EnzymeJAX.jl index 2acb1247c4..c0bc81844e 100644 --- a/src/serialization/EnzymeJAX.jl +++ b/src/serialization/EnzymeJAX.jl @@ -41,20 +41,16 @@ A tuple `(mlir_path, python_path, input_paths)` containing paths to: using Reactant # Define a simple function -function my_function(x, y) - return x .+ y +function my_function(x, y::NamedTuple) + return x .+ y.x .- y.y end # Create some example inputs x = Reactant.to_rarray(Float32[1, 2, 3]) -y = Reactant.to_rarray(Float32[4, 5, 6]) +y = (; x=Reactant.to_rarray(Float32[4, 5, 6]), y=Reactant.to_rarray(Float32[7, 8, 9])) # Export to EnzymeJAX -mlir_path, python_path, input_paths = Reactant.Serialization.export_to_enzymeax( - my_function, x, y; - output_dir="/tmp/exported", - function_name="my_function" -) +python_file_path = Reactant.Serialization.export_to_enzymeax(my_function, x, y) ``` Then in Python: @@ -67,83 +63,63 @@ result = jax.jit(run_my_function)(*inputs) ``` """ function export_to_enzymeax( - f, args...; output_dir::String=".", function_name::String="exported_function" + f, args...; output_dir::Union{String,Nothing}=nothing, function_name::String=string(f) ) - # Create output directory if it doesn't exist - mkpath(output_dir) + if output_dir === nothing + output_dir = mktempdir(; cleanup=false) + @info "Output directory is $(output_dir)" + else + mkpath(output_dir) + end - # Generate the StableHLO/MLIR code using compile_mlir directly - mod, mlir_fn_res = Compiler.compile_mlir(f, args; shardy_passes=:none) + # Generate the StableHLO/MLIR code using compile_mlir + # This returns compilation result with traced argument information + mod, mlir_fn_res = Compiler.compile_mlir(f, args) hlo_code = string(mod) # Save MLIR code - mlir_path = joinpath(output_dir, "$(function_name).mlir") + fnid = 0 + while isfile(joinpath(output_dir, "$(function_name)_$(fnid).mlir")) + fnid += 1 + end + mlir_path = joinpath(output_dir, "$(function_name)_$(fnid).mlir") write(mlir_path, hlo_code) - # Process and save inputs + invmap = IdDict() + for (k, v) in mlir_fn_res.seen_args + invmap[v] = k + end + + # Process and save inputs based on the linearized arguments input_paths = String[] input_info = [] - - for (i, arg) in enumerate(args) - # Convert to array if needed - arr = _to_array(arg) - + for (i, linear_arg) in enumerate(mlir_fn_res.linear_args) + carg = invmap[linear_arg] # Save the input (transposed for row-major Python/NumPy) - input_path = joinpath(output_dir, "$(function_name)_input_$(i).npy") - _save_transposed_array(input_path, arr) + input_path = joinpath(output_dir, "$(function_name)_$(fnid)_input_$(i).npy") + _save_transposed_array(input_path, _to_array(carg)) push!(input_paths, input_path) - - # Store shape and dtype info (in Julia's column-major ordering) - push!(input_info, (shape=size(arr), dtype=eltype(arr))) + push!(input_info, (shape=size(carg), dtype=eltype(carg))) end # Generate Python script python_path = joinpath(output_dir, "$(function_name).py") _generate_python_script(python_path, function_name, mlir_path, input_paths, input_info) - - return (mlir_path, python_path, input_paths) -end - -""" -Convert Reactant types to regular Julia arrays for saving. -""" -function _to_array(x::Reactant.ConcreteRArray) - return Array(x) -end - -function _to_array(x::Reactant.ConcreteRNumber) - return [x.data] -end - -function _to_array(x::AbstractArray) - return Array(x) -end - -function _to_array(x::Number) - return [x] -end - -function _to_array(x::Tuple) - return error("Tuple arguments are not yet supported. Please flatten your arguments.") + return python_path end -function _to_array(x::NamedTuple) - return error( - "NamedTuple arguments are not yet supported. Please flatten your arguments." - ) -end +_to_array(x::Reactant.ConcreteRArray) = Array(x) +_to_array(x::Reactant.ConcreteRNumber) = Number(x) -""" -Save an array to a .npy file, transposing to account for row-major vs column-major ordering. -""" +# Save an array to a .npy file, transposing to account for row-major vs +# column-major ordering. function _save_transposed_array(path::String, arr::AbstractArray) # For multi-dimensional arrays, we need to reverse the dimensions for Python/NumPy - # Julia: column-major (fastest changing index first) - # Python: row-major (fastest changing index last) transposed = permutedims(arr, reverse(1:ndims(arr))) # Use a simple .npy writer - # NPY format v1.0: magic (6 bytes) + version (2 bytes) + header_len (2 bytes) + header + data + # NPY format v1.0: magic (6 bytes) + version (2 bytes) + header_len (2 bytes) + + # header + data open(path, "w") do io # Magic number for .npy format write(io, UInt8[0x93, 0x4E, 0x55, 0x4D, 0x50, 0x59]) @@ -173,68 +149,22 @@ function _save_transposed_array(path::String, arr::AbstractArray) return nothing end -""" -Get NumPy dtype string for a Julia type. -""" -function _numpy_dtype_string(::Type{Bool}) - return "|b1" -end - -function _numpy_dtype_string(::Type{Int8}) - return "|i1" -end - -function _numpy_dtype_string(::Type{UInt8}) - return "|u1" -end - -function _numpy_dtype_string(::Type{Int16}) - return " Date: Sat, 6 Dec 2025 19:14:21 +0000 Subject: [PATCH 09/20] Use existing seen_args infrastructure instead of manual inverse map Co-authored-by: wsmoses <1260124+wsmoses@users.noreply.github.com> --- src/serialization/EnzymeJAX.jl | 28 ++++++++++++++++------------ 1 file changed, 16 insertions(+), 12 deletions(-) diff --git a/src/serialization/EnzymeJAX.jl b/src/serialization/EnzymeJAX.jl index c0bc81844e..7ac1af3915 100644 --- a/src/serialization/EnzymeJAX.jl +++ b/src/serialization/EnzymeJAX.jl @@ -85,21 +85,25 @@ function export_to_enzymeax( mlir_path = joinpath(output_dir, "$(function_name)_$(fnid).mlir") write(mlir_path, hlo_code) - invmap = IdDict() - for (k, v) in mlir_fn_res.seen_args - invmap[v] = k - end - # Process and save inputs based on the linearized arguments + # seen_args is an OrderedIdDict where keys are concrete args and values are traced args + # linear_args contains only the arguments that need to be passed to the function + # We iterate over seen_args which preserves the order, and only save those in linear_args input_paths = String[] input_info = [] - for (i, linear_arg) in enumerate(mlir_fn_res.linear_args) - carg = invmap[linear_arg] - # Save the input (transposed for row-major Python/NumPy) - input_path = joinpath(output_dir, "$(function_name)_$(fnid)_input_$(i).npy") - _save_transposed_array(input_path, _to_array(carg)) - push!(input_paths, input_path) - push!(input_info, (shape=size(carg), dtype=eltype(carg))) + input_idx = 1 + for (concrete_arg, traced_arg) in mlir_fn_res.seen_args + # Only process arguments that are in linear_args (skip computed values) + if traced_arg in mlir_fn_res.linear_args + # Save the input (transposed for row-major Python/NumPy) + input_path = joinpath( + output_dir, "$(function_name)_$(fnid)_input_$(input_idx).npy" + ) + _save_transposed_array(input_path, _to_array(concrete_arg)) + push!(input_paths, input_path) + push!(input_info, (shape=size(concrete_arg), dtype=eltype(concrete_arg))) + input_idx += 1 + end end # Generate Python script From 1f5f62aacf74415175d85f9260e3b43d9e99495f Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Sat, 6 Dec 2025 19:19:03 +0000 Subject: [PATCH 10/20] Rename export_to_enzymeax to export_to_enzymejax Co-authored-by: wsmoses <30564094+wsmoses@users.noreply.github.com> --- docs/src/api/serialization.md | 2 +- src/serialization/EnzymeJAX.jl | 8 ++++---- src/serialization/Serialization.jl | 8 ++++---- test/export_enzymeax.jl | 4 ++-- test/export_enzymeax_comprehensive.jl | 6 +++--- 5 files changed, 14 insertions(+), 14 deletions(-) diff --git a/docs/src/api/serialization.md b/docs/src/api/serialization.md index 6e9affed04..f1cb6c5e3b 100644 --- a/docs/src/api/serialization.md +++ b/docs/src/api/serialization.md @@ -45,5 +45,5 @@ The generated Python script can be immediately used with JAX and EnzymeAD withou additional Julia dependencies. ```@docs -Reactant.Serialization.export_to_enzymeax +Reactant.Serialization.export_to_enzymejax ``` diff --git a/src/serialization/EnzymeJAX.jl b/src/serialization/EnzymeJAX.jl index 7ac1af3915..1972343d94 100644 --- a/src/serialization/EnzymeJAX.jl +++ b/src/serialization/EnzymeJAX.jl @@ -3,7 +3,7 @@ module EnzymeJAX using ..Reactant: Reactant, Compiler, MLIR, Serialization """ - export_to_enzymeax( + export_to_enzymejax( f, args...; output_dir::String=".", @@ -50,7 +50,7 @@ x = Reactant.to_rarray(Float32[1, 2, 3]) y = (; x=Reactant.to_rarray(Float32[4, 5, 6]), y=Reactant.to_rarray(Float32[7, 8, 9])) # Export to EnzymeJAX -python_file_path = Reactant.Serialization.export_to_enzymeax(my_function, x, y) +python_file_path = Reactant.Serialization.export_to_enzymejax(my_function, x, y) ``` Then in Python: @@ -62,7 +62,7 @@ import jax result = jax.jit(run_my_function)(*inputs) ``` """ -function export_to_enzymeax( +function export_to_enzymejax( f, args...; output_dir::Union{String,Nothing}=nothing, function_name::String=string(f) ) if output_dir === nothing @@ -208,7 +208,7 @@ function _generate_python_script( \"\"\" Auto-generated Python script for calling exported Julia/Reactant function via EnzymeJAX. - This script was generated by Reactant.Serialization.export_to_enzymeax(). + This script was generated by Reactant.Serialization.export_to_enzymejax(). \"\"\" from enzyme_ad.jax import hlo_call diff --git a/src/serialization/Serialization.jl b/src/serialization/Serialization.jl index c0fb2bde84..2d2a4bf184 100644 --- a/src/serialization/Serialization.jl +++ b/src/serialization/Serialization.jl @@ -128,7 +128,7 @@ function export_as_tf_saved_model( end """ - export_to_enzymeax( + export_to_enzymejax( f, args...; output_dir::String=".", @@ -137,10 +137,10 @@ end Export a Julia function to EnzymeJAX format for use in Python/JAX. -See [`EnzymeJAX.export_to_enzymeax`](@ref) for details. +See [`EnzymeJAX.export_to_enzymejax`](@ref) for details. """ -function export_to_enzymeax(f, args...; kwargs...) - return EnzymeJAX.export_to_enzymeax(f, args...; kwargs...) +function export_to_enzymejax(f, args...; kwargs...) + return EnzymeJAX.export_to_enzymejax(f, args...; kwargs...) end end diff --git a/test/export_enzymeax.jl b/test/export_enzymeax.jl index 30939fa450..59aeee8f22 100644 --- a/test/export_enzymeax.jl +++ b/test/export_enzymeax.jl @@ -16,7 +16,7 @@ using Test y = Reactant.to_rarray(Float32[4, 5, 6]) # Export to EnzymeJAX - mlir_path, python_path, input_paths = Reactant.Serialization.export_to_enzymeax( + mlir_path, python_path, input_paths = Reactant.Serialization.export_to_enzymejax( simple_add, x, y; output_dir=tmpdir, function_name="simple_add" @@ -46,7 +46,7 @@ using Test @test filesize(input_path) > 0 end - println("✓ All export_to_enzymeax tests passed!") + println("✓ All export_to_enzymejax tests passed!") println(" - MLIR file created: $(mlir_path)") println(" - Python file created: $(python_path)") println(" - Input files created: $(length(input_paths))") diff --git a/test/export_enzymeax_comprehensive.jl b/test/export_enzymeax_comprehensive.jl index 0f1fc9c132..86e189488c 100644 --- a/test/export_enzymeax_comprehensive.jl +++ b/test/export_enzymeax_comprehensive.jl @@ -15,7 +15,7 @@ using Test y = Reactant.to_rarray(Float32[7 8; 9 10; 11 12]) # 3x2 matrix # Export to EnzymeJAX - mlir_path, python_path, input_paths = Reactant.Serialization.export_to_enzymeax( + mlir_path, python_path, input_paths = Reactant.Serialization.export_to_enzymejax( matrix_multiply, x, y; output_dir=tmpdir, function_name="matrix_multiply" @@ -56,7 +56,7 @@ end y = Reactant.to_rarray(rand(Float32, 28, 28, 1, 4)) # Export to EnzymeJAX - mlir_path, python_path, input_paths = Reactant.Serialization.export_to_enzymeax( + mlir_path, python_path, input_paths = Reactant.Serialization.export_to_enzymejax( add_3d, x, y; output_dir=tmpdir, function_name="add_3d" @@ -86,7 +86,7 @@ end x = Reactant.to_rarray(Float32[1.0, 2.0, 3.0, 4.0]) - mlir_path, python_path, input_paths = Reactant.Serialization.export_to_enzymeax( + mlir_path, python_path, input_paths = Reactant.Serialization.export_to_enzymejax( simple_fn, x; output_dir=tmpdir, function_name="test_fn" From 709101d8e7342c38946ebe0abd45b33cf064e9c1 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 6 Dec 2025 14:22:13 -0500 Subject: [PATCH 11/20] feat: add path to docstring --- src/serialization/EnzymeJAX.jl | 30 ++++++++++++++++++------------ 1 file changed, 18 insertions(+), 12 deletions(-) diff --git a/src/serialization/EnzymeJAX.jl b/src/serialization/EnzymeJAX.jl index 1972343d94..6c8a0b022c 100644 --- a/src/serialization/EnzymeJAX.jl +++ b/src/serialization/EnzymeJAX.jl @@ -74,7 +74,8 @@ function export_to_enzymejax( # Generate the StableHLO/MLIR code using compile_mlir # This returns compilation result with traced argument information - mod, mlir_fn_res = Compiler.compile_mlir(f, args) + argprefix = gensym("exportarg") + mod, mlir_fn_res = Compiler.compile_mlir(f, args; argprefix) hlo_code = string(mod) # Save MLIR code @@ -93,17 +94,22 @@ function export_to_enzymejax( input_info = [] input_idx = 1 for (concrete_arg, traced_arg) in mlir_fn_res.seen_args + path = Reactant.TracedUtils.get_idx(traced_arg, argprefix)[2:end] + # Only process arguments that are in linear_args (skip computed values) - if traced_arg in mlir_fn_res.linear_args - # Save the input (transposed for row-major Python/NumPy) - input_path = joinpath( - output_dir, "$(function_name)_$(fnid)_input_$(input_idx).npy" - ) - _save_transposed_array(input_path, _to_array(concrete_arg)) - push!(input_paths, input_path) - push!(input_info, (shape=size(concrete_arg), dtype=eltype(concrete_arg))) - input_idx += 1 - end + # Save the input (transposed for row-major Python/NumPy) + input_path = joinpath(output_dir, "$(function_name)_$(fnid)_input_$(input_idx).npy") + _save_transposed_array(input_path, _to_array(concrete_arg)) + push!(input_paths, input_path) + push!( + input_info, + ( + shape=size(concrete_arg), + dtype=eltype(concrete_arg), + path="arg." * join(string.(path), "."), + ), + ) + input_idx += 1 end # Generate Python script @@ -197,7 +203,7 @@ function _generate_python_script( # Generate docstring for arguments arg_docs = join( [ - " $(arg_names[i]): Array of shape $(reverse(info.shape)) and dtype $(Serialization.NUMPY_SIMPLE_TYPES[info.dtype])" + " $(arg_names[i]): Array of shape $(reverse(info.shape)) and dtype $(Serialization.NUMPY_SIMPLE_TYPES[info.dtype]). Path: $(info.path)" for (i, info) in enumerate(input_info) ], "\n", From e4ee6b9cb661da264ed03d6595f3bfd249a6827c Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 6 Dec 2025 15:02:59 -0500 Subject: [PATCH 12/20] fix: use NPZ for proper export --- Project.toml | 2 + ext/ReactantNPZExt.jl | 24 +++++ src/Compiler.jl | 10 +- src/serialization/EnzymeJAX.jl | 144 ++++++++++---------------- test/export_enzymeax.jl | 20 ++-- test/export_enzymeax_comprehensive.jl | 56 +++++----- 6 files changed, 120 insertions(+), 136 deletions(-) create mode 100644 ext/ReactantNPZExt.jl diff --git a/Project.toml b/Project.toml index af33c5249a..c2e1409a57 100644 --- a/Project.toml +++ b/Project.toml @@ -39,6 +39,7 @@ GPUCompiler = "61eb1bfa-7361-4325-ad38-22787b887f55" KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c" MPI = "da04e1cc-30fd-572f-bb4f-1f8673147195" NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" +NPZ = "15e1cf62-19b3-5cfa-8e77-841668bca605" OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881" OneHotArrays = "0b1bfda6-eb8a-41d2-88d8-f5af5cad476f" PythonCall = "6099a3de-0909-46bc-b1f4-468b9a2dfc0d" @@ -62,6 +63,7 @@ ReactantFloat8sExt = "Float8s" ReactantKernelAbstractionsExt = "KernelAbstractions" ReactantMPIExt = "MPI" ReactantNNlibExt = ["NNlib", "Statistics"] +ReactantNPZExt = "NPZ" ReactantOffsetArraysExt = "OffsetArrays" ReactantOneHotArraysExt = "OneHotArrays" ReactantPythonCallExt = "PythonCall" diff --git a/ext/ReactantNPZExt.jl b/ext/ReactantNPZExt.jl new file mode 100644 index 0000000000..1565b2d387 --- /dev/null +++ b/ext/ReactantNPZExt.jl @@ -0,0 +1,24 @@ +module ReactantNPZExt + +using NPZ: npzwrite +using Reactant.Serialization: Serialization, EnzymeJAX + +Serialization.serialization_supported(::Val{:NPZ}) = true + +# Helper function to save all input data to a single NPZ file +function EnzymeJAX.save_inputs_npz_impl( + output_path::String, inputs::Dict{String,<:Union{AbstractArray,Number}} +) + # Transpose arrays for Python/NumPy (row-major vs column-major) + transposed_inputs = Dict{String,Union{AbstractArray,Number}}() + for (name, arr) in inputs + transposed_inputs[name] = + arr isa Number ? arr : permutedims(arr, reverse(1:ndims(arr))) + end + + # Save all inputs to a single NPZ file with compression + npzwrite(output_path, transposed_inputs) + return output_path +end + +end # module diff --git a/src/Compiler.jl b/src/Compiler.jl index 7beda792f4..1959d46e8d 100644 --- a/src/Compiler.jl +++ b/src/Compiler.jl @@ -1411,7 +1411,7 @@ function __get_compile_options_and_kwargs(; ) end -function compile_mlir(f, args; client=nothing, kwargs...) +function compile_mlir(f, args; client=nothing, drop_unsupported_attributes=false, kwargs...) client = client !== nothing ? client : XLA.default_backend() backend = XLA.platform_name(client) @@ -1441,6 +1441,11 @@ function compile_mlir(f, args; client=nothing, kwargs...) mod, f; mlir_fn_res.num_partitions, mlir_fn_res.num_replicas ) + if drop_unsupported_attributes + # Drop some of our attributes + run_pass_pipeline!(mod, "drop-unsupported-attributes", "drop_enzymexla_attributes") + end + return mod, mlir_fn_res end @@ -3584,9 +3589,6 @@ function compile_xla( module_string = "" end - # Drop some of our attributes - run_pass_pipeline!(mod, "drop-unsupported-attributes", "drop_enzymexla_attributes") - if before_xla_optimizations exec = nothing hlo_modules = XLA.HloModule(mod) diff --git a/src/serialization/EnzymeJAX.jl b/src/serialization/EnzymeJAX.jl index 6c8a0b022c..4cc066d1bc 100644 --- a/src/serialization/EnzymeJAX.jl +++ b/src/serialization/EnzymeJAX.jl @@ -6,8 +6,8 @@ using ..Reactant: Reactant, Compiler, MLIR, Serialization export_to_enzymejax( f, args...; - output_dir::String=".", - function_name::String="exported_function", + output_dir::Union{String,Nothing}=nothing, + function_name::String=string(f) ) Export a Julia function to EnzymeJAX format for use in Python/JAX. @@ -15,9 +15,14 @@ Export a Julia function to EnzymeJAX format for use in Python/JAX. This function: 1. Compiles the function to StableHLO via `Reactant.@code_hlo` 2. Saves the MLIR/StableHLO code to a `.mlir` file -3. Saves input arrays to `.npy` files (transposed to account for row-major vs column-major) +3. Saves all input arrays to a single compressed `.npz` file (transposed to account for + row-major vs column-major) 4. Generates a Python script with the function wrapped for EnzymeJAX's `hlo_call` +## Requirements + +- **NPZ.jl**: Must be loaded with `using NPZ` for compression support + ## Arguments - `f`: The Julia function to export @@ -25,32 +30,37 @@ This function: ## Keyword Arguments - - `output_dir::String="."`: Directory where output files will be saved - - `function_name::String="exported_function"`: Base name for generated files + - `output_dir::Union{String,Nothing}`: Directory where output files will be saved. If + `nothing`, uses a temporary directory and prints the path. + - `function_name::String`: Base name for generated files ## Returns -A tuple `(mlir_path, python_path, input_paths)` containing paths to: - - The generated `.mlir` file - - The generated `.py` file - - A vector of paths to input `.npy` files +The path to the generated Python script as a `String`. + +## Files Generated + + - `{function_name}.mlir`: The StableHLO/MLIR module + - `{function_name}_{id}_inputs.npz`: Compressed NPZ file containing all input arrays + - `{function_name}.py`: Python script with the function wrapped for EnzymeJAX ## Example ```julia -using Reactant +using Reactant, NPZ # Define a simple function -function my_function(x, y::NamedTuple) - return x .+ y.x .- y.y +function my_function(x::AbstractArray, y::NamedTuple, z::Number) + return x .+ y.x .- y.y .+ z end # Create some example inputs x = Reactant.to_rarray(Float32[1, 2, 3]) y = (; x=Reactant.to_rarray(Float32[4, 5, 6]), y=Reactant.to_rarray(Float32[7, 8, 9])) +z = Reactant.to_rarray(10.0f0; track_numbers=true) # Export to EnzymeJAX -python_file_path = Reactant.Serialization.export_to_enzymejax(my_function, x, y) +python_file_path = Reactant.Serialization.export_to_enzymejax(my_function, x, y, z) ``` Then in Python: @@ -75,7 +85,9 @@ function export_to_enzymejax( # Generate the StableHLO/MLIR code using compile_mlir # This returns compilation result with traced argument information argprefix = gensym("exportarg") - mod, mlir_fn_res = Compiler.compile_mlir(f, args; argprefix) + mod, mlir_fn_res = Compiler.compile_mlir( + f, args; argprefix, drop_unsupported_attributes=true + ) hlo_code = string(mod) # Save MLIR code @@ -90,114 +102,66 @@ function export_to_enzymejax( # seen_args is an OrderedIdDict where keys are concrete args and values are traced args # linear_args contains only the arguments that need to be passed to the function # We iterate over seen_args which preserves the order, and only save those in linear_args - input_paths = String[] + input_data = Dict{String,Union{AbstractArray,Number}}() input_info = [] input_idx = 1 for (concrete_arg, traced_arg) in mlir_fn_res.seen_args path = Reactant.TracedUtils.get_idx(traced_arg, argprefix)[2:end] - # Only process arguments that are in linear_args (skip computed values) - # Save the input (transposed for row-major Python/NumPy) - input_path = joinpath(output_dir, "$(function_name)_$(fnid)_input_$(input_idx).npy") - _save_transposed_array(input_path, _to_array(concrete_arg)) - push!(input_paths, input_path) + # Store input data for the single NPZ file + arr_key = "arr_$input_idx" + input_data[arr_key] = _to_array(concrete_arg) push!( input_info, ( shape=size(concrete_arg), - dtype=eltype(concrete_arg), + dtype=Reactant.unwrapped_eltype(concrete_arg), path="arg." * join(string.(path), "."), + key=arr_key, ), ) input_idx += 1 end + # Save all inputs to a single NPZ file + input_path = joinpath(output_dir, "$(function_name)_$(fnid)_inputs.npz") + save_inputs_npz(input_path, input_data) + # Generate Python script python_path = joinpath(output_dir, "$(function_name).py") - _generate_python_script(python_path, function_name, mlir_path, input_paths, input_info) + _generate_python_script(python_path, function_name, mlir_path, input_path, input_info) return python_path end _to_array(x::Reactant.ConcreteRArray) = Array(x) -_to_array(x::Reactant.ConcreteRNumber) = Number(x) - -# Save an array to a .npy file, transposing to account for row-major vs -# column-major ordering. -function _save_transposed_array(path::String, arr::AbstractArray) - # For multi-dimensional arrays, we need to reverse the dimensions for Python/NumPy - transposed = permutedims(arr, reverse(1:ndims(arr))) - - # Use a simple .npy writer - # NPY format v1.0: magic (6 bytes) + version (2 bytes) + header_len (2 bytes) + - # header + data - open(path, "w") do io - # Magic number for .npy format - write(io, UInt8[0x93, 0x4E, 0x55, 0x4D, 0x50, 0x59]) - # Version 1.0 - write(io, UInt8[0x01, 0x00]) - - # Prepare header - dtype_str = _numpy_dtype_string(eltype(arr)) - shape_str = join(size(transposed), ", ") - header = "{'descr': '$(dtype_str)', 'fortran_order': False, 'shape': ($(shape_str),)}" - - # Pad header to be aligned on 64 bytes (16-byte alignment for v1.0) - # Total size needs to be divisible by 16 - header_len = length(header) + 1 # +1 for newline - total_len = 10 + header_len # 10 = magic(6) + version(2) + header_len(2) - padding = (16 - (total_len % 16)) % 16 - header = header * " "^padding * "\n" - header_len = length(header) - - # Write header length (little-endian UInt16) - write(io, UInt16(header_len)) - # Write header - write(io, header) - # Write data - write(io, vec(transposed)) +_to_array(x::Reactant.ConcreteRNumber{T}) where {T} = T(x) + +function save_inputs_npz( + output_path::String, inputs::Dict{String,<:Union{AbstractArray,Number}} +) + if !Serialization.serialization_supported(Val(:NPZ)) + error("`NPZ.jl` is required for saving compressed arrays. Please load it with \ + `using NPZ` and try again.") end - return nothing + return save_inputs_npz_impl(output_path, inputs) end -# TODO: use a proper package for this -_numpy_dtype_string(::Type{Bool}) = "|b1" -_numpy_dtype_string(::Type{Int8}) = "|i1" -_numpy_dtype_string(::Type{UInt8}) = "|u1" -_numpy_dtype_string(::Type{Int16}) = " 0 end - + println("✓ All export_to_enzymejax tests passed!") println(" - MLIR file created: $(mlir_path)") println(" - Python file created: $(python_path)") diff --git a/test/export_enzymeax_comprehensive.jl b/test/export_enzymeax_comprehensive.jl index 86e189488c..aad88147d4 100644 --- a/test/export_enzymeax_comprehensive.jl +++ b/test/export_enzymeax_comprehensive.jl @@ -3,39 +3,37 @@ using Test @testset "Export to EnzymeJAX - Multi-dimensional Arrays" begin tmpdir = mktempdir() - + try # Define a function with 2D arrays function matrix_multiply(x, y) return x * y end - + # Create 2D arrays - Julia uses column-major order x = Reactant.to_rarray(Float32[1 2 3; 4 5 6]) # 2x3 matrix y = Reactant.to_rarray(Float32[7 8; 9 10; 11 12]) # 3x2 matrix - + # Export to EnzymeJAX mlir_path, python_path, input_paths = Reactant.Serialization.export_to_enzymejax( - matrix_multiply, x, y; - output_dir=tmpdir, - function_name="matrix_multiply" + matrix_multiply, x, y; output_dir=tmpdir, function_name="matrix_multiply" ) - + @test isfile(mlir_path) @test isfile(python_path) @test length(input_paths) == 2 - + # Read Python file and check for correct shape information python_content = read(python_path, String) - + # The shapes should be transposed for Python (row-major) # Julia x: (2, 3) -> Python: (3, 2) # Julia y: (3, 2) -> Python: (2, 3) @test occursin("(3, 2)", python_content) # Transposed shape of x @test occursin("(2, 3)", python_content) # Transposed shape of y - + println("✓ Multi-dimensional array export test passed!") - + finally rm(tmpdir; recursive=true, force=true) end @@ -43,34 +41,32 @@ end @testset "Export to EnzymeJAX - 3D Arrays" begin tmpdir = mktempdir() - + try # Define a function with 3D arrays (like image data) function add_3d(x, y) return x .+ y end - + # Create 3D arrays - e.g., (height, width, channels, batch) # Julia: (28, 28, 1, 4) -> Python: (4, 1, 28, 28) x = Reactant.to_rarray(rand(Float32, 28, 28, 1, 4)) y = Reactant.to_rarray(rand(Float32, 28, 28, 1, 4)) - + # Export to EnzymeJAX mlir_path, python_path, input_paths = Reactant.Serialization.export_to_enzymejax( - add_3d, x, y; - output_dir=tmpdir, - function_name="add_3d" + add_3d, x, y; output_dir=tmpdir, function_name="add_3d" ) - + @test isfile(mlir_path) @test isfile(python_path) - + # Check that Python file mentions the transposed shape python_content = read(python_path, String) @test occursin("(4, 1, 28, 28)", python_content) - + println("✓ 3D array export test passed!") - + finally rm(tmpdir; recursive=true, force=true) end @@ -78,24 +74,22 @@ end @testset "Export to EnzymeJAX - File Content Verification" begin tmpdir = mktempdir() - + try function simple_fn(x) return x .* 2.0f0 end - + x = Reactant.to_rarray(Float32[1.0, 2.0, 3.0, 4.0]) - + mlir_path, python_path, input_paths = Reactant.Serialization.export_to_enzymejax( - simple_fn, x; - output_dir=tmpdir, - function_name="test_fn" + simple_fn, x; output_dir=tmpdir, function_name="test_fn" ) - + # Verify MLIR contains necessary elements mlir_content = read(mlir_path, String) @test occursin("module", mlir_content) - + # Verify Python file structure python_content = read(python_path, String) @test occursin("import jax", python_content) @@ -104,9 +98,9 @@ end @test occursin("def run_test_fn(arg1)", python_content) @test occursin("source=_hlo_code", python_content) @test occursin("jax.jit(run_test_fn)", python_content) - + println("✓ File content verification test passed!") - + finally rm(tmpdir; recursive=true, force=true) end From ed16977d50d6998bddc321af5a01f6c2a5a24ccd Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 6 Dec 2025 15:42:15 -0500 Subject: [PATCH 13/20] feat: add size checks --- docs/src/api/serialization.md | 9 ++++----- src/serialization/EnzymeJAX.jl | 26 +++++++++++++++++++++++--- src/serialization/Serialization.jl | 16 +--------------- 3 files changed, 28 insertions(+), 23 deletions(-) diff --git a/docs/src/api/serialization.md b/docs/src/api/serialization.md index f1cb6c5e3b..c61cb08a89 100644 --- a/docs/src/api/serialization.md +++ b/docs/src/api/serialization.md @@ -30,15 +30,14 @@ Reactant.Serialization.export_as_tf_saved_model ## Exporting to JAX via EnzymeAD -!!! note "No Dependencies Required" +!!! note "Load NPZ" - Unlike TensorFlow SavedModel export, exporting to JAX via EnzymeAD does not require any - Python dependencies at build time. It generates standalone files that can be used with - EnzymeAD/JAX in Python. + This export functionality requires the `NPZ` package to be loaded. This export functionality generates: + 1. A `.mlir` file containing the StableHLO representation of your Julia function -2. Example input `.npy` files with properly transposed arrays (column-major → row-major) +2. Input `.npz` files containing the input arrays for the function 3. A Python script that wraps the function for use with `enzyme_ad.jax.hlo_call` The generated Python script can be immediately used with JAX and EnzymeAD without any diff --git a/src/serialization/EnzymeJAX.jl b/src/serialization/EnzymeJAX.jl index 4cc066d1bc..24fa3cd20b 100644 --- a/src/serialization/EnzymeJAX.jl +++ b/src/serialization/EnzymeJAX.jl @@ -33,6 +33,8 @@ This function: - `output_dir::Union{String,Nothing}`: Directory where output files will be saved. If `nothing`, uses a temporary directory and prints the path. - `function_name::String`: Base name for generated files + - `preserve_sharding::Bool`: Whether to preserve sharding information in the exported + function. Defaults to `true`. ## Returns @@ -55,8 +57,11 @@ function my_function(x::AbstractArray, y::NamedTuple, z::Number) end # Create some example inputs -x = Reactant.to_rarray(Float32[1, 2, 3]) -y = (; x=Reactant.to_rarray(Float32[4, 5, 6]), y=Reactant.to_rarray(Float32[7, 8, 9])) +x = Reactant.to_rarray(reshape(collect(Float32, 1:6), 2, 3)) +y = (; + x=Reactant.to_rarray(reshape(collect(Float32, 7:12), 2, 3)), + y=Reactant.to_rarray(reshape(collect(Float32, 13:18), 2, 3)) +) z = Reactant.to_rarray(10.0f0; track_numbers=true) # Export to EnzymeJAX @@ -73,7 +78,11 @@ result = jax.jit(run_my_function)(*inputs) ``` """ function export_to_enzymejax( - f, args...; output_dir::Union{String,Nothing}=nothing, function_name::String=string(f) + f, + args...; + output_dir::Union{String,Nothing}=nothing, + function_name::String=string(f), + preserve_sharding::Bool=true, ) if output_dir === nothing output_dir = mktempdir(; cleanup=false) @@ -173,6 +182,15 @@ function _generate_python_script( "\n", ) + arg_size_checks = [ + "assert $(arg_names[i]).shape == $(reverse(info.shape)), f\"Expected shape of $(arg_names[i]) to be $(reverse(info.shape)). Got {$(arg_names[i]).shape} (path: $(info.path))\"" + for (i, info) in enumerate(input_info) + ] + arg_dtype_checks = [ + "assert $(arg_names[i]).dtype == np.dtype('$(Serialization.NUMPY_SIMPLE_TYPES[info.dtype])'), f\"Expected dtype of $(arg_names[i]) to be $(Serialization.NUMPY_SIMPLE_TYPES[info.dtype]). Got {$(arg_names[i]).dtype} (path: $(info.path))\"" + for (i, info) in enumerate(input_info) + ] + load_inputs = ["npz_data['$(info.key)']" for info in input_info] # Build the complete Python script @@ -217,6 +235,8 @@ function _generate_python_script( arrays from Julia, make sure to transpose them first using: \`permutedims(arr, reverse(1:ndims(arr)))\` \"\"\" + $(join(arg_dtype_checks, "\n ")) + $(join(arg_size_checks, "\n ")) return hlo_call( $(arg_list), source=_hlo_code, diff --git a/src/serialization/Serialization.jl b/src/serialization/Serialization.jl index 2d2a4bf184..5fdc28c297 100644 --- a/src/serialization/Serialization.jl +++ b/src/serialization/Serialization.jl @@ -127,20 +127,6 @@ function export_as_tf_saved_model( ) end -""" - export_to_enzymejax( - f, - args...; - output_dir::String=".", - function_name::String="exported_function", - ) - -Export a Julia function to EnzymeJAX format for use in Python/JAX. - -See [`EnzymeJAX.export_to_enzymejax`](@ref) for details. -""" -function export_to_enzymejax(f, args...; kwargs...) - return EnzymeJAX.export_to_enzymejax(f, args...; kwargs...) -end +const export_to_enzymejax = EnzymeJAX.export_to_enzymejax end From 517eb449caa667d8484310a5ab9542064c99b225 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 6 Dec 2025 15:44:06 -0500 Subject: [PATCH 14/20] feat: automatically run jit --- src/serialization/EnzymeJAX.jl | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/serialization/EnzymeJAX.jl b/src/serialization/EnzymeJAX.jl index 24fa3cd20b..0388d8bfca 100644 --- a/src/serialization/EnzymeJAX.jl +++ b/src/serialization/EnzymeJAX.jl @@ -74,7 +74,7 @@ Then in Python: from exported.my_function import run_my_function import jax -result = jax.jit(run_my_function)(*inputs) +result = run_my_function(*inputs) ``` """ function export_to_enzymejax( @@ -220,6 +220,7 @@ function _generate_python_script( inputs = [$(join(load_inputs, ", "))] return tuple(inputs) + @jax.jit def run_$(function_name)($(arg_list)): \"\"\" Call the exported Julia function via EnzymeJAX. @@ -248,7 +249,7 @@ function _generate_python_script( # Run the function (with JIT compilation) print(\"Running $(function_name) with JIT compilation...\") - result = jax.jit(run_$(function_name))(*inputs) + result = run_$(function_name)(*inputs) print(\"Result:\", result) """ From 2a91d8a4bf98155b6aeb6efe507c72e3ebc1db6e Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 6 Dec 2025 16:37:15 -0500 Subject: [PATCH 15/20] feat: preserve sharding --- docs/src/api/serialization.md | 2 +- src/Sharding.jl | 4 +- src/serialization/EnzymeJAX.jl | 126 ++++++++++++++++++++++++++++++--- 3 files changed, 121 insertions(+), 11 deletions(-) diff --git a/docs/src/api/serialization.md b/docs/src/api/serialization.md index c61cb08a89..bfe77a6c98 100644 --- a/docs/src/api/serialization.md +++ b/docs/src/api/serialization.md @@ -44,5 +44,5 @@ The generated Python script can be immediately used with JAX and EnzymeAD withou additional Julia dependencies. ```@docs -Reactant.Serialization.export_to_enzymejax +Reactant.Serialization.EnzymeJAX.export_to_enzymejax ``` diff --git a/src/Sharding.jl b/src/Sharding.jl index 8b357292e7..851a72fd79 100644 --- a/src/Sharding.jl +++ b/src/Sharding.jl @@ -949,7 +949,7 @@ function HloSharding(sharding::NamedSharding, client::XLA.IFRT.Client, _, x) data = XLA.IFRT.AsyncArray(client, x, ifrt_sharding) # XXX: Can we auto-pad this case too? Will think about it later, for now use - # NamedSharidng + # NamedSharding return data, ShardInfo(hlo_sharding, device_to_array_slices), nothing end @@ -997,7 +997,7 @@ function (sharding::HloSharding)( data = XLA.IFRT.AsyncArray(client, x, ifrt_sharding) # XXX: Can we auto-pad this case too? Will think about it later, for now use - # NamedSharidng + # NamedSharding return data, ShardInfo(sharding, device_to_array_slices), nothing end diff --git a/src/serialization/EnzymeJAX.jl b/src/serialization/EnzymeJAX.jl index 0388d8bfca..b13120b374 100644 --- a/src/serialization/EnzymeJAX.jl +++ b/src/serialization/EnzymeJAX.jl @@ -95,7 +95,12 @@ function export_to_enzymejax( # This returns compilation result with traced argument information argprefix = gensym("exportarg") mod, mlir_fn_res = Compiler.compile_mlir( - f, args; argprefix, drop_unsupported_attributes=true + f, + args; + argprefix, + drop_unsupported_attributes=true, + # to support older jax versions which don't support shardy + shardy_passes=:to_mhlo_shardings, ) hlo_code = string(mod) @@ -120,6 +125,13 @@ function export_to_enzymejax( # Store input data for the single NPZ file arr_key = "arr_$input_idx" input_data[arr_key] = _to_array(concrete_arg) + + # Extract sharding information if available and if preserve_sharding is true + sharding_info = nothing + if preserve_sharding && _has_sharding_info(concrete_arg) + sharding_info = _extract_sharding_info(concrete_arg) + end + push!( input_info, ( @@ -127,6 +139,7 @@ function export_to_enzymejax( dtype=Reactant.unwrapped_eltype(concrete_arg), path="arg." * join(string.(path), "."), key=arr_key, + sharding=sharding_info, ), ) input_idx += 1 @@ -138,13 +151,40 @@ function export_to_enzymejax( # Generate Python script python_path = joinpath(output_dir, "$(function_name).py") - _generate_python_script(python_path, function_name, mlir_path, input_path, input_info) + _generate_python_script( + python_path, function_name, mlir_path, input_path, input_info; preserve_sharding + ) return python_path end _to_array(x::Reactant.ConcreteRArray) = Array(x) _to_array(x::Reactant.ConcreteRNumber{T}) where {T} = T(x) +_has_sharding_info(x::Reactant.ConcreteRArray) = Reactant.Sharding.is_sharded(x.sharding) +_has_sharding_info(x) = false + +function _extract_sharding_info(x::Reactant.ConcreteRArray) + sharding = x.sharding + if sharding isa Reactant.Sharding.ShardInfo + inner_sharding = sharding.sharding + if inner_sharding isa Reactant.Sharding.NamedSharding + # TODO: we need to export is_closed, priority, and subaxes at some point + return (; + type="NamedSharding", + mesh=inner_sharding.mesh, + partition_spec=inner_sharding.partition_spec, + ) + elseif inner_sharding isa Reactant.Sharding.Replicated + return (; type="Replicated", mesh=inner_sharding.mesh) + elseif inner_sharding isa Reactant.Sharding.NoSharding + return (; type="NoSharding") + else + error("Unsupported sharding type: $(typeof(inner_sharding))") + end + end + return (; type="NoSharding") +end + function save_inputs_npz( output_path::String, inputs::Dict{String,<:Union{AbstractArray,Number}} ) @@ -162,7 +202,8 @@ function _generate_python_script( function_name::String, mlir_path::String, input_path::String, - input_info::Vector, + input_info::Vector; + preserve_sharding::Bool=true, ) # Get relative paths for the Python script output_dir = dirname(python_path) @@ -191,6 +232,74 @@ function _generate_python_script( for (i, info) in enumerate(input_info) ] + # Generate sharding annotations if available + has_any_sharding = + preserve_sharding && any(info.sharding !== nothing for info in input_info) + + device_put_calls = String[] + if has_any_sharding + inserted_meshes = IdDict() + counter = 0 + for (i, info) in enumerate(input_info) + if info.sharding !== nothing + if haskey(inserted_meshes, info.sharding.mesh) + pymesh = inserted_meshes[info.sharding.mesh] + else + pymesh = "mesh$counter" + counter += 1 + inserted_meshes[info.sharding.mesh] = pymesh + axis_sizes = join(string.(reverse(info.sharding.mesh.axis_sizes)), ", ") + mesh_axes = join( + reverse(["'$(string(x))'" for x in info.sharding.mesh.axis_names]), + ", ", + ) + + push!( + device_put_calls, + "$(pymesh) = jax.make_mesh(($(axis_sizes)), ($(mesh_axes)))", + ) + end + + push!( + device_put_calls, + "# Set up sharding for $(arg_names[i]): $(info.sharding.type)", + ) + + # Create device_put call with NamedSharding + if info.sharding.type == "NoSharding" + device_put_calls_str = "$(arg_names[i]) = jnp.asarray($(arg_names[i]))" + elseif info.sharding.type == "NamedSharding" + pstrings = [ + if length(p) == 1 + p[1] isa Nothing ? "None" : "'$(string(p[1]))'" + else + join(string.(reverse(p)), ", ") + end for p in info.sharding.partition_spec + ] + partition_spec = join(reverse(pstrings), ", ") + device_put_calls_str = "$(arg_names[i]) = jax.device_put($(arg_names[i]), jax.sharding.NamedSharding($(pymesh), P($(partition_spec))))" + else + error("Unsupported sharding type: $(info.sharding.type)") + end + push!(device_put_calls, device_put_calls_str) + end + end + end + + if has_any_sharding + inputs_to_jax_arrays = """# Apply sharding to inputs using device_put and NamedSharding + $(join(device_put_calls, "\n ")) + """ + else + convert_str_list = join( + [" $(argname) = jnp.asarray($(argname))" for argname in arg_names], "\n" + ) + inputs_to_jax_arrays = """ + # Convert inputs to jax arrays + $(convert_str_list) + """ + end + load_inputs = ["npz_data['$(info.key)']" for info in input_info] # Build the complete Python script @@ -203,6 +312,7 @@ function _generate_python_script( from enzyme_ad.jax import hlo_call import jax + from jax.sharding import PartitionSpec as P import jax.numpy as jnp import numpy as np import os @@ -245,11 +355,11 @@ function _generate_python_script( if __name__ == \"__main__\": # Load the example inputs - inputs = load_inputs() - - # Run the function (with JIT compilation) - print(\"Running $(function_name) with JIT compilation...\") - result = run_$(function_name)(*inputs) + ($(arg_list),) = load_inputs() + $(inputs_to_jax_arrays) + # Run the function + print(\"Running $(function_name)...\") + result = run_$(function_name)($(arg_list)) print(\"Result:\", result) """ From 8e459e2e77b3f28468a97231de7ce87d5a090bd3 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 6 Dec 2025 18:11:05 -0500 Subject: [PATCH 16/20] test: exported functions --- CondaPkg.toml | 6 +- Project.toml | 1 + src/Compiler.jl | 3 + src/serialization/EnzymeJAX.jl | 2 +- test/Project.toml | 5 +- test/export_enzymeax.jl | 55 ------- test/export_enzymeax_comprehensive.jl | 109 ------------- test/integration/enzymejax.jl | 217 ++++++++++++++++++++++++++ test/runtests.jl | 17 +- 9 files changed, 245 insertions(+), 170 deletions(-) delete mode 100644 test/export_enzymeax.jl delete mode 100644 test/export_enzymeax_comprehensive.jl create mode 100644 test/integration/enzymejax.jl diff --git a/CondaPkg.toml b/CondaPkg.toml index b1db4f8e75..52a193ead9 100644 --- a/CondaPkg.toml +++ b/CondaPkg.toml @@ -1,7 +1,7 @@ [deps] -python = "<=3.13,>=3.9,<4" +python = "<=3.12,>=3.9,<4" [pip.deps] -jax = ">= 0.6" +jax = ">= 0.5" tensorflow = ">= 2.17" -numpy = ">= 2" +numpy = ">= 1, >= 2" diff --git a/Project.toml b/Project.toml index c2e1409a57..1d23deea97 100644 --- a/Project.toml +++ b/Project.toml @@ -98,6 +98,7 @@ Libdl = "1.10" LinearAlgebra = "1.10" MPI = "0.20" NNlib = "0.9.26" +NPZ = "0.4" OffsetArrays = "1" OneHotArrays = "0.2.10" OrderedCollections = "1" diff --git a/src/Compiler.jl b/src/Compiler.jl index 1959d46e8d..6b08d0b5af 100644 --- a/src/Compiler.jl +++ b/src/Compiler.jl @@ -3576,6 +3576,9 @@ function compile_xla( mod, f; mlir_fn_res.num_partitions, mlir_fn_res.num_replicas ) + # Drop some of our attributes + run_pass_pipeline!(mod, "drop-unsupported-attributes", "drop_enzymexla_attributes") + # compile MLIR module to XLA executable global_device_ids = collect(Int64, mlir_fn_res.global_device_ids) mlir_fn_res.is_sharded && (device = nothing) diff --git a/src/serialization/EnzymeJAX.jl b/src/serialization/EnzymeJAX.jl index b13120b374..7c103be8a7 100644 --- a/src/serialization/EnzymeJAX.jl +++ b/src/serialization/EnzymeJAX.jl @@ -1,6 +1,6 @@ module EnzymeJAX -using ..Reactant: Reactant, Compiler, MLIR, Serialization +using ..Reactant: Reactant, Compiler, Serialization """ export_to_enzymejax( diff --git a/test/Project.toml b/test/Project.toml index ebb2029081..29584d8c7c 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -3,6 +3,7 @@ Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" +CondaPkg = "992eb4ea-22a4-4c89-a5bb-47a3300528ab" DLFP8Types = "f4c16678-4a16-415b-82ef-ed337c5d6c7c" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" @@ -18,9 +19,10 @@ KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Lux = "b2108857-7c20-44ae-9111-449ecde12c47" LuxLib = "82251201-b29d-42c6-8e01-566dec8acb11" -MethodAnalysis = "85b6ec6f-f7df-4429-9514-a64bcd9ee824" MPI = "da04e1cc-30fd-572f-bb4f-1f8673147195" +MethodAnalysis = "85b6ec6f-f7df-4429-9514-a64bcd9ee824" NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" +NPZ = "15e1cf62-19b3-5cfa-8e77-841668bca605" OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881" OneHotArrays = "0b1bfda6-eb8a-41d2-88d8-f5af5cad476f" Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" @@ -56,6 +58,7 @@ Lux = "1.21" LuxLib = "1.11" MPI = "0.20" NNlib = "0.9.26" +NPZ = "0.4" OffsetArrays = "1" OneHotArrays = "0.2.6" Optimisers = "0.4" diff --git a/test/export_enzymeax.jl b/test/export_enzymeax.jl deleted file mode 100644 index ffb73a8a78..0000000000 --- a/test/export_enzymeax.jl +++ /dev/null @@ -1,55 +0,0 @@ -using Reactant -using Test - -@testset "Export to EnzymeJAX" begin - # Create a temporary directory for the export - tmpdir = mktempdir() - - try - # Define a simple function - function simple_add(x, y) - return x .+ y - end - - # Create some example inputs - x = Reactant.to_rarray(Float32[1, 2, 3]) - y = Reactant.to_rarray(Float32[4, 5, 6]) - - # Export to EnzymeJAX - mlir_path, python_path, input_paths = Reactant.Serialization.export_to_enzymejax( - simple_add, x, y; output_dir=tmpdir, function_name="simple_add" - ) - - # Verify that all files were created - @test isfile(mlir_path) - @test isfile(python_path) - @test length(input_paths) == 2 - @test all(isfile, input_paths) - - # Verify MLIR file has content - mlir_content = read(mlir_path, String) - @test !isempty(mlir_content) - @test occursin("module", mlir_content) - - # Verify Python file has content and correct structure - python_content = read(python_path, String) - @test !isempty(python_content) - @test occursin("from enzyme_ad.jax import hlo_call", python_content) - @test occursin("def run_simple_add", python_content) - @test occursin("def load_inputs", python_content) - @test occursin("if __name__ == \"__main__\":", python_content) - - # Verify input files exist and have reasonable sizes - for input_path in input_paths - @test filesize(input_path) > 0 - end - - println("✓ All export_to_enzymejax tests passed!") - println(" - MLIR file created: $(mlir_path)") - println(" - Python file created: $(python_path)") - println(" - Input files created: $(length(input_paths))") - finally - # Clean up - rm(tmpdir; recursive=true, force=true) - end -end diff --git a/test/export_enzymeax_comprehensive.jl b/test/export_enzymeax_comprehensive.jl deleted file mode 100644 index aad88147d4..0000000000 --- a/test/export_enzymeax_comprehensive.jl +++ /dev/null @@ -1,109 +0,0 @@ -using Reactant -using Test - -@testset "Export to EnzymeJAX - Multi-dimensional Arrays" begin - tmpdir = mktempdir() - - try - # Define a function with 2D arrays - function matrix_multiply(x, y) - return x * y - end - - # Create 2D arrays - Julia uses column-major order - x = Reactant.to_rarray(Float32[1 2 3; 4 5 6]) # 2x3 matrix - y = Reactant.to_rarray(Float32[7 8; 9 10; 11 12]) # 3x2 matrix - - # Export to EnzymeJAX - mlir_path, python_path, input_paths = Reactant.Serialization.export_to_enzymejax( - matrix_multiply, x, y; output_dir=tmpdir, function_name="matrix_multiply" - ) - - @test isfile(mlir_path) - @test isfile(python_path) - @test length(input_paths) == 2 - - # Read Python file and check for correct shape information - python_content = read(python_path, String) - - # The shapes should be transposed for Python (row-major) - # Julia x: (2, 3) -> Python: (3, 2) - # Julia y: (3, 2) -> Python: (2, 3) - @test occursin("(3, 2)", python_content) # Transposed shape of x - @test occursin("(2, 3)", python_content) # Transposed shape of y - - println("✓ Multi-dimensional array export test passed!") - - finally - rm(tmpdir; recursive=true, force=true) - end -end - -@testset "Export to EnzymeJAX - 3D Arrays" begin - tmpdir = mktempdir() - - try - # Define a function with 3D arrays (like image data) - function add_3d(x, y) - return x .+ y - end - - # Create 3D arrays - e.g., (height, width, channels, batch) - # Julia: (28, 28, 1, 4) -> Python: (4, 1, 28, 28) - x = Reactant.to_rarray(rand(Float32, 28, 28, 1, 4)) - y = Reactant.to_rarray(rand(Float32, 28, 28, 1, 4)) - - # Export to EnzymeJAX - mlir_path, python_path, input_paths = Reactant.Serialization.export_to_enzymejax( - add_3d, x, y; output_dir=tmpdir, function_name="add_3d" - ) - - @test isfile(mlir_path) - @test isfile(python_path) - - # Check that Python file mentions the transposed shape - python_content = read(python_path, String) - @test occursin("(4, 1, 28, 28)", python_content) - - println("✓ 3D array export test passed!") - - finally - rm(tmpdir; recursive=true, force=true) - end -end - -@testset "Export to EnzymeJAX - File Content Verification" begin - tmpdir = mktempdir() - - try - function simple_fn(x) - return x .* 2.0f0 - end - - x = Reactant.to_rarray(Float32[1.0, 2.0, 3.0, 4.0]) - - mlir_path, python_path, input_paths = Reactant.Serialization.export_to_enzymejax( - simple_fn, x; output_dir=tmpdir, function_name="test_fn" - ) - - # Verify MLIR contains necessary elements - mlir_content = read(mlir_path, String) - @test occursin("module", mlir_content) - - # Verify Python file structure - python_content = read(python_path, String) - @test occursin("import jax", python_content) - @test occursin("import numpy as np", python_content) - @test occursin("from enzyme_ad.jax import hlo_call", python_content) - @test occursin("def run_test_fn(arg1)", python_content) - @test occursin("source=_hlo_code", python_content) - @test occursin("jax.jit(run_test_fn)", python_content) - - println("✓ File content verification test passed!") - - finally - rm(tmpdir; recursive=true, force=true) - end -end - -println("\n✅ All comprehensive tests passed!") diff --git a/test/integration/enzymejax.jl b/test/integration/enzymejax.jl new file mode 100644 index 0000000000..701f08866e --- /dev/null +++ b/test/integration/enzymejax.jl @@ -0,0 +1,217 @@ +using Reactant, Test, NPZ, PythonCall + +function run_exported_enzymejax_function(python_file_path::String, function_name::String) + output_dir = dirname(python_file_path) + + sys = pyimport("sys") + importlib = pyimport("importlib.util") + sys.path.insert(0, "$(output_dir)") + + spec = importlib.spec_from_file_location("test_module", "$(python_file_path)") + mod = importlib.module_from_spec(spec) + spec.loader.exec_module(mod) + + loaded_inputs = pygetattr(mod, "load_inputs")() + res = pyconvert(Array, pygetattr(mod, function_name)(loaded_inputs...)[0]) + return permutedims(res, ndims(res):-1:1) +end + +@testset "EnzymeJAX Export - Without Sharding" begin + @testset "Simple function" begin + f_simple(x) = sin.(x) .+ cos.(x) + + x_data = Reactant.TestUtils.construct_test_array(Float32, 4, 5) + x = Reactant.to_rarray(x_data) + + # Compute expected result + expected_result = f_simple(x_data) + + # Export the function + python_file_path = Reactant.Serialization.export_to_enzymejax( + f_simple, x; output_dir=mktempdir(; cleanup=false) + ) + + @test isfile(python_file_path) + @test endswith(python_file_path, ".py") + + # Check that generated files exist + output_dir = dirname(python_file_path) + mlir_files = filter(f -> endswith(f, ".mlir"), readdir(output_dir)) + npz_files = filter(f -> endswith(f, ".npz"), readdir(output_dir)) + + @test length(mlir_files) > 0 + @test length(npz_files) > 0 + + # Verify Python script contains key components + python_content = read(python_file_path, String) + @test contains(python_content, "hlo_call") + @test contains(python_content, "f_simple") + + # Run the exported script and verify results + result = run_exported_enzymejax_function(python_file_path, "run_f_simple") + @test isapprox(Array(result), expected_result; atol=1e-5, rtol=1e-5) + end + + @testset "Matrix multiplication" begin + f_matmul(x, y) = x * y + + x_data = Reactant.TestUtils.construct_test_array(Float32, 3, 4) + y_data = Reactant.TestUtils.construct_test_array(Float32, 4, 5) + x = Reactant.to_rarray(x_data) + y = Reactant.to_rarray(y_data) + + # Compute expected result + expected_result = f_matmul(x_data, y_data) + + # Export the function + python_file_path = Reactant.Serialization.export_to_enzymejax( + f_matmul, x, y; output_dir=mktempdir(; cleanup=false), function_name="matmul" + ) + + @test isfile(python_file_path) + + output_dir = dirname(python_file_path) + npz_files = filter(f -> endswith(f, ".npz"), readdir(output_dir)) + @test length(npz_files) > 0 + + # Verify the NPZ file contains both inputs + npz_data = npzread( + first(filter(f -> endswith(f, ".npz"), readdir(output_dir; join=true))) + ) + @test haskey(npz_data, "arr_1") || haskey(npz_data, "arr_2") + + # Run the exported script and verify results + result = run_exported_enzymejax_function(python_file_path, "run_matmul") + @test isapprox(Array(result), expected_result; atol=1e-5, rtol=1e-5) + end + + @testset "Complex function with multiple arguments" begin + f_complex(x, y, z) = sum(x .* y .+ sin.(z); dims=2) + + x_data = Reactant.TestUtils.construct_test_array(Float32, 5, 4) + y_data = Reactant.TestUtils.construct_test_array(Float32, 5, 4) + z_data = Reactant.TestUtils.construct_test_array(Float32, 5, 4) + x = Reactant.to_rarray(x_data) + y = Reactant.to_rarray(y_data) + z = Reactant.to_rarray(z_data) + + # Compute expected result + expected_result = f_complex(x_data, y_data, z_data) + + # Export the function + python_file_path = Reactant.Serialization.export_to_enzymejax( + f_complex, + x, + y, + z; + output_dir=mktempdir(; cleanup=false), + function_name="complex_fn", + ) + + @test isfile(python_file_path) + + output_dir = dirname(python_file_path) + mlir_files = filter(f -> endswith(f, ".mlir"), readdir(output_dir)) + npz_files = filter(f -> endswith(f, ".npz"), readdir(output_dir)) + + @test length(mlir_files) > 0 + @test length(npz_files) > 0 + + python_content = read(python_file_path, String) + @test contains(python_content, "complex_fn") + + # Run the exported script and verify results + result = run_exported_enzymejax_function(python_file_path, "run_complex_fn") + @test isapprox(Array(result), expected_result; atol=1e-5, rtol=1e-5) + end +end + +@testset "EnzymeJAX Export - With Sharding" begin + # Only run sharding tests if we have multiple devices + addressable_devices = Reactant.addressable_devices() + + if length(addressable_devices) ≥ 8 + mesh = Reactant.Sharding.Mesh( + reshape(addressable_devices[1:8], 2, 4), ("batch", "feature") + ) + + @testset "Export with sharding and preserve_sharding=true" begin + f_sharded(x, y) = x .+ y + + x_data = Reactant.TestUtils.construct_test_array(Float32, 2, 4) + y_data = Reactant.TestUtils.construct_test_array(Float32, 2, 4) + x = Reactant.to_rarray( + x_data; sharding=Reactant.Sharding.NamedSharding(mesh, ("batch", "feature")) + ) + y = Reactant.to_rarray(y_data; sharding=Reactant.Sharding.Replicated(mesh)) + + # Compute expected result + expected_result = f_sharded(x_data, y_data) + + # Export with sharding preservation enabled + python_file_path = Reactant.Serialization.export_to_enzymejax( + f_sharded, + x, + y; + output_dir=mktempdir(; cleanup=false), + function_name="f_sharded_with_preserve", + preserve_sharding=true, + ) + + @test isfile(python_file_path) + + # Check that Python script includes sharding information + python_content = read(python_file_path, String) + @test ( + contains(python_content, "NamedSharding") || + contains(python_content, "pmap") || + contains(python_content, "mesh") + ) + + # Run the exported script and verify results + result = run_exported_enzymejax_function( + python_file_path, "run_f_sharded_with_preserve" + ) + @test isapprox(Array(result), expected_result; atol=1e-5, rtol=1e-5) + end + + @testset "Export with sharding but preserve_sharding=false" begin + f_sharded_no_preserve(x, y) = x .- y + + x_data = Reactant.TestUtils.construct_test_array(Float32, 2, 4) + y_data = Reactant.TestUtils.construct_test_array(Float32, 2, 4) + x = Reactant.to_rarray( + x_data; sharding=Reactant.Sharding.NamedSharding(mesh, ("batch", "feature")) + ) + y = Reactant.to_rarray(y_data; sharding=Reactant.Sharding.Replicated(mesh)) + + # Compute expected result + expected_result = f_sharded_no_preserve(x_data, y_data) + + # Export without sharding preservation + python_file_path = Reactant.Serialization.export_to_enzymejax( + f_sharded_no_preserve, + x, + y; + output_dir=mktempdir(; cleanup=false), + function_name="f_sharded_no_preserve", + preserve_sharding=false, + ) + + @test isfile(python_file_path) + + # Check that Python script does NOT include explicit sharding directives + python_content = read(python_file_path, String) + # Should have hlo_call but without the advanced sharding setup + @test contains(python_content, "hlo_call") + + # Run the exported script and verify results + result = run_exported_enzymejax_function( + python_file_path, "run_f_sharded_no_preserve" + ) + @test isapprox(Array(result), expected_result; atol=1e-5, rtol=1e-5) + end + else + @warn "Skipping sharding tests: insufficient devices (need ≥8, have $(length(addressable_devices)))" + end +end diff --git a/test/runtests.jl b/test/runtests.jl index 97a258d29c..a2585e2c18 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -6,6 +6,19 @@ end const REACTANT_TEST_GROUP = lowercase(get(ENV, "REACTANT_TEST_GROUP", "all")) +using CondaPkg + +const ENZYMEAD_INSTALLED = Ref(false) +# Install specific packages. Pkg.test doesn't pick up CondaPkg.toml in test folder +if REACTANT_TEST_GROUP == "all" || REACTANT_TEST_GROUP == "integration" + CondaPkg.add_pip("jax"; version="==0.5") + try + CondaPkg.add_pip("enzyme_ad"; version=">=0.0.9") + ENZYMEAD_INSTALLED[] = true + catch + end +end + @testset "Reactant.jl Tests" begin if REACTANT_TEST_GROUP == "all" || REACTANT_TEST_GROUP == "core" if Sys.isapple() && haskey(Reactant.XLA.global_backend_state.clients, "metal") @@ -38,7 +51,6 @@ const REACTANT_TEST_GROUP = lowercase(get(ENV, "REACTANT_TEST_GROUP", "all")) @safetestset "Cluster Detection" include("cluster_detector.jl") @safetestset "Config" include("config.jl") @safetestset "Batching" include("batching.jl") - @safetestset "Export to EnzymeJAX" include("export_enzymeax.jl") @safetestset "QA" include("qa.jl") end @@ -54,6 +66,9 @@ const REACTANT_TEST_GROUP = lowercase(get(ENV, "REACTANT_TEST_GROUP", "all")) @safetestset "Python" include("integration/python.jl") @safetestset "Optimisers" include("integration/optimisers.jl") @safetestset "FillArrays" include("integration/fillarrays.jl") + if ENZYMEAD_INSTALLED[] + @safetestset "EnzymeJAX Export" include("integration/enzymejax.jl") + end @safetestset "MPI" begin using MPI nranks = 2 From da189071381d7436c7ecdcfa5cbe66434915de48 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 6 Dec 2025 23:36:23 -0500 Subject: [PATCH 17/20] test: cleanup --- test/integration/enzymejax.jl | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/test/integration/enzymejax.jl b/test/integration/enzymejax.jl index 701f08866e..cd532801c2 100644 --- a/test/integration/enzymejax.jl +++ b/test/integration/enzymejax.jl @@ -28,7 +28,7 @@ end # Export the function python_file_path = Reactant.Serialization.export_to_enzymejax( - f_simple, x; output_dir=mktempdir(; cleanup=false) + f_simple, x; output_dir=mktempdir(; cleanup=true) ) @test isfile(python_file_path) @@ -65,7 +65,7 @@ end # Export the function python_file_path = Reactant.Serialization.export_to_enzymejax( - f_matmul, x, y; output_dir=mktempdir(; cleanup=false), function_name="matmul" + f_matmul, x, y; output_dir=mktempdir(; cleanup=true), function_name="matmul" ) @test isfile(python_file_path) @@ -104,7 +104,7 @@ end x, y, z; - output_dir=mktempdir(; cleanup=false), + output_dir=mktempdir(; cleanup=true), function_name="complex_fn", ) @@ -153,7 +153,7 @@ end f_sharded, x, y; - output_dir=mktempdir(; cleanup=false), + output_dir=mktempdir(; cleanup=true), function_name="f_sharded_with_preserve", preserve_sharding=true, ) @@ -193,7 +193,7 @@ end f_sharded_no_preserve, x, y; - output_dir=mktempdir(; cleanup=false), + output_dir=mktempdir(; cleanup=true), function_name="f_sharded_no_preserve", preserve_sharding=false, ) From e63fa4c2dfa9dc9e2418786e16ea12cc9a8ba148 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 7 Dec 2025 01:36:10 -0500 Subject: [PATCH 18/20] Apply suggestions from code review --- test/runtests.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/runtests.jl b/test/runtests.jl index a2585e2c18..b6cf869ae1 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -8,13 +8,13 @@ const REACTANT_TEST_GROUP = lowercase(get(ENV, "REACTANT_TEST_GROUP", "all")) using CondaPkg -const ENZYMEAD_INSTALLED = Ref(false) +const ENZYMEJAX_INSTALLED = Ref(false) # Install specific packages. Pkg.test doesn't pick up CondaPkg.toml in test folder if REACTANT_TEST_GROUP == "all" || REACTANT_TEST_GROUP == "integration" CondaPkg.add_pip("jax"; version="==0.5") try CondaPkg.add_pip("enzyme_ad"; version=">=0.0.9") - ENZYMEAD_INSTALLED[] = true + ENZYMEJAX_INSTALLED[] = true catch end end From fa5615a9bdd7b71a45b56e58caa680d11c91a8b2 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 7 Dec 2025 09:20:24 -0500 Subject: [PATCH 19/20] Apply suggestion from @avik-pal --- test/runtests.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/runtests.jl b/test/runtests.jl index b6cf869ae1..3aa763a03e 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -66,7 +66,7 @@ end @safetestset "Python" include("integration/python.jl") @safetestset "Optimisers" include("integration/optimisers.jl") @safetestset "FillArrays" include("integration/fillarrays.jl") - if ENZYMEAD_INSTALLED[] + if ENZYMEJAX_INSTALLED[] @safetestset "EnzymeJAX Export" include("integration/enzymejax.jl") end @safetestset "MPI" begin From 406465b79be98ebfeff982e87c16f3ddc80eacea Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 7 Dec 2025 11:02:15 -0500 Subject: [PATCH 20/20] Apply suggestion from @avik-pal --- test/runtests.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/runtests.jl b/test/runtests.jl index 3aa763a03e..bbd5e0855f 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -66,7 +66,7 @@ end @safetestset "Python" include("integration/python.jl") @safetestset "Optimisers" include("integration/optimisers.jl") @safetestset "FillArrays" include("integration/fillarrays.jl") - if ENZYMEJAX_INSTALLED[] + if ENZYMEJAX_INSTALLED[] && !Sys.isapple() @safetestset "EnzymeJAX Export" include("integration/enzymejax.jl") end @safetestset "MPI" begin