Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable broadcasted assignment with trailing singleton dimensions #141

Merged
merged 2 commits into from
Jan 31, 2024

Conversation

wkearn
Copy link
Contributor

@wkearn wkearn commented Jan 30, 2024

With regular Arrays, it is possible to do a broadcasted assignment to a destination with fewer dimensions when the trailing dimensions all have size 1:

dest = zeros(10,9)
src = rand(10,9,1,1)
dest .= src

@assert dest == src[:,:,1,1]

But if src is a DiskArray, this fails:

using DiskArrays

struct _DiskArray{T,N,A<:AbstractArray{T,N}} <: AbstractArray{T,N}
    parent::A
    chunksize::NTuple{N,Int}
end
_DiskArray(a; chunksize=size(a)) = _DiskArray(a, chunksize)
DiskArrays.@implement_diskarray _DiskArray
Base.size(a::_DiskArray) = size(a.parent)
DiskArrays.haschunks(::_DiskArray) = DiskArrays.Chunked()
DiskArrays.eachchunk(a::_DiskArray) = DiskArrays.GridChunks(a, a.chunksize)
DiskArrays.readblock!(a::_DiskArray, aout, i::AbstractUnitRange...) = aout .= a.parent[i...]
DiskArrays.writeblock!(a::_DiskArray, v, i::AbstractUnitRange...) = view(a.parent, i...) .= v

data = rand(10,9,1,1)
src = _DiskArray(data)
dest = zeros(10,9)
dest .= src

@assert dest == data[:,:,1,1]

with the error:

ERROR: MethodError: no method matching splittuple()

Closest candidates are:
  splittuple(::Any, ::Any...)
   @ DiskArrays ~/.julia/packages/DiskArrays/1rcQi/src/broadcast.jl:127

Stacktrace:
  [1] maybeonerange(out::Tuple{UnitRange{Int64}, UnitRange{Int64}}, sizes::Tuple{Int64, Int64}, ranges::Tuple{})
    @ DiskArrays ~/.julia/packages/DiskArrays/1rcQi/src/broadcast.jl:122
  [2] maybeonerange(out::Tuple{UnitRange{Int64}}, sizes::Tuple{Int64, Int64, Int64}, ranges::Tuple{UnitRange{Int64}})
    @ DiskArrays ~/.julia/packages/DiskArrays/1rcQi/src/broadcast.jl:123
  [3] maybeonerange(out::Tuple{}, sizes::NTuple{4, Int64}, ranges::Tuple{UnitRange{Int64}, UnitRange{Int64}})
    @ DiskArrays ~/.julia/packages/DiskArrays/1rcQi/src/broadcast.jl:123
  [4] maybeonerange(sizes::NTuple{4, Int64}, ranges::Tuple{UnitRange{Int64}, UnitRange{Int64}})
    @ DiskArrays ~/.julia/packages/DiskArrays/1rcQi/src/broadcast.jl:126
  [5] subsetarg(x::_DiskArray{Float64, 4, Array{Float64, 4}}, a::Tuple{UnitRange{Int64}, UnitRange{Int64}})
    @ Main ~/.julia/packages/DiskArrays/1rcQi/src/broadcast.jl:144
  [6] (::DiskArrays.var"#62#64"{Tuple{UnitRange{Int64}, UnitRange{Int64}}})(i::_DiskArray{Float64, 4, Array{Float64, 4}})
    @ DiskArrays ~/.julia/packages/DiskArrays/1rcQi/src/broadcast.jl:41
  [7] map
    @ ./tuple.jl:291 [inlined]
  [8] (::DiskArrays.var"#61#63"{Matrix{}, Base.Broadcast.Broadcasted{}})(cnow::Tuple{UnitRange{…}, UnitRange{…}})
    @ DiskArrays ~/.julia/packages/DiskArrays/1rcQi/src/broadcast.jl:41
  [9] foreach(f::DiskArrays.var"#61#63"{Matrix{}, Base.Broadcast.Broadcasted{}}, itr::DiskArrays.GridChunks{2, Tuple{…}})
    @ Base ./abstractarray.jl:3094
 [10] copyto!(dest::Matrix{…}, bc::Base.Broadcast.Broadcasted{…})
    @ DiskArrays ~/.julia/packages/DiskArrays/1rcQi/src/broadcast.jl:38
 [11] materialize!
    @ Base.Broadcast ./broadcast.jl:914 [inlined]
 [12] materialize!(dest::Matrix{…}, bc::Base.Broadcast.Broadcasted{…})
    @ Base.Broadcast ./broadcast.jl:911
 [13] top-level scope
    @ REPL[16]:1

This seems to occur in the case where there are trailing singleton dimensions because the recursive maybeonerange does not reach the correct base case. This can be fixed by adding the following cases

maybeonerange(out, sizes, ::Tuple{}) = out
maybeonerange(out, ::Tuple{}, ::Tuple{}) = out

where the first one is the case that we reach with the trailing singleton dimensions. The second one is needed to resolve method ambiguities. I think by the time the call gets to this point, the array shapes have all been checked, so it should be safe to do this, but I could be wrong.

This pull request makes that change and adds a test.

@rafaqz
Copy link
Collaborator

rafaqz commented Jan 30, 2024

From the error it looks like this didnt work previously on 1.6? (julia throws an error before we hit the DiskArrays.jl error)

We may need to only test this after 1.9/1.10 or whenever it was implemented.

@wkearn
Copy link
Contributor Author

wkearn commented Jan 31, 2024

Looks like this is the commit that added this broadcasting behavior, so the test should work for versions 1.7 and higher. I've wrapped the testset in a check to run it only on these versions. but let me know if you would rather do it a different way.

CI is failing on nightly, but I think that is unrelated to this change. See #142

@rafaqz
Copy link
Collaborator

rafaqz commented Jan 31, 2024

Thanks

@rafaqz rafaqz merged commit 38bae53 into JuliaIO:main Jan 31, 2024
9 of 12 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants