-
Notifications
You must be signed in to change notification settings - Fork 62
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Errror in accumulate when I have one argument as a tuple #664
Comments
Zygote is constructing tangents that enter the I think this could be fixed by making sure @inline function wrap_chainrules_input(dxs::Union{Tuple, NamedTuple})
xp = map(wrap_chainrules_input, dxs)
# This produces Tangent{Any} since it does not get to see the primal, `x`.
# ChainRulesCore.Tangent{Any, typeof(xp)}(xp) -- comment this out and replace by line below
ChainRulesCore.StructuralTangent{typeof(xp)}(xp)
end things seem to work out |
Same error with JuliaDiff/ChainRules.jl#569, FWIW. Not certain this is relevant, but notice the similarity to this: julia> accumulate(=>, (1,2,3))
(1, 1 => 2, (1 => 2) => 3)
julia> accumulate(=>, [1,2,3])
ERROR: MethodError: Cannot `convert` an object of type Int64 to an object of type Pair{Int64, Int64} and that this gradient works with x::Tuple: julia> gradient(α -> sum(sum(g(α, h, Tuple(x)))), 1f0)[1]
15.059713f0
julia> gradient(α -> sum(sum(g(α, h, x))), 1f0)[1] # with x::Vector as above
ERROR: MethodError: no method matching construct(::Type{Any}, ::Tuple{FillArrays.Fill{Float32, 1, Tuple{Base.OneTo{Int64}}}, ChainRulesCore.NoTangent}) |
Hello,
I have been for educational purposes implementing RNN by hand and wanted to be fancy and use
accumulate
instead of recursion or for rule. But I run into an error, when one of the operands in accumulate is tuple.A have carved out an MWE, which would look like this
While computing gradient of
f
succeeds, computing gradient ofg
crashes withJulia and environment
Thanks for help
The text was updated successfully, but these errors were encountered: