@@ -315,7 +315,7 @@ function compile_mlir!(mod, f, args; optimize::Union{Bool,Symbol}=true)
315315
316316 toolkit = " "
317317 if isdefined (Reactant_jll, :ptxas_path )
318- toolkit = Reactant_jll. ptxas_path[1 : end - length (" /bin/ptxas" )]
318+ toolkit = Reactant_jll. ptxas_path[1 : ( end - length (" /bin/ptxas" ) )]
319319 end
320320 kern = " lower-kernel{toolkitPath=$toolkit cuLaunchKernelPtr=$(cuLaunch[]) cuModuleLoadDataPtr=$(cuModule[]) cuModuleGetFunctionPtr=$(cuFunc[]) }"
321321 if optimize === :all
@@ -329,7 +329,7 @@ function compile_mlir!(mod, f, args; optimize::Union{Bool,Symbol}=true)
329329 " remove-unnecessary-enzyme-ops" ,
330330 " enzyme-simplify-math" ,
331331 opt_passes,
332- kern
332+ kern,
333333 ],
334334 ' ,' ,
335335 ),
@@ -385,7 +385,7 @@ function compile_mlir!(mod, f, args; optimize::Union{Bool,Symbol}=true)
385385 " remove-unnecessary-enzyme-ops" ,
386386 " enzyme-simplify-math" ,
387387 opt_passes,
388- kern
388+ kern,
389389 ],
390390 ' ,' ,
391391 ),
@@ -394,7 +394,7 @@ function compile_mlir!(mod, f, args; optimize::Union{Bool,Symbol}=true)
394394 run_pass_pipeline! (mod, join ([opt_passes, " enzyme-batch" , opt_passes], " ," ))
395395 run_pass_pipeline! (mod, " enzyme,arith-raise{stablehlo=true}" ; enable_verifier= false )
396396 run_pass_pipeline! (
397- mod, " canonicalize,remove-unnecessary-enzyme-ops,enzyme-simplify-math," * kern
397+ mod, " canonicalize,remove-unnecessary-enzyme-ops,enzyme-simplify-math," * kern
398398 )
399399 elseif optimize != = :none
400400 error (" Invalid optimize option: $(Meta. quot (optimize)) " )
@@ -638,10 +638,30 @@ function codegen_unflatten!(
638638 end
639639
640640 # unroll path tree
641- for p in path
641+ for p in path[ 1 : ( end - 1 )]
642642 unflatcode = :(traced_getfield ($ unflatcode, $ (Meta. quot (p))))
643643 end
644- unflatcode = :($ unflatcode. data = $ concrete_res_name)
644+ final_val = gensym (" final_val" )
645+ unflatcode = quote
646+ $ final_val = traced_getfield ($ unflatcode, $ (Meta. quot (path[end ])))
647+ if $ final_val isa TracedRArray
648+ setfield! (
649+ $ unflatcode,
650+ $ (Meta. quot (path[end ])),
651+ ConcreteRArray {eltype($final_val),ndims($final_val)} (
652+ $ concrete_res_name, size ($ final_val)
653+ ),
654+ )
655+ elseif $ final_val isa TracedRNumber
656+ setfield! (
657+ $ unflatcode,
658+ $ (Meta. quot (path[end ])),
659+ ConcreteRNumber {eltype($final_val)} ($ concrete_res_name),
660+ )
661+ else
662+ setfield! ($ final_val, :data , $ concrete_res_name)
663+ end
664+ end
645665
646666 push! (unflatten_code, unflatcode)
647667 end
0 commit comments