Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

unsupported keyword argument "time" when taking a gradient with Zygote #140

Open
lazarusA opened this issue Apr 29, 2023 · 1 comment
Open

Comments

@lazarusA
Copy link

Doing the following outputs a reasonable error? Any hints? Doing a similar function without AxisKeys works, however the advantage of using the time dimension is lost, which I would like to keep.

I suppose that I will need a new rule, however not sure how to actually start doing it.

using AxisKeys
using AxisKeys: KeyedArray as KA
using Zygote

ab = (;
    a =KA([5.0f0, 10.0f0];  time=1:2),
    b = KA([-2.0f0, 0.1f0];  time=1:2),
    )

function getVals(ab::NamedTuple, ts::Int)
    map(ab) do v
        in(:time, AxisKeys.dimnames(v)) ? v[time=ts][1] : v
    end
end

gradient(x -> x^2 + x*sum(getVals(ab,2)), 5)
ERROR: MethodError: no method matching adjoint(::Zygote.Context{false}, ::typeof(getindex), ::KeyedArray{Float32, 1, NamedDimsArray{(:time,), Float32, 1, Vector{Float32}}, Base.RefValue{UnitRange{Int64}}}; time=2)
Closest candidates are:
  adjoint(::ZygoteRules.AContext, ::typeof(getindex), ::AbstractArray, ::Any...) at none:0 got unsupported keyword argument "time"
  adjoint(::ZygoteRules.AContext, ::Base.Fix2, ::Any) at none:0 got unsupported keyword argument "time"
  adjoint(::ZygoteRules.AContext, ::Base.Fix1, ::Any) at none:0 got unsupported keyword argument "time"
  ...
Stacktrace:
  [1] _pullback
    @ ~/.julia/packages/ZygoteRules/OgCVT/src/adjoint.jl:75 [inlined]
  [2] _pullback
.
.
.
@mcabbott
Copy link
Owner

MWE is this:

julia> using Zygote, NamedDims

julia> gradient(x -> x[1], NamedDimsArray(rand(3), :a))  # easy case
(NamedDimsArray([1.0, 0.0, 0.0], :a),)

julia> gradient(x -> x[a=1], NamedDimsArray(rand(3), :a))
ERROR: MethodError: no method matching adjoint(::Zygote.Context{false}, ::typeof(getindex), ::NamedDimsArray{(:a,), Float64, 1, Vector{Float64}}; a::Int64)
Closest candidates are:
  adjoint(::ZygoteRules.AContext, ::typeof(getindex), ::AbstractArray, ::Any...) got unsupported keyword argument "a"

Ideally Zygote would treat this call as not having a rule, and keep going, to see later calls to getindex without keywords. On some level keywords don't participate in dispatch, but e.g. something like this nothing is the desired outcome:

julia> using ChainRulesCore

julia> rrule(getindex, NamedDimsArray(rand(3), :a), 1)  # easy case
(0.48728253527621856, ChainRules.var"#getindex_pullback#1601"{NamedDimsArray{(:a,), Float64, 1, Vector{Float64}}, Tuple{Int64}, Tuple{NoTangent}}(NamedDimsArray([0.48728253527621856, 0.20976131397698006, 0.6193363603857295], :a), (1,), (NoTangent(),)))

julia> rrule(getindex, NamedDimsArray(rand(3), :a); a=1)  # nothing, as if no rule

julia> rrule(getindex, NamedDimsArray(rand(3), :a))  # same positional arguments, without keyword
ERROR: BoundsError: attempt to access 3-element Vector{Float64} at index []

That's not how Zygote works, because (1) it's using its own @adjoint rule for getindex, and (2) even when using rrule it doesn't call it & check nothing, instead it asks the compiler which method would hypothetically be used.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants