Skip to content

Commit 89b7690

Browse files
Copilotwsmoses
andcommitted
Use existing seen_args infrastructure instead of manual inverse map
Co-authored-by: wsmoses <1260124+wsmoses@users.noreply.github.com>
1 parent 43aaf0e commit 89b7690

File tree

1 file changed

+16
-12
lines changed

1 file changed

+16
-12
lines changed

src/serialization/EnzymeJAX.jl

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -85,21 +85,25 @@ function export_to_enzymeax(
8585
mlir_path = joinpath(output_dir, "$(function_name)_$(fnid).mlir")
8686
write(mlir_path, hlo_code)
8787

88-
invmap = IdDict()
89-
for (k, v) in mlir_fn_res.seen_args
90-
invmap[v] = k
91-
end
92-
9388
# Process and save inputs based on the linearized arguments
89+
# seen_args is an OrderedIdDict where keys are concrete args and values are traced args
90+
# linear_args contains only the arguments that need to be passed to the function
91+
# We iterate over seen_args which preserves the order, and only save those in linear_args
9492
input_paths = String[]
9593
input_info = []
96-
for (i, linear_arg) in enumerate(mlir_fn_res.linear_args)
97-
carg = invmap[linear_arg]
98-
# Save the input (transposed for row-major Python/NumPy)
99-
input_path = joinpath(output_dir, "$(function_name)_$(fnid)_input_$(i).npy")
100-
_save_transposed_array(input_path, _to_array(carg))
101-
push!(input_paths, input_path)
102-
push!(input_info, (shape=size(carg), dtype=eltype(carg)))
94+
input_idx = 1
95+
for (concrete_arg, traced_arg) in mlir_fn_res.seen_args
96+
# Only process arguments that are in linear_args (skip computed values)
97+
if traced_arg in mlir_fn_res.linear_args
98+
# Save the input (transposed for row-major Python/NumPy)
99+
input_path = joinpath(
100+
output_dir, "$(function_name)_$(fnid)_input_$(input_idx).npy"
101+
)
102+
_save_transposed_array(input_path, _to_array(concrete_arg))
103+
push!(input_paths, input_path)
104+
push!(input_info, (shape=size(concrete_arg), dtype=eltype(concrete_arg)))
105+
input_idx += 1
106+
end
103107
end
104108

105109
# Generate Python script

0 commit comments

Comments
 (0)