@@ -85,21 +85,25 @@ function export_to_enzymeax(
8585 mlir_path = joinpath (output_dir, " $(function_name) _$(fnid) .mlir" )
8686 write (mlir_path, hlo_code)
8787
88- invmap = IdDict ()
89- for (k, v) in mlir_fn_res. seen_args
90- invmap[v] = k
91- end
92-
9388 # Process and save inputs based on the linearized arguments
89+ # seen_args is an OrderedIdDict where keys are concrete args and values are traced args
90+ # linear_args contains only the arguments that need to be passed to the function
91+ # We iterate over seen_args which preserves the order, and only save those in linear_args
9492 input_paths = String[]
9593 input_info = []
96- for (i, linear_arg) in enumerate (mlir_fn_res. linear_args)
97- carg = invmap[linear_arg]
98- # Save the input (transposed for row-major Python/NumPy)
99- input_path = joinpath (output_dir, " $(function_name) _$(fnid) _input_$(i) .npy" )
100- _save_transposed_array (input_path, _to_array (carg))
101- push! (input_paths, input_path)
102- push! (input_info, (shape= size (carg), dtype= eltype (carg)))
94+ input_idx = 1
95+ for (concrete_arg, traced_arg) in mlir_fn_res. seen_args
96+ # Only process arguments that are in linear_args (skip computed values)
97+ if traced_arg in mlir_fn_res. linear_args
98+ # Save the input (transposed for row-major Python/NumPy)
99+ input_path = joinpath (
100+ output_dir, " $(function_name) _$(fnid) _input_$(input_idx) .npy"
101+ )
102+ _save_transposed_array (input_path, _to_array (concrete_arg))
103+ push! (input_paths, input_path)
104+ push! (input_info, (shape= size (concrete_arg), dtype= eltype (concrete_arg)))
105+ input_idx += 1
106+ end
103107 end
104108
105109 # Generate Python script
0 commit comments