Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

accum_param_gradients! does not support scale_factor for static functions #387

Open
ztangent opened this issue Feb 25, 2021 · 1 comment

Comments

@ztangent
Copy link
Member

As stated in the title, accum_param_gradients! does not support scale_factor for static functions. Calling accum_param_gradients! with a third argument returns ERROR: Not implemented, because it defaults to the abstract GFI definition.

This is due to (1) the lack of a generated method definition with the appropriate signature:

push!(generated_functions, quote
@generated function $(GlobalRef(Gen, :accumulate_param_gradients!))(trace::T, retval_grad) where {T<:$(QuoteNode(StaticIRTrace))}
$(QuoteNode(codegen_accumulate_param_gradients!))(trace, retval_grad)
end
end)

And (2) the lack of logic to handle a scale factor in the backward pass for trainable parameter nodes:

function back_codegen!(stmts, ir, selected_calls, fwd_marked, back_marked, node::TrainableParameterNode, mode)
# handle case when it is the return node
if node === ir.return_node && node in fwd_marked
@assert node in back_marked
push!(stmts, :(isnothing(retval_grad) && error("Required return value gradient but got nothing")))
push!(stmts, :($(gradient_var(node)) += retval_grad))
end
if node in fwd_marked && node in back_marked
cur_param_grad = :($(QuoteNode(get_param_grad))(trace.$static_ir_gen_fn_ref,
$(QuoteNode(node.name))))
push!(stmts, :($(QuoteNode(set_param_grad!))(trace.$static_ir_gen_fn_ref,
$(QuoteNode(node.name)),
$cur_param_grad + $(gradient_var(node)))))
end
end

@marcoct
Copy link
Collaborator

marcoct commented Jul 13, 2021

This was addressed by #417. However, I don't see a test for it here: https://github.com/probcomp/Gen.jl/blob/20210512-marcoct-gradopts/test/static_ir/gradients.jl.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants