Skip to content

Commit b8b5e44

Browse files
committed
fix: handle traced array returns inside objects
1 parent 0bfc722 commit b8b5e44

File tree

3 files changed

+56
-7
lines changed

3 files changed

+56
-7
lines changed

src/Compiler.jl

Lines changed: 32 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -631,17 +631,45 @@ function codegen_unflatten!(
631631
path = path[2:end]
632632
result_stores[path] = concrete_res_name
633633
continue
634-
else
635-
@assert path[1] == :resargs
634+
elseif path[1] == :resargs
636635
unflatcode = :(args[$(path[2])])
637636
path = path[3:end]
637+
else
638+
@show "path[1] == $(path[1]) has been ignored..." # XXX: Validate this is correct
639+
continue
638640
end
639641

640642
# unroll path tree
641-
for p in path
643+
for p in path[1:(end - 1)]
642644
unflatcode = :(traced_getfield($unflatcode, $(Meta.quot(p))))
643645
end
644-
unflatcode = :($unflatcode.data = $concrete_res_name)
646+
if length(path) > 0
647+
final_val = gensym("final_val")
648+
unflatcode = quote
649+
$final_val = traced_getfield($unflatcode, $(Meta.quot(path[end])))
650+
if $final_val isa TracedRArray
651+
setfield!(
652+
$unflatcode,
653+
$(Meta.quot(path[end])),
654+
ConcreteRArray{eltype($final_val),ndims($final_val)}(
655+
$concrete_res_name, size($final_val)
656+
),
657+
)
658+
elseif $final_val isa TracedRNumber
659+
setfield!(
660+
$unflatcode,
661+
$(Meta.quot(path[end])),
662+
ConcreteRNumber{eltype($final_val)}($concrete_res_name),
663+
)
664+
else
665+
setfield!($final_val, :data, $concrete_res_name)
666+
end
667+
end
668+
else
669+
unflatcode = quote
670+
$unflatcode.data = $concrete_res_name
671+
end
672+
end
645673

646674
push!(unflatten_code, unflatcode)
647675
end

src/Interpreter.jl

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -235,7 +235,7 @@ function overload_autodiff(
235235
primf = f.val
236236
primargs = ((v.val for v in args)...,)
237237

238-
fnwrap, func2, traced_result, result, seen_args, ret, linear_args, in_tys, linear_results = Reactant.TracedUtils.make_mlir_fn(
238+
fnwrap, func2, traced_result, result, seen_args, ret, linear_args, in_tys, linear_results = TracedUtils.make_mlir_fn(
239239
primf, primargs, (), string(f) * "_autodiff", false
240240
)
241241

@@ -302,7 +302,7 @@ function overload_autodiff(
302302
cst = MLIR.IR.result(MLIR.Dialects.stablehlo.constant(; value=attr), 1)
303303
push!(ad_inputs, cst)
304304
end
305-
else
305+
elseif TracedUtils.has_argidx(a)
306306
idx, path = TracedUtils.get_argidx(a)
307307
if idx == 1 && fnwrap
308308
act = act_from_type(f, reverse, true)
@@ -322,6 +322,12 @@ function overload_autodiff(
322322
end
323323
TracedUtils.push_val!(ad_inputs, args[idx].dval, path[3:end])
324324
end
325+
else
326+
act = act_from_type(Enzyme.Const, reverse, true)
327+
push!(ret_activity, act)
328+
if act != enzyme_out && act != enzyme_outnoneed
329+
continue
330+
end
325331
end
326332
end
327333

@@ -385,7 +391,7 @@ function overload_autodiff(
385391
end
386392
residx += 1
387393
end
388-
else
394+
elseif TracedUtils.has_argidx(a)
389395
idx, path = TracedUtils.get_argidx(a)
390396
if idx == 1 && fnwrap
391397
TracedUtils.set!(
@@ -405,6 +411,9 @@ function overload_autodiff(
405411
)
406412
residx += 1
407413
end
414+
else
415+
TracedUtils.set!(a, (), TracedUtils.transpose_val(MLIR.IR.result(res, residx)))
416+
residx += 1
408417
end
409418
end
410419

src/TracedUtils.jl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -341,6 +341,18 @@ function get_argidx(x)
341341
throw(AssertionError("No path found for $x"))
342342
end
343343

344+
function has_argidx(x)
345+
for path in x.paths
346+
if length(path) == 0
347+
continue
348+
end
349+
if path[1] == :args
350+
return true
351+
end
352+
end
353+
return false
354+
end
355+
344356
function set!(x, path, tostore; emptypath=false)
345357
for p in path
346358
x = Reactant.Compiler.traced_getfield(x, p)

0 commit comments

Comments
 (0)