Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Errror in accumulate when I have one argument as a tuple #664

Open
pevnak opened this issue Feb 14, 2024 · 2 comments
Open

Errror in accumulate when I have one argument as a tuple #664

pevnak opened this issue Feb 14, 2024 · 2 comments

Comments

@pevnak
Copy link

pevnak commented Feb 14, 2024

Hello,

I have been for educational purposes implementing RNN by hand and wanted to be fancy and use accumulate instead of recursion or for rule. But I run into an error, when one of the operands in accumulate is tuple.
A have carved out an MWE, which would look like this

using Zygote

x = [randn(Float32, 2) for i in 1:3]
h = randn(Float32, 2)


function f(α, h, x)
	o = accumulate(x, init = h) do h, x
		α * h + x
	end
end

function g(α, h, x)
	o = accumulate(x, init = (h, x[1])) do (h,_),x
		(α * h + x, x)
	end
	first.(o)
end

gradient-> sum(sum(g(α, h, x))), 1f0)[1]
gradient-> sum(sum(f(α, h, x))), 1f0)[1]

While computing gradient of f succeeds, computing gradient of g crashes with

julia> gradient-> sum(sum(g(α, h, x))), 1f0)[1]
ERROR: MethodError: no method matching construct(::Type{Any}, ::Tuple{FillArrays.Fill{Float32, 1, Tuple{Base.OneTo{Int64}}}, ChainRulesCore.NoTangent})

Closest candidates are:
  construct(::Type{T}, ::T) where T<:Tuple
   @ ChainRulesCore ~/.julia/packages/ChainRulesCore/6DiyF/src/tangent_types/structural_tangent.jl:251
  construct(::Type{T}, ::NamedTuple{L}) where {T, L}
   @ ChainRulesCore ~/.julia/packages/ChainRulesCore/6DiyF/src/tangent_types/structural_tangent.jl:235

Stacktrace:
  [1] +(a::ChainRulesCore.Tangent{Tuple{…}, Tuple{…}}, d::ChainRulesCore.Tangent{Any, Tuple{…}})
    @ ChainRulesCore ~/.julia/packages/ChainRulesCore/6DiyF/src/tangent_arithmetic.jl:142
  [2] (::ChainRules.var"#1699#1702")(::Tuple{…}, ::Tuple{…})
    @ ChainRules ~/.julia/packages/ChainRules/FLsQJ/src/rulesets/Base/mapreduce.jl:541
  [3] iterate(itr::Base.Iterators.Accumulate)
    @ Base.Iterators ./iterators.jl:589 [inlined]
  [4] collect_to!
    @ ./array.jl:892 [inlined]
  [5] collect_to_with_first!
    @ ./array.jl:870 [inlined]
  [6] _collect(c::Any, itr::Any, ::Base.EltypeUnknown, isz::Union{Base.HasLength, Base.HasShape})
    @ Base ./array.jl:864 [inlined]
  [7] collect(itr::Base.Generator)
    @ Base ./array.jl:759 [inlined]
  [8] #accumulate#893
    @ ./accumulate.jl:281 [inlined]
  [9] accumulate
    @ ./accumulate.jl:278 [inlined]
 [10] (::ChainRules.var"#decumulate#1701"{})(dy::Vector{…})
    @ ChainRules ~/.julia/packages/ChainRules/FLsQJ/src/rulesets/Base/mapreduce.jl:540
 [11] ZBack
    @ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/chainrules.jl:211 [inlined]
 [12] (::Zygote.var"#kw_zpullback#53"{ChainRules.var"#decumulate#1701"{}})(dy::Vector{Tuple{…}})
    @ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/chainrules.jl:237
 [13] g
    @ ./REPL[43]:2 [inlined]
 [14] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::FillArrays.Fill{FillArrays.Fill{…}, 1, Tuple{…}})
    @ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/interface2.jl:0
 [15] #53
    @ ./REPL[44]:1 [inlined]
 [16] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Float32)
    @ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/interface2.jl:0
 [17] (::Zygote.var"#75#76"{Zygote.Pullback{Tuple{}, Tuple{}}})(Δ::Float32)
    @ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/interface.jl:91
 [18] gradient(f::Function, args::Float32)
    @ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/interface.jl:148
 [19] top-level scope
    @ REPL[44]:1
Some type information was truncated. Use `show(err)` to see complete types.

Julia and environment

julia> versioninfo()
Julia Version 1.10.0-rc2
Commit dbb9c46795b (2023-12-03 15:25 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: macOS (x86_64-apple-darwin22.4.0)
  CPU: 8 × Intel(R) Core(TM) i5-8279U CPU @ 2.40GHz
  WORD_SIZE: 64
  LIBM: libopenlibm
  LLVM: libLLVM-15.0.7 (ORCJIT, skylake)
  Threads: 1 on 8 virtual cores

(tmp) pkg> st
Status `/private/tmp/Project.toml`
  [082447d4] ChainRules v1.63.0
  [d360d2e6] ChainRulesCore v1.21.1
  [26cc04aa] FiniteDifferences v0.12.31
  [587475ba] Flux v0.14.11
  [3bd65402] Optimisers v0.3.2
  [eeda0dda] SafeTensors v1.0.0
  [2913bbd2] StatsBase v0.34.2
  [e88e6eb3] Zygote v0.6.69

Thanks for help

@nmheim
Copy link

nmheim commented Feb 15, 2024

Zygote is constructing tangents that enter the decumulate pullback via wrap_chainrules_output. in this case its hitting the method for Union{Tuple,NamedTuple} which is interesting, because I think it should be using the method for Tuple.

I think this could be fixed by making sure wrap_chainrules_output returns a StructuralTangent... or at least if in zygote I do:

@inline function wrap_chainrules_input(dxs::Union{Tuple, NamedTuple})
  xp = map(wrap_chainrules_input, dxs)
  # This produces Tangent{Any} since it does not get to see the primal, `x`.
  # ChainRulesCore.Tangent{Any, typeof(xp)}(xp) -- comment this out and replace by line below
  ChainRulesCore.StructuralTangent{typeof(xp)}(xp)
end

things seem to work out

@mcabbott
Copy link
Member

Same error with JuliaDiff/ChainRules.jl#569, FWIW.

Not certain this is relevant, but notice the similarity to this:

julia> accumulate(=>, (1,2,3))
(1, 1 => 2, (1 => 2) => 3)

julia> accumulate(=>, [1,2,3])
ERROR: MethodError: Cannot `convert` an object of type Int64 to an object of type Pair{Int64, Int64}

and that this gradient works with x::Tuple:

julia> gradient-> sum(sum(g(α, h, Tuple(x)))), 1f0)[1]
15.059713f0

julia> gradient-> sum(sum(g(α, h, x))), 1f0)[1]  # with x::Vector as above
ERROR: MethodError: no method matching construct(::Type{Any}, ::Tuple{FillArrays.Fill{Float32, 1, Tuple{Base.OneTo{Int64}}}, ChainRulesCore.NoTangent})

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants