@@ -235,7 +235,7 @@ function overload_autodiff(
235235 primf = f. val
236236 primargs = ((v. val for v in args). .. ,)
237237
238- fnwrap, func2, traced_result, result, seen_args, ret, linear_args, in_tys, linear_results = Reactant . TracedUtils. make_mlir_fn (
238+ fnwrap, func2, traced_result, result, seen_args, ret, linear_args, in_tys, linear_results = TracedUtils. make_mlir_fn (
239239 primf, primargs, (), string (f) * " _autodiff" , false
240240 )
241241
@@ -302,7 +302,7 @@ function overload_autodiff(
302302 cst = MLIR. IR. result (MLIR. Dialects. stablehlo. constant (; value= attr), 1 )
303303 push! (ad_inputs, cst)
304304 end
305- else
305+ elseif TracedUtils . has_argidx (a)
306306 idx, path = TracedUtils. get_argidx (a)
307307 if idx == 1 && fnwrap
308308 act = act_from_type (f, reverse, true )
@@ -322,6 +322,12 @@ function overload_autodiff(
322322 end
323323 TracedUtils. push_val! (ad_inputs, args[idx]. dval, path[3 : end ])
324324 end
325+ else
326+ act = act_from_type (Enzyme. Const, reverse, true )
327+ push! (ret_activity, act)
328+ if act != enzyme_out && act != enzyme_outnoneed
329+ continue
330+ end
325331 end
326332 end
327333
@@ -385,7 +391,7 @@ function overload_autodiff(
385391 end
386392 residx += 1
387393 end
388- else
394+ elseif TracedUtils . has_argidx (a)
389395 idx, path = TracedUtils. get_argidx (a)
390396 if idx == 1 && fnwrap
391397 TracedUtils. set! (
@@ -405,6 +411,9 @@ function overload_autodiff(
405411 )
406412 residx += 1
407413 end
414+ else
415+ TracedUtils. set! (a, (), TracedUtils. transpose_val (MLIR. IR. result (res, residx)))
416+ residx += 1
408417 end
409418 end
410419
0 commit comments