-
-
Notifications
You must be signed in to change notification settings - Fork 24
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Test with Yota, too #105
Test with Yota, too #105
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Paging @dfdx about these errors.
docs/src/index.md
Outdated
Unfortunately this example doesn't actually run right now. This is the error: | ||
``` | ||
julia> loss, (∇function, ∇model, ∇image) = Yota.grad(model, image) do m, x | ||
sum(m(x)) | ||
end; | ||
┌ Error: Failed to compile rrule for #233(Chain(Conv((3, 3), 64 => 64, pad=1, bias=false), BatchNorm(64, relu), Conv((3, 3), 64 => 64, pad=1, bias=false), BatchNorm(64)),), extract details via: | ||
│ (f, args) = Yota.RRULE_VIA_AD_STATE[] | ||
└ @ Yota ~/.julia/packages/Yota/GIFMf/src/cr_api.jl:160 | ||
ERROR: No deriative rule found for op %3 = getfield(%1, :x)::Array{Float32, 4} , try defining it using |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe this should stay WIP for a bit.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for pinging me! I'll be able to check out these errors during the weekend.
Some of the broken tests are already fixed on One think that I seem to be missing is why model = MyModel()
state = Optimisers.setup(Optimisers.Adam(), model)
input = ...
loss = ...
for i=1:N
grad = gradient(loss, model, input) # differentiable part
state, model = Optimisers.update(state, model, grad) # at every step # non necessarily differentiable
end
|
Sounds good. I have no idea if the tests have ZeroTangent() vs. NoTangent() the wrong way around, fine to adjust tests to whatever is produced. I broke Flux at some point because it turned out half the SciML universe rested on the gradient of |
I have a question regarding tests like this: @test Yota_gradient(m -> destructure(m)[1][2], m2)[1] == ([0,1,0], [0,0,0]) Currently, Yota returns g = Yota_gradient(m -> destructure(m)[1][2], m2)[1]
g + [0, 0, 0] # => [0.0, 1.0, 0.0] Is it what Optimisers.jl expects or should I better return a plain tuple as in the test? |
I think a Tangent is fine, this term is the gradient with respect to a Tuple. The test should be changed to allow for this, or perhaps the |
I fixed the most hardcore issues in the tests, but after several days of investigation I can't solve 2 remaining problems:
unpack(x::Tangent) = x.backing
unpack(x) = x
function Yota_gradient(f, xs...)
g = Base.tail(Yota.grad(f, xs...)[2])
return map(unpack, g)
end It helped with some tested, but broke others. Structurally, the results seem to be correct, but I don't quite understand what needs to be adjusted - Yota, I'm going to proceed with testing of Yota on Flux models + Optimisers, which should uncover more inconsistencies, but if you are want to make another pass on these tests. please try Yota@0.8.0 and share your thoughts! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Had a quick go with 0.8, and still see many errors? But will update a few things so long.
test/runtests.jl
Outdated
@@ -13,6 +13,8 @@ struct TwoThirds a; b; c; end | |||
Functors.@functor TwoThirds (a, c) | |||
Optimisers.trainable(x::TwoThirds) = (a = x.a,) | |||
|
|||
Yota_gradient(f, xs...) = Base.tail(Yota.grad(f, xs...)[2]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe this is a better rough translation function, much like the suggestion above:
Yota_gradient(f, xs...) = Base.tail(Yota.grad(f, xs...)[2]) | |
Yota_gradient(f, xs...) = map(y2z, Base.tail(Yota.grad(f, xs...)[2])); | |
y2z(::AbstractZero) = nothing # we don't care about different flavours | |
y2z(t::Tangent) = map(y2z, ChainRulesCore.backing(canonicalize(t))) | |
y2z(x) = x |
The only goal is to have as few changes as possible between tests using Zygote and the same with Yota. I don't think we care at all about the different kinds of special Zero.
Well, we care internally that all should be accepted. But when testing what's returned, we are happy if we get any one of them.
I can successfully run tests in this PR on Julia nightly with this rule added: function rrule(::typeof(getfield), s, f::Symbol)
y = getproperty(s, f)
function getproperty_pullback(dy)
dy = unthunk(dy)
T = typeof(s)
nt = NamedTuple{(f,)}((dy,))
return NoTangent(), Tangent{T}(; nt...), ZeroTangent()
end
return y, getproperty_pullback
end Yota contains the same rule for |
It's possible that this package and Functors.jl should think more about whether to call But looking at the errors on CI, maybe it's from somewhere deeper inside, involving
That said, having a rule for function rrule(::typeof(getfield), x::T, f::Symbol) where T
y = getproperty(x, f)
proj = ProjectTo(x)
# valT = Val(T) # perhaps more stable inside closure?
function getfield_pullback(dy)
nt = NamedTuple{(f,)}((unthunk(dy),))
# not really sure whether this ought to unthunk or not, maybe ProjectTo will anyway, in which case best to be explicit?
return NoTangent(), proj(Tangent{T}(; nt...)), ZeroTangent()
end
return y, getfield_pullback
end # These print lots in red:
@code_warntype rrule(getfield, (x=1, y=2.0), :x)
@code_warntype rrule(getfield, (x=1, y=2.0), :x)[2](3)
# But these are OK
@code_warntype (nt -> rrule(getfield, nt, :x))((x=1, y=2.0))
@code_warntype (nt -> rrule(getfield, nt, :x)[2](3.0))((x=1, y=2.0)) |
It's not in the tests here, but running the Metalhead example in the docs I still get this error (with or without getfield rule, 1.8 and 1.9):
|
It's even curiouser! Running a random test in REPL works fine: julia> re1 = destructure(m1)[2]
Restructure(Array, ..., 3)
julia> @test Yota_gradient(x -> re1(x)[1], rand(3))[1] == [1,0,0]
Test Passed But wrap it into julia> @testset "using Yota" begin
re1 = destructure(m1)[2]
@test Yota_gradient(x -> re1(x)[1], rand(3))[1] == [1,0,0]
end
using Yota: Error During Test at REPL[69]:3
Test threw exception
Expression: (Yota_gradient((x->((re1(x))[1];)), rand(3)))[1] == [1, 0, 0]
No deriative rule found for op %3 = getfield(%1, :re1)::Optimisers.Restructure{Vector{Float64}, Int64} , try defining it using
ChainRulesCore.rrule(::typeof(getfield), ::var"#95#96"{Optimisers.Restructure{Vector{Float64}, Int64}}, ::Symbol) = ...
... Perhaps,
Yes, it makes sense. Regarding type stability, I'm going to include your definition to Yota as is for now to keep the focus on correctness, and come back to performance later.
I'm looking at it. |
I may have spotted one of the bugs related to the failures on Metalhead example, but must make sure first. In this piece of code in generic broadcasting: ys3, backs = unzip_broadcast(args...) do a...
rrule_via_ad(cfg, f, a...)
end does f = x -> identity(x)
args = (rand(3),)
rrule(cfg, broadcasted, f, args...) which of the following is invoked: rrule_via_ad(cfg, broadcasted, f, args...) or rrule_via_ad(cfg, f, args...) ? |
I don't see an obvious mistake. The intention is for This |
Oh, I don't think it's a mistake in the generic broadcasting, but rather in using Flux, Yota
model = Dense(28*28, 1024, x -> identity(x))
x = rand(Float32, 28*28, 4)
grad((model, x) -> sum(model(x)), model, x) which produces this nice stacktrace: ERROR: all field arrays must have same shape
Stacktrace:
[1] error(s::String)
@ Base ./error.jl:35
[2] (::StructArrays.var"#6#7"{Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}})(ci::Vector{Function})
@ StructArrays ~/.julia/packages/StructArrays/w2GaP/src/structarray.jl:21
[3] map
@ ./tuple.jl:221 [inlined]
[4] (StructArrays.StructArray{Tuple{Float32, Function}, 2, Tuple{Matrix{Float32}, Vector{Function}}})(c::Tuple{Matrix{Float32}, Vector{Function}})
@ StructArrays ~/.julia/packages/StructArrays/w2GaP/src/structarray.jl:20
[5] (StructArrays.StructArray{Tuple{Float32, Function}})(c::Tuple{Matrix{Float32}, Vector{Function}})
@ StructArrays ~/.julia/packages/StructArrays/w2GaP/src/structarray.jl:97
[6] _widenstructarray(dest::StructArrays.StructArray{Tuple{Float32, var"#25#27"}, 2, Tuple{Matrix{Float32}, Matrix{var"#25#27"}}, Int64}, i::Int64, #unused#::Type{Tuple{Float32, Function}})
@ StructArrays ~/.julia/packages/StructArrays/w2GaP/src/collect.jl:118
[7] widen_from_type(dest::StructArrays.StructArray{Tuple{Float32, var"#25#27"}, 2, Tuple{Matrix{Float32}, Matrix{var"#25#27"}}, Int64}, i::Int64, #unused#::Type{Tuple{Float32, var"#24#26"}})
@ StructArrays ~/.julia/packages/StructArrays/w2GaP/src/collect.jl:109
[8] widen_from_instance(dest::StructArrays.StructArray{Tuple{Float32, var"#25#27"}, 2, Tuple{Matrix{Float32}, Matrix{var"#25#27"}}, Int64}, i::Int64, el::Tuple{Float32, var"#24#26"})
@ StructArrays ~/.julia/packages/StructArrays/w2GaP/src/collect.jl:105
[9] collect_to_structarray!(dest::StructArrays.StructArray{Tuple{Float32, var"#25#27"}, 2, Tuple{Matrix{Float32}, Matrix{var"#25#27"}}, Int64}, itr::Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{2}, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}, ChainRules.var"#1705#1707"{YotaRuleConfig, var"#141#142"}, Tuple{Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{2}, Nothing, typeof(+), Tuple{Matrix{Float32}, Vector{Float32}}}}}, offs::Int64, st::Tuple{CartesianIndices{2, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}}, CartesianIndex{2}})
@ StructArrays ~/.julia/packages/StructArrays/w2GaP/src/collect.jl:77
[10] _collect_structarray!
@ ~/.julia/packages/StructArrays/w2GaP/src/collect.jl:59 [inlined]
[11] _collect_structarray(itr::Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{2}, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}, ChainRules.var"#1705#1707"{YotaRuleConfig, var"#141#142"}, Tuple{Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{2}, Nothing, typeof(+), Tuple{Matrix{Float32}, Vector{Float32}}}}}, elem::Tuple{Tuple{Float32, var"#25#27"}, Tuple{CartesianIndices{2, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}}, CartesianIndex{2}}}, ax::Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}; initializer::StructArrays.StructArrayInitializer{typeof(StructArrays.alwaysfalse), typeof(StructArrays.arrayof)})
@ StructArrays ~/.julia/packages/StructArrays/w2GaP/src/collect.jl:54
[12] collect_structarray(itr::Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{2}, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}, ChainRules.var"#1705#1707"{YotaRuleConfig, var"#141#142"}, Tuple{Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{2}, Nothing, typeof(+), Tuple{Matrix{Float32}, Vector{Float32}}}}}; initializer::StructArrays.StructArrayInitializer{typeof(StructArrays.alwaysfalse), typeof(StructArrays.arrayof)})
@ StructArrays ~/.julia/packages/StructArrays/w2GaP/src/collect.jl:40
[13] StructArrays.StructArray(v::Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{2}, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}, ChainRules.var"#1705#1707"{YotaRuleConfig, var"#141#142"}, Tuple{Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{2}, Nothing, typeof(+), Tuple{Matrix{Float32}, Vector{Float32}}}}}; unwrap::typeof(StructArrays.alwaysfalse))
@ StructArrays ~/.julia/packages/StructArrays/w2GaP/src/structarray.jl:261
[14] StructArray
@ ~/.julia/packages/StructArrays/w2GaP/src/structarray.jl:260 [inlined]
[15] unzip_broadcast
@ ~/.julia/packages/ChainRules/DUopG/src/unzipped.jl:39 [inlined]
[16] split_bc_pullbacks(cfg::YotaRuleConfig, f::var"#141#142", args::Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{2}, Nothing, typeof(+), Tuple{Matrix{Float32}, Vector{Float32}}})
@ ChainRules ~/.julia/packages/ChainRules/DUopG/src/rulesets/Base/broadcast.jl:127
[17] rrule(cfg::YotaRuleConfig, #unused#::typeof(Base.Broadcast.broadcasted), #unused#::Base.Broadcast.DefaultArrayStyle{2}, f::var"#141#142", args::Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{2}, Nothing, typeof(+), Tuple{Matrix{Float32}, Vector{Float32}}})
@ ChainRules ~/.julia/packages/ChainRules/DUopG/src/rulesets/Base/broadcast.jl:44
[18] mkcall(::Function, ::YotaRuleConfig, ::Vararg{Any}; val::Missing, line::Core.LineInfoNode, kwargs::NamedTuple{(), Tuple{}}, free_kwargs::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
@ Umlaut ~/.julia/packages/Umlaut/vGy3v/src/tape.jl:192
[19] chainrules_transform!(tape::Tape{GradCtx})
@ Main ~/work/Yota/src/grad.jl:184
[20] gradtape!(tape::Tape{GradCtx}; seed::Int64)
@ Main ~/work/Yota/src/grad.jl:271
[21] gradtape(::Function, ::Dense{var"#141#142", Matrix{Float32}, Vector{Float32}}, ::Vararg{Any}; ctx::GradCtx, seed::Int64)
@ Main ~/work/Yota/src/grad.jl:300
[22] grad(::Function, ::Dense{var"#141#142", Matrix{Float32}, Vector{Float32}}, ::Vararg{Any}; seed::Int64)
@ Main ~/work/Yota/src/grad.jl:370
[23] grad(::Function, ::Dense{var"#141#142", Matrix{Float32}, Vector{Float32}}, ::Vararg{Any})
@ Main ~/work/Yota/src/grad.jl:362
[24] top-level scope
@ REPL[26]:1 From the stacktrace I infer that julia> y, bk = rrule_via_ad(YotaRuleConfig(), broadcasted, sqrt, [1.0, 2, 3])
...
julia> y
3-element Vector{Float64}:
1.0
1.4142135623730951
1.7320508075688772
julia> bk
#24 (generic function with 1 method) and that |
Quite the stacktrace! These lines look correct to me: The same function [16] split_bc_pullbacks(cfg::YotaRuleConfig, f::var"#141#142", args::Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{2}, Nothing, typeof(+), Tuple{Matrix{Float32}, Vector{Float32}}})
@ ChainRules ~/.julia/packages/ChainRules/DUopG/src/rulesets/Base/broadcast.jl:127
[17] rrule(cfg::YotaRuleConfig, #unused#::typeof(Base.Broadcast.broadcasted), #unused#::Base.Broadcast.DefaultArrayStyle{2}, f::var"#141#142", args::Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{2}, Nothing, typeof(+), Tuple{Matrix{Float32}, Vector{Float32}}})
@ ChainRules ~/.julia/packages/ChainRules/DUopG/src/rulesets/Base/broadcast.jl:44 To get to line 39 https://github.com/JuliaDiff/ChainRules.jl/blob/main/src/unzipped.jl#L39 the function |
Here's an interesting observation. If I run the same example as is: using Flux, Yota, ChainRules
myid = x -> identity(x)
model = Dense(5, 3, myid)
x = rand(Float32, 5, 1);
val, g = grad((model, x) -> sum(model(x)), model, x)
@show val
@show g I get the same stacktrace as posted above, complaining about "ERROR: all field arrays must have same shape". However, if I slightly modify function unzip_broadcast(f::F, args...) where {F}
T = Broadcast.combine_eltypes(f, args)
if isconcretetype(T)
T <: Tuple || throw(ArgumentError("""unzip_broadcast(f, args) only works on functions returning a tuple,
but f = $(sprint(show, f)) returns type T = $T"""))
end
bc = Broadcast.instantiate(Broadcast.broadcasted(f, args...))
bcs = Broadcast.BroadcastStyle(typeof(bc))
if bcs isa AbstractGPUArrayStyle
# This is a crude way to allow GPU arrays, not currently tested, TODO.
# See also https://github.com/JuliaArrays/StructArrays.jl/issues/150
return unzip(broadcast(f, args...))
elseif bcs isa Broadcast.AbstractArrayStyle
Broadcast.materialize(bc) # <-- this line added
return StructArrays.components(StructArray(bc))
else
return unzip(broadcast(f, args...)) # e.g. tuples
end
# TODO maybe this if-else can be replaced by methods of `unzip(:::Broadcast.Broadcasted)`?
end The error disappears! The only hypothesis I have is that materialization of a broadcasted variable changes something in the global Julia state that makes it more friendly to
|
That is pretty odd. I can reproduce this, by Edit: I've pasted in a complete session below. This julia> using Flux, Yota, ChainRules
julia> ENV["JULIA_DEBUG"] = ChainRules;
julia> begin
myid = x -> identity(x)
model = Dense(5, 3, myid)
x = rand(Float32, 5, 1)
end;
julia> val, g = grad((model, x) -> sum(model(x)), model, x)
┌ Debug: broadcasting: plus
│ length(xs) = 2
└ @ ChainRules ~/.julia/packages/ChainRules/fgVxV/src/rulesets/Base/broadcast.jl:161
┌ Debug: split broadcasting generic
│ f = #7 (generic function with 1 method)
│ N = 1
└ @ ChainRules ~/.julia/packages/ChainRules/fgVxV/src/rulesets/Base/broadcast.jl:126
ERROR: all field arrays must have same shape
Stacktrace:
[1] error(s::String)
@ Base ./error.jl:35
[2] (::StructArrays.var"#6#7"{Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}})(ci::Vector{Function})
@ StructArrays ~/.julia/packages/StructArrays/w2GaP/src/structarray.jl:21
[3] map
@ ./tuple.jl:273 [inlined]
[4] (StructArrays.StructArray{Tuple{Float32, Function}, 2, Tuple{Matrix{Float32}, Vector{Function}}})(c::Tuple{Matrix{Float32}, Vector{Function}})
@ StructArrays ~/.julia/packages/StructArrays/w2GaP/src/structarray.jl:20
[5] (StructArrays.StructArray{Tuple{Float32, Function}})(c::Tuple{Matrix{Float32}, Vector{Function}})
@ StructArrays ~/.julia/packages/StructArrays/w2GaP/src/structarray.jl:97
[6] _widenstructarray(dest::StructArrays.StructArray{Tuple{Float32, Yota.var"#21#23"}, 2, Tuple{Matrix{Float32}, Matrix{Yota.var"#21#23"}}, Int64}, i::Int64, #unused#::Type{Tuple{Float32, Function}})
@ StructArrays ~/.julia/packages/StructArrays/w2GaP/src/collect.jl:118
[7] widen_from_type(dest::StructArrays.StructArray{Tuple{Float32, Yota.var"#21#23"}, 2, Tuple{Matrix{Float32}, Matrix{Yota.var"#21#23"}}, Int64}, i::Int64, #unused#::Type{Tuple{Float32, Yota.var"#20#22"}})
@ StructArrays ~/.julia/packages/StructArrays/w2GaP/src/collect.jl:109
[8] widen_from_instance(dest::StructArrays.StructArray{Tuple{Float32, Yota.var"#21#23"}, 2, Tuple{Matrix{Float32}, Matrix{Yota.var"#21#23"}}, Int64}, i::Int64, el::Tuple{Float32, Yota.var"#20#22"})
@ StructArrays ~/.julia/packages/StructArrays/w2GaP/src/collect.jl:105
[9] collect_to_structarray!(dest::StructArrays.StructArray{Tuple{Float32, Yota.var"#21#23"}, 2, Tuple{Matrix{Float32}, Matrix{Yota.var"#21#23"}}, Int64}, itr::Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{2}, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}, ChainRules.var"#1707#1709"{Yota.YotaRuleConfig, var"#7#8"}, Tuple{Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{2}, Nothing, typeof(+), Tuple{Matrix{Float32}, Vector{Float32}}}}}, offs::Int64, st::Tuple{CartesianIndices{2, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}}, CartesianIndex{2}})
@ StructArrays ~/.julia/packages/StructArrays/w2GaP/src/collect.jl:77
[10] _collect_structarray!
@ ~/.julia/packages/StructArrays/w2GaP/src/collect.jl:59 [inlined]
[11] _collect_structarray(itr::Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{2}, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}, ChainRules.var"#1707#1709"{Yota.YotaRuleConfig, var"#7#8"}, Tuple{Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{2}, Nothing, typeof(+), Tuple{Matrix{Float32}, Vector{Float32}}}}}, elem::Tuple{Tuple{Float32, Yota.var"#21#23"}, Tuple{CartesianIndices{2, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}}, CartesianIndex{2}}}, ax::Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}; initializer::StructArrays.StructArrayInitializer{typeof(StructArrays.alwaysfalse), typeof(StructArrays.arrayof)})
@ StructArrays ~/.julia/packages/StructArrays/w2GaP/src/collect.jl:54
[12] collect_structarray(itr::Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{2}, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}, ChainRules.var"#1707#1709"{Yota.YotaRuleConfig, var"#7#8"}, Tuple{Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{2}, Nothing, typeof(+), Tuple{Matrix{Float32}, Vector{Float32}}}}}; initializer::StructArrays.StructArrayInitializer{typeof(StructArrays.alwaysfalse), typeof(StructArrays.arrayof)})
@ StructArrays ~/.julia/packages/StructArrays/w2GaP/src/collect.jl:40
[13] StructArrays.StructArray(v::Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{2}, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}, ChainRules.var"#1707#1709"{Yota.YotaRuleConfig, var"#7#8"}, Tuple{Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{2}, Nothing, typeof(+), Tuple{Matrix{Float32}, Vector{Float32}}}}}; unwrap::typeof(StructArrays.alwaysfalse))
@ StructArrays ~/.julia/packages/StructArrays/w2GaP/src/structarray.jl:261
[14] StructArray
@ ~/.julia/packages/StructArrays/w2GaP/src/structarray.jl:260 [inlined]
[15] unzip_broadcast
@ ~/.julia/packages/ChainRules/fgVxV/src/unzipped.jl:39 [inlined]
[16] split_bc_pullbacks(cfg::Yota.YotaRuleConfig, f::var"#7#8", args::Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{2}, Nothing, typeof(+), Tuple{Matrix{Float32}, Vector{Float32}}})
@ ChainRules ~/.julia/packages/ChainRules/fgVxV/src/rulesets/Base/broadcast.jl:127
[17] rrule(cfg::Yota.YotaRuleConfig, #unused#::typeof(Base.Broadcast.broadcasted), #unused#::Base.Broadcast.DefaultArrayStyle{2}, f::var"#7#8", args::Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{2}, Nothing, typeof(+), Tuple{Matrix{Float32}, Vector{Float32}}})
@ ChainRules ~/.julia/packages/ChainRules/fgVxV/src/rulesets/Base/broadcast.jl:44
[18] mkcall(::Function, ::Yota.YotaRuleConfig, ::Vararg{Any}; val::Missing, line::Core.LineInfoNode, kwargs::NamedTuple{(), Tuple{}}, free_kwargs::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
@ Umlaut ~/.julia/packages/Umlaut/vGy3v/src/tape.jl:192
[19] chainrules_transform!(tape::Umlaut.Tape{Yota.GradCtx})
@ Yota ~/.julia/packages/Yota/uu3H0/src/grad.jl:181
[20] gradtape!(tape::Umlaut.Tape{Yota.GradCtx}; seed::Int64)
@ Yota ~/.julia/packages/Yota/uu3H0/src/grad.jl:268
[21] gradtape(::Function, ::Dense{var"#7#8", Matrix{Float32}, Vector{Float32}}, ::Vararg{Any}; ctx::Yota.GradCtx, seed::Int64)
@ Yota ~/.julia/packages/Yota/uu3H0/src/grad.jl:297
[22] grad(::Function, ::Dense{var"#7#8", Matrix{Float32}, Vector{Float32}}, ::Vararg{Any}; seed::Int64)
@ Yota ~/.julia/packages/Yota/uu3H0/src/grad.jl:367
[23] grad(::Function, ::Dense{var"#7#8", Matrix{Float32}, Vector{Float32}}, ::Vararg{Any})
@ Yota ~/.julia/packages/Yota/uu3H0/src/grad.jl:359
[24] top-level scope
@ REPL[4]:1
[25] top-level scope
@ ~/.julia/packages/CUDA/DfvRa/src/initialization.jl:52
julia> @eval ChainRules function unzip_broadcast(f::F, args...) where {F}
T = Broadcast.combine_eltypes(f, args)
if isconcretetype(T)
T <: Tuple || throw(ArgumentError("""unzip_broadcast(f, args) only works on functions returning a tuple,
but f = $(sprint(show, f)) returns type T = $T"""))
end
bc = Broadcast.instantiate(Broadcast.broadcasted(f, args...))
bcs = Broadcast.BroadcastStyle(typeof(bc))
if bcs isa AbstractGPUArrayStyle
# This is a crude way to allow GPU arrays, not currently tested, TODO.
# See also https://github.com/JuliaArrays/StructArrays.jl/issues/150
return unzip(broadcast(f, args...))
elseif bcs isa Broadcast.AbstractArrayStyle
# Broadcast.materialize(bc) # <-- this line added # <-- now removed, identical to original
return StructArrays.components(StructArray(bc))
else
return unzip(broadcast(f, args...)) # e.g. tuples
end
# TODO maybe this if-else can be replaced by methods of `unzip(:::Broadcast.Broadcasted)`?
end
unzip_broadcast (generic function with 1 method)
julia> val, g = grad((model, x) -> sum(model(x)), model, x)
┌ Debug: broadcasting: plus
│ length(xs) = 2
└ @ ChainRules ~/.julia/packages/ChainRules/fgVxV/src/rulesets/Base/broadcast.jl:161
┌ Debug: split broadcasting generic
│ f = #7 (generic function with 1 method)
│ N = 1
└ @ ChainRules ~/.julia/packages/ChainRules/fgVxV/src/rulesets/Base/broadcast.jl:126
(2.6951299f0, (ChainRulesCore.ZeroTangent(), Tangent{Dense{var"#7#8", Matrix{Float32}, Vector{Float32}}}(σ = ChainRulesCore.ZeroTangent(), weight = Float32[0.88211715 0.71158904 … 0.74754727 0.49648; 0.88211715 0.71158904 … 0.74754727 0.49648; 0.88211715 0.71158904 … 0.74754727 0.49648], bias = Float32[1.0, 1.0, 1.0]), Float32[1.0559639; 1.8083295; … ; 0.78016365; -0.7226729;;])) |
Here's a hypothesis for world age problem:
But:
|
How are you adding this |
I have a file called include("core.jl") # in its turn, core.jl includes all the files from Yota, so now Main ~ Yota
using Flux
# I think these imports are not needed anymore, but just copy pasting them
import ChainRules: unzip_broadcast, RCR, TRI_NO, AbstractGPUArrayStyle, StructArrays
import ChainRules.StructArrays: StructArray
@eval ChainRules function unzip_broadcast(f::F, args...) where {F}
global BC_STATE = (f, args)
T = Broadcast.combine_eltypes(f, args)
if isconcretetype(T)
T <: Tuple || throw(ArgumentError("""unzip_broadcast(f, args) only works on functions returning a tuple,
but f = $(sprint(show, f)) returns type T = $T"""))
end
# bc - rrule_via_ad's wrapper broadcasted to all arguments
bc = Broadcast.instantiate(Broadcast.broadcasted(f, args...))
bcs = Broadcast.BroadcastStyle(typeof(bc))
if bcs isa AbstractGPUArrayStyle
# This is a crude way to allow GPU arrays, not currently tested, TODO.
# See also https://github.com/JuliaArrays/StructArrays.jl/issues/150
return unzip(broadcast(f, args...))
elseif bcs isa Broadcast.AbstractArrayStyle
println("World age before materialize(bc): $(Base.get_world_counter())")
# Broadcast.materialize(bc)
println("World age after materialize(bc): $(Base.get_world_counter())")
# global BC = bc
return StructArrays.components(StructArray(bc))
else
return unzip(broadcast(f, args...)) # e.g. tuples
end
# TODO maybe this if-else can be replaced by methods of `unzip(:::Broadcast.Broadcasted)`?
end
function bc_test()
myid = x -> identity(x)
model = Dense(5, 3, myid)
x = rand(Float32, 5, 1);
grad((model, x) -> sum(model(x)), model, x)
end Whenever I do a change, I include the whole file, thus updating all definitions from Yota + I also noticed that the problem is fixed if I replace function ChainRulesCore.rrule_via_ad(cfg::YotaRuleConfig, f, args...)
return 1.0, dy -> (ZeroTangent(), [ZeroTangent() for _ in args]...)
end In theory, I can make |
Focusing on these lines
here's a smaller reproducer: julia> using ChainRules, Yota
julia> y, bk = ChainRules.split_bc_pullbacks(Yota.YotaRuleConfig(), identity, Broadcast.broadcasted(+, [1 2; 3 4], [5, 6]));
julia> bk([7 8; 9 0]) # with identity it works fine, also sqrt
(ChainRulesCore.NoTangent(), ChainRulesCore.NoTangent(), ChainRulesCore.NoTangent(), [7 8; 9 0])
julia> y, bk = ChainRules.split_bc_pullbacks(Yota.YotaRuleConfig(), x -> identity(x), Broadcast.broadcasted(+, [1 2; 3 4], [5, 6]));
ERROR: all field arrays must have same shape
(@v1.9) pkg> st Yota ChainRules
Status `~/.julia/environments/v1.9/Project.toml`
[082447d4] ChainRules v1.44.5
[cd998857] Yota v0.8.0 This seems to avoid my order-of-loading weirdness above. If I |
Some half-way steps: julia> using ChainRules, Yota
# Easy case
julia> broadcast(Broadcast.broadcasted(+, [1 2; 3 4], [5, 6])) do x
ChainRules.rrule_via_ad(Yota.YotaRuleConfig(), identity, x)
end
2×2 Matrix{Tuple{Int64, ChainRules.var"#identity_pullback#1201"}}:
(6, identity_pullback) (7, identity_pullback)
(9, identity_pullback) (10, identity_pullback)
julia> ChainRules.unzip(ans)
([6 7; 9 10], [ChainRules.var"#identity_pullback#1201"() ChainRules.var"#identity_pullback#1201"(); ChainRules.var"#identity_pullback#1201"() ChainRules.var"#identity_pullback#1201"()])
julia> broadcast(|>, [7 8; 9 0], ans[2])
2×2 Matrix{Tuple{ChainRulesCore.NoTangent, Int64}}:
(NoTangent(), 7) (NoTangent(), 8)
(NoTangent(), 9) (NoTangent(), 0)
julia> ChainRules.unzip_broadcast(Broadcast.broadcasted(+, [1 2; 3 4], [5, 6])) do x
ChainRules.rrule_via_ad(Yota.YotaRuleConfig(), identity, x)
end
([6 7; 9 10], [ChainRules.var"#identity_pullback#1201"() ChainRules.var"#identity_pullback#1201"(); ChainRules.var"#identity_pullback#1201"() ChainRules.var"#identity_pullback#1201"()])
julia> broadcast(|>, [7 8; 9 0], ans[2])
2×2 Matrix{Tuple{ChainRulesCore.NoTangent, Int64}}:
(NoTangent(), 7) (NoTangent(), 8)
(NoTangent(), 9) (NoTangent(), 0)
# Now try with y -> identity(y)
julia> broadcast(Broadcast.broadcasted(+, [1 2; 3 4], [5, 6])) do x
ChainRules.rrule_via_ad(Yota.YotaRuleConfig(), y -> identity(y), x)
end
2×2 Matrix{Tuple{Int64, Function}}: ## <-- notice Function, abstract type
(6, #21) (7, #20)
(9, #20) (10, #20)
julia> ChainRules.unzip(ans) ## notice Core.Box
([6 7; 9 10], Function[Yota.var"#21#23"(Core.Box(Yota.var"##pullback_#72#328#86"{ChainRules.var"#identity_pullback#1201"}(ChainRules.var"#identity_pullback#1201"()))) Yota.var"#20#22"(Core.Box(Yota.var"##pullback_#72#328#86"{ChainRules.var"#identity_pullback#1201"}(ChainRules.var"#identity_pullback#1201"()))); Yota.var"#20#22"(Core.Box(Yota.var"##pullback_#72#328#86"{ChainRules.var"#identity_pullback#1201"}(ChainRules.var"#identity_pullback#1201"()))) Yota.var"#20#22"(Core.Box(Yota.var"##pullback_#72#328#86"{ChainRules.var"#identity_pullback#1201"}(ChainRules.var"#identity_pullback#1201"())))])
julia> broadcast(|>, [7 8; 9 0], ans[2])
2×2 Matrix{Tuple{ChainRulesCore.ZeroTangent, Int64}}:
(ZeroTangent(), 7) (ZeroTangent(), 8)
(ZeroTangent(), 9) (ZeroTangent(), 0)
julia> ChainRules.unzip_broadcast(Broadcast.broadcasted(+, [1 2; 3 4], [5, 6])) do x
ChainRules.rrule_via_ad(Yota.YotaRuleConfig(), y -> identity(y), x)
end
ERROR: all field arrays must have same shape
# Name the function:
julia> myid(x) = x;
julia> broadcast(Broadcast.broadcasted(+, [1 2; 3 4], [5, 6])) do x
ChainRules.rrule_via_ad(Yota.YotaRuleConfig(), myid, x)
end
2×2 Matrix{Tuple{Int64, Function}}: ## <-- looks as bad
(6, #21) (7, #20)
(9, #20) (10, #20)
julia> ChainRules.unzip_broadcast(Broadcast.broadcasted(+, [1 2; 3 4], [5, 6])) do x
ChainRules.rrule_via_ad(Yota.YotaRuleConfig(), myid, x) ## now this works, with Core.Box
end
([6 7; 9 10], Yota.var"#20#22"[Yota.var"#20#22"(Core.Box(Yota.var"##pullback_myid#334#89"{ChainRules.var"#identity_pullback#1201"}(ChainRules.var"#identity_pullback#1201"()))) Yota.var"#20#22"(Core.Box(Yota.var"##pullback_myid#334#89"{ChainRules.var"#identity_pullback#1201"}(ChainRules.var"#identity_pullback#1201"()))); Yota.var"#20#22"(Core.Box(Yota.var"##pullback_myid#334#89"{ChainRules.var"#identity_pullback#1201"}(ChainRules.var"#identity_pullback#1201"()))) Yota.var"#20#22"(Core.Box(Yota.var"##pullback_myid#334#89"{ChainRules.var"#identity_pullback#1201"}(ChainRules.var"#identity_pullback#1201"())))]) |
Apparently, in the last example there's no error because julia> ChainRules.unzip_broadcast(Broadcast.broadcasted(+, [1 2; 3 4], [5, 6])) do x
ChainRules.rrule_via_ad(Yota.YotaRuleConfig(), myid, x)
end
...
ERROR: all field arrays must have same shape
... |
My current understanding is as follows:
unzip_broadcast(args...) do a...
rrule_via_ad(cfg, f, a...)
end
Removing any of these factors solves the problem. Also, if in bc = Broadcast.instantiate(Broadcast.broadcasted(f, args...)) with this: bc = broadcast(f, args...) it leads to earlier evaluation and fixes the issue to. Given that a few lines later, in all 3 branches we materialize if bcs isa AbstractGPUArrayStyle
# This is a crude way to allow GPU arrays, not currently tested, TODO.
# See also https://github.com/JuliaArrays/StructArrays.jl/issues/150
return unzip(broadcast(f, args...))
elseif bcs isa Broadcast.AbstractArrayStyle
return StructArrays.components(StructArray(bc))
else
return unzip(broadcast(f, args...)) # e.g. tuples
end I wonder why do we need |
But one of them materialises directly two arrays, instead of allocating an array of tuples first. This path is the entire reason for this function, and for depending on StructArrays. Cc @piever in case this weird error rings any bells. (I wonder if it's possible to hit it without AD being involved?) |
Not sure if this is helpful, but here are some thoughts that could be useful.
Though I definitely am puzzled as to why this is happening. Looks like the collection mechanism |
Here's a reproducible example without Yota and ChainRules: import StructArrays
import StructArrays: StructArray
# eval a new function similar to rrule()
function make_rrule(f, args...)
name = gensym()
ex = :(function $name(f, args...)
y = sum(args)
pullback(dy) = dy + y
return y, pullback
end)
return Base.eval(@__MODULE__, ex)
end
# wrap rrule-like function with required number of invokelatest()
function rrule_via_ad(f, args...)
rr = make_rrule(f, args...)
res = Base.invokelatest(rr, f, args...)
y, pb_ = res
pb = dy -> Base.invokelatest(pb_, dy)
return y, pb
end
# original split_bc_pullbacks stripped to the bones
function split_bc_pullbacks(f::F, args::Vararg{Any,N}) where {F,N}
wf = (a...) -> rrule_via_ad(f, a...)
# comment/uncomment the next 2 lines to make the example fail/work
bc = Broadcast.instantiate(Broadcast.broadcasted(wf, args...))
# bc = broadcast(wf, args...)
return StructArrays.components(StructArray(bc))
end
bce() = Broadcast.broadcasted(+, [1 2; 3 4], [5, 6])
split_bc_pullbacks(x -> identity(x), bce()) Since |
Stil happens on the PR's branch. You can simplify a bit further, and note that acting on a vector is OK, but higher ndims fails:
Trying to pick bits out of the stack trace, is this correct?
|
Agh, no it isn't, well spotted! Somehow the widening mechanism was not updated to support arrays of arbitrary shape and only worked for 2D things... JuliaArrays/StructArrays.jl#246 should hopefully fix it! |
I can confirm JuliaArrays/StructArrays.jl#246 fixes all the issues up to my first reproducer using Flux and Yota. Thanks for the quick fix! The Metalhead example still fails though, but that's another story, which I'm looking at now. |
|
||
loss, (∇function, ∇model, ∇image) = Yota.grad(model, image) do m, x | ||
sum(m(x)) | ||
end; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I rebased this and tests pass!
This example does not, it fails with a seemingly simple error:
julia> loss, (∇function, ∇model, ∇image) = Yota.grad(model, image) do m, x
sum(m(x))
end;
loss, (_, ∇model) = Yota.grad(m -> sum(m(image)), model)ERROR: No derivative rule found for op %454 = ntuple(%452, 4)::NTuple{4, Int64} , try defining it using
ChainRulesCore.rrule(::typeof(ntuple), ::Flux.var"#336#337"{4, Array{Float32, 4}}, ::Int64) = ...
Stacktrace:
[1] error(s::String)
@ Base ./error.jl:35
[2] step_back!(tape::Umlaut.Tape{Yota.GradCtx}, y::Umlaut.Variable)
@ Yota ~/.julia/packages/Yota/KJQ6n/src/grad.jl:219
That was on tagged Yota; on latest everything instead it seems to take forever, and interrupts here:
julia> loss, (∇function, ∇model, ∇image) = Yota.grad(model, image) do m, x
sum(m(x))
end;
^CERROR: InterruptException:
Stacktrace:
[1] collect(itr::Base.Generator{Vector{Umlaut.Variable}, Yota.var"#68#72"{Umlaut.Tape{Yota.GradCtx}}})
@ Base ./array.jl:792
[2] todo_list(tape::Umlaut.Tape{Yota.GradCtx}, y::Umlaut.Variable)
@ Yota ~/.julia/packages/Yota/5CVY7/src/grad.jl:113
[3] #68
@ ./none:0 [inlined]
[4] iterate
@ ./generator.jl:47 [inlined]
[5] collect(itr::Base.Generator{Vector{Umlaut.Variable}, Yota.var"#68#72"{Umlaut.Tape{Yota.GradCtx}}})
@ Base ./array.jl:787
[6] todo_list(tape::Umlaut.Tape{Yota.GradCtx}, y::Umlaut.Variable)
@ Yota ~/.julia/packages/Yota/5CVY7/src/grad.jl:113
[7] #68
@ ./array.jl:0 [inlined]
[8] iterate
@ ./generator.jl:47 [inlined]
[9] collect_to!(dest::Vector{Vector{Umlaut.Variable}}, itr::Base.Generator{Vector{Umlaut.Variable}, Yota.var"#68#72"{Umlaut.Tape{Yota.GradCtx}}}, offs::Int64, st::Int64)
@ Base ./array.jl:845
[10] collect_to_with_first!(dest::Vector{Vector{Umlaut.Variable}}, v1::Vector{Umlaut.Variable}, itr::Base.Generator{Vector{Umlaut.Variable}, Yota.var"#68#72"{Umlaut.Tape{Yota.GradCtx}}}, st::Int64)
@ Base ./array.jl:823
[11] collect(itr::Base.Generator{Vector{Umlaut.Variable}, Yota.var"#68#72"{Umlaut.Tape{Yota.GradCtx}}})
@ Base ./array.jl:797
--- the last 10 lines are repeated 2 more times ---
(jl_aZPcXz) pkg> st
Status `/private/var/folders/yq/4p2zwd614y59gszh7y9ypyhh0000gn/T/jl_aZPcXz/Project.toml`
[dbeba491] Metalhead v0.8.0-DEV `https://github.com/FluxML/Metalhead.jl.git#master`
[3bd65402] Optimisers v0.2.10 `~/.julia/dev/Optimisers`
[09ab397b] StructArrays v0.6.13 `https://github.com/JuliaArrays/StructArrays.jl.git#master`
[cd998857] Yota v0.8.1 `https://github.com/dfdx/Yota.jl.git#main`
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm, I was indeed investigating incredibly long processing time, but profiler blamed type inference/abstract interpreter, so I started a long search for a better way to trace functions (e.g. see my recent post on Discourse). However, your stacktrace implies the problem may actually appear after the tracing. I will try to investigate this option too closer to the end of the week.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
FYI: I opened an issue to track this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed. The ResNet(18)
example now compiles and runs in 61 second (compared to 47 seconds with Zygote). Subsequent calls take ~0.4 seconds on my CPU.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great, I see something similar locally, on 0.8.2
Are the failures on nightly easy to resolve? |
It's a failure in CompilerPluginTools.jl, which apparently has not been adapted for Julia 1.9 yet. I opened JuliaCompilerPlugins/CompilerPluginTools.jl#8 to track it. |
Should we just skip tests on nightly, so that this can go in? @dfdx do you know whether 1.9 works? |
It looks like there's more work to do in CompilerPluginTools.jl to make it work on Julia 1.9, so I don't think it will happen in the nearest time. If we can skip Yota tests for Julia 1.9, it should be the most efficient solution for now. Note that Julia nightly now points to Julia 1.10, so perhaps we need a separate entry for the 1.9. |
…thout checking first because I forgot about this for ages
Tests with Yota are now skipped for 1.9 & up. Should be ready to go. Can someone approve? |
Does not close #96, in fact this surely makes tests slower. But perhaps it's good to get something besides Zygote running?