Skip to content

Commit 3633105

Browse files
committed
fix: handle traced array returns inside objects
1 parent 556d014 commit 3633105

File tree

4 files changed

+56
-13
lines changed

4 files changed

+56
-13
lines changed

src/Compiler.jl

Lines changed: 26 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -315,7 +315,7 @@ function compile_mlir!(mod, f, args; optimize::Union{Bool,Symbol}=true)
315315

316316
toolkit = ""
317317
if isdefined(Reactant_jll, :ptxas_path)
318-
toolkit = Reactant_jll.ptxas_path[1:end-length("/bin/ptxas")]
318+
toolkit = Reactant_jll.ptxas_path[1:(end - length("/bin/ptxas"))]
319319
end
320320
kern = "lower-kernel{toolkitPath=$toolkit cuLaunchKernelPtr=$(cuLaunch[]) cuModuleLoadDataPtr=$(cuModule[]) cuModuleGetFunctionPtr=$(cuFunc[])}"
321321
if optimize === :all
@@ -329,7 +329,7 @@ function compile_mlir!(mod, f, args; optimize::Union{Bool,Symbol}=true)
329329
"remove-unnecessary-enzyme-ops",
330330
"enzyme-simplify-math",
331331
opt_passes,
332-
kern
332+
kern,
333333
],
334334
',',
335335
),
@@ -385,7 +385,7 @@ function compile_mlir!(mod, f, args; optimize::Union{Bool,Symbol}=true)
385385
"remove-unnecessary-enzyme-ops",
386386
"enzyme-simplify-math",
387387
opt_passes,
388-
kern
388+
kern,
389389
],
390390
',',
391391
),
@@ -394,7 +394,7 @@ function compile_mlir!(mod, f, args; optimize::Union{Bool,Symbol}=true)
394394
run_pass_pipeline!(mod, join([opt_passes, "enzyme-batch", opt_passes], ","))
395395
run_pass_pipeline!(mod, "enzyme,arith-raise{stablehlo=true}"; enable_verifier=false)
396396
run_pass_pipeline!(
397-
mod, "canonicalize,remove-unnecessary-enzyme-ops,enzyme-simplify-math,"*kern
397+
mod, "canonicalize,remove-unnecessary-enzyme-ops,enzyme-simplify-math," * kern
398398
)
399399
elseif optimize !== :none
400400
error("Invalid optimize option: $(Meta.quot(optimize))")
@@ -638,10 +638,30 @@ function codegen_unflatten!(
638638
end
639639

640640
# unroll path tree
641-
for p in path
641+
for p in path[1:(end - 1)]
642642
unflatcode = :(traced_getfield($unflatcode, $(Meta.quot(p))))
643643
end
644-
unflatcode = :($unflatcode.data = $concrete_res_name)
644+
final_val = gensym("final_val")
645+
unflatcode = quote
646+
$final_val = traced_getfield($unflatcode, $(Meta.quot(path[end])))
647+
if $final_val isa TracedRArray
648+
setfield!(
649+
$unflatcode,
650+
$(Meta.quot(path[end])),
651+
ConcreteRArray{eltype($final_val),ndims($final_val)}(
652+
$concrete_res_name, size($final_val)
653+
),
654+
)
655+
elseif $final_val isa TracedRNumber
656+
setfield!(
657+
$unflatcode,
658+
$(Meta.quot(path[end])),
659+
ConcreteRNumber{eltype($final_val)}($concrete_res_name),
660+
)
661+
else
662+
setfield!($final_val, :data, $concrete_res_name)
663+
end
664+
end
645665

646666
push!(unflatten_code, unflatcode)
647667
end

src/Interpreter.jl

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -235,7 +235,7 @@ function overload_autodiff(
235235
primf = f.val
236236
primargs = ((v.val for v in args)...,)
237237

238-
fnwrap, func2, traced_result, result, seen_args, ret, linear_args, in_tys, linear_results = Reactant.TracedUtils.make_mlir_fn(
238+
fnwrap, func2, traced_result, result, seen_args, ret, linear_args, in_tys, linear_results = TracedUtils.make_mlir_fn(
239239
primf, primargs, (), string(f) * "_autodiff", false
240240
)
241241

@@ -302,7 +302,7 @@ function overload_autodiff(
302302
cst = MLIR.IR.result(MLIR.Dialects.stablehlo.constant(; value=attr), 1)
303303
push!(ad_inputs, cst)
304304
end
305-
else
305+
elseif TracedUtils.has_argidx(a)
306306
idx, path = TracedUtils.get_argidx(a)
307307
if idx == 1 && fnwrap
308308
act = act_from_type(f, reverse, true)
@@ -322,6 +322,12 @@ function overload_autodiff(
322322
end
323323
TracedUtils.push_val!(ad_inputs, args[idx].dval, path[3:end])
324324
end
325+
else
326+
act = act_from_type(Enzyme.Const, reverse, true)
327+
push!(ret_activity, act)
328+
if act != enzyme_out && act != enzyme_outnoneed
329+
continue
330+
end
325331
end
326332
end
327333

@@ -385,7 +391,7 @@ function overload_autodiff(
385391
end
386392
residx += 1
387393
end
388-
else
394+
elseif TracedUtils.has_argidx(a)
389395
idx, path = TracedUtils.get_argidx(a)
390396
if idx == 1 && fnwrap
391397
TracedUtils.set!(
@@ -405,6 +411,11 @@ function overload_autodiff(
405411
)
406412
residx += 1
407413
end
414+
else
415+
TracedUtils.set!(
416+
a, (), TracedUtils.transpose_val(MLIR.IR.result(res, residx))
417+
)
418+
residx += 1
408419
end
409420
end
410421

src/TracedUtils.jl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -341,6 +341,18 @@ function get_argidx(x)
341341
throw(AssertionError("No path found for $x"))
342342
end
343343

344+
function has_argidx(x)
345+
for path in x.paths
346+
if length(path) == 0
347+
continue
348+
end
349+
if path[1] == :args
350+
return true
351+
end
352+
end
353+
return false
354+
end
355+
344356
function set!(x, path, tostore; emptypath=false)
345357
for p in path
346358
x = Reactant.Compiler.traced_getfield(x, p)

src/utils.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -229,9 +229,9 @@ function make_oc(
229229
sig::Type, rt::Type, src::Core.CodeInfo, nargs::Int, isva::Bool, f::Any
230230
)::Core.OpaqueClosure
231231
key = (sig, rt, src, nargs, isva, f)
232-
if haskey(oc_captures, key)
233-
return oc_captures[key]
234-
else
232+
# if haskey(oc_captures, key)
233+
# return oc_captures[key]
234+
# else
235235
ores = ccall(
236236
:jl_new_opaque_closure_from_code_info,
237237
Any,
@@ -250,7 +250,7 @@ function make_oc(
250250
)::Core.OpaqueClosure
251251
oc_captures[key] = ores
252252
return ores
253-
end
253+
# end
254254
end
255255

256256
function safe_print(name, x)

0 commit comments

Comments
 (0)