Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make getindex rule work for AxisArrays #779

Closed
wants to merge 3 commits into from

Conversation

simsurace
Copy link

Due to the distinction between Base.axes and AxisArrays.axes, the existing rules did not work for AxisArrays.


[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
AxisArrays = "39de3d68-74b9-583c-8d2d-e117c070f3a9"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMO ChainRules should not depend on AxisArrays.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't have any insight on the merits for or against this. But what is your suggestion?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

AFAICT it has been a general policy to not accept such dependencies, see e.g. JuliaArrays/FillArrays.jl#153 (comment)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

An extension to FillArrays is also out of the question?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The PR was before extensions existed, so I have been thinking for a while one should try again with an extension. I managed to get in an extension on PDMats recently, so I think it seems likely that it would be approved.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Making that FillArrays PR into an extension there would be great.

Comment on lines 132 to 133
_setindex_zero(x::AbstractArray{<:Number}, dy, inds::Integer...) = fill!(similar(x, typeof(dy), axes(x)), false)
_setindex_zero(x::AbstractArray{<:Number}, dy, inds...) = fill!(similar(x, eltype(dy), axes(x)), false)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a reason for why these are not just

Suggested change
_setindex_zero(x::AbstractArray{<:Number}, dy, inds::Integer...) = fill!(similar(x, typeof(dy), axes(x)), false)
_setindex_zero(x::AbstractArray{<:Number}, dy, inds...) = fill!(similar(x, eltype(dy), axes(x)), false)
_setindex_zero(x::AbstractArray{<:Number}, dy, inds::Integer...) = fill!(similar(x, typeof(dy)), false)
_setindex_zero(x::AbstractArray{<:Number}, dy, inds...) = fill!(similar(x, eltype(dy)), false)

AFAICT this would also fix the AxisArrays problem.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This breaks existing tests. The problem is if you don't pass the axes, then you don't get a dense array.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Which tests are broken? The two-arg method is even advised in the Julia docs: https://docs.julialang.org/en/v1/manual/methods/#Building-a-similar-type-with-a-different-type-parameter

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The 3-arg one removes structured matrices like Symmetric, iirc

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we should just add special cases for these? At first glance, it doesn't seem very desirable to remove structure (as the AxisArrays case shows).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMO the ideal situation is for axes to return more information in AxisArrays, alla mcabbott/AxisKeys.jl#6 , since the relevant properties do belong to individual axes, not to the whole (like Symmetric). But we ran out of energy to fix things.

@mcabbott
Copy link
Member

mcabbott commented Feb 9, 2024

What I thought should work is plain_inds = Base.to_indices(x, inds), which should turn non-indices like Colon and also Symbols etc. into actual indices. Then the gradient can remain a plain array. This doesn't work at present, but perhaps something nearby does?

If the gradient is another AxisArray (as proposed here) then one question is what meaning its axis vector have. E.g. if I differentiate x = AxisArray(rand(3), ax = [1.1, 2.2, 3.3]) with respect to the 2.2, you might argue the gradient representation should be AxisArray(zeros(3), ax = [0, 1.0, 0]), which is in conflict with the meaning here.

@simsurace
Copy link
Author

Is this something that people do (label axes with floats)?

@mcabbott
Copy link
Member

mcabbott commented Feb 9, 2024

Certainly some people want AxisArrays to be a sort of DataFrame, with labels just identifying columns.

I think the structure is more generally useful for storing anything which varies along an axis, and passing this along, never using it to replace indexing. For instance keeping a probability vector associated with eachcol(matrix).

No idea really which is more common.

The current behaviour with Zygote is this, which I think means you won't silently get wrong answers. But may get errors if you wish to combine the two functions:

julia> gradient(x -> x[1], AxisArray([1,2,3.0], aux=[4,5,6.0]))  # natural 
([1.0, 0.0, 0.0],)

julia> gradient(x -> AxisArrays.axes(x)[1][1], AxisArray([1,2,3.0], aux=[4,5,6.0]))  # structural
((data = nothing, axes = ((val = [1.0, 0.0, 0.0],),)),)

AxisKeys acquired some projection rules which seem to be designed only for the first case:

julia> gradient(x -> x[1], KeyedArray([1,2,3.0], aux=[4,5,6.0]))[1]
1-dimensional KeyedArray(...) with keys:
↓   3-element Vector{Float64}
And data, 3-element ChainRules.OneElement{Float64, 1, Tuple{Int64}, Tuple{Base.OneTo{Int64}}}:
 (4.0)  1.0
 (5.0)  0.0
 (6.0)  0.0

julia> gradient(x -> axiskeys(x)[1][1], KeyedArray([1,2,3.0], aux=[4,5,6.0]))
ERROR: MethodError: (::ChainRulesCore.ProjectTo{KeyedArray, @NamedTuple{data::ChainRulesCore.ProjectTo{AbstractArray, @NamedTuple{element::ChainRulesCore.ProjectTo{Float64, @NamedTuple{}}, axes::Tuple{Base.OneTo{Int64}}}}, keys::@NamedTuple{aux::Vector{Float64}}}})(::ChainRulesCore.Tangent{KeyedArray{Float64, 1, NamedDimsArray{(:aux,), Float64, 1, Vector{Float64}}, Base.RefValue{Vector{Float64}}}, @NamedTuple{data::ChainRulesCore.NoTangent, keys::ChainRulesCore.Tangent{Base.RefValue{Vector{Float64}}, @NamedTuple{x::ChainRules.OneElement{Float64, 1, Tuple{Int64}, Tuple{Base.OneTo{Int64}}}}}}}) is ambiguous.

@simsurace
Copy link
Author

It would seem to me that the first use case, where the index is not differentiable (strings, symbols) is the more pressing one to tackle, as it is clear what the right behavior should be (i.e. identical to indexing with integers).

Could one of the maintainers please make a final decision as to whether this PR is going to be rejected based on the additional dependency? In that case, and if we can't make #780 work, I will open a PR for an extension over at AxisArrays.

@mcabbott
Copy link
Member

CR doesn't define rules for any packages, only standard lib. This isn't a whole rrule, but it seems depending on a packages to define a method is pretty much the same thing. So I don't think it is an acceptable solution.

If there's an easy way to tweak a rule to play nicely that's OK, and there are some tests involving packages. But most things should be handled by packages depending on ChainRulesCore, or pkg extensions.

Making AxisArrays opt out of the rule might work, as it must ultimately index the parent array, and probably that's the right time to call this rule? I'm still a little surprised that Base.to_indices doesn't produce valid indices; altering that logic might be another way.

@simsurace simsurace closed this Feb 10, 2024
@simsurace simsurace deleted the axisarrays-getindex branch February 10, 2024 15:24
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants