Skip to content

Commit

Permalink
[BlockSparseArrays] Permute and merge blocks (#1514)
Browse files Browse the repository at this point in the history
* [BlockSparseArrays] Permute and merge blocks

* [NDTensors] Bump to v0.3.39
  • Loading branch information
mtfishman committed Jul 1, 2024
1 parent 2985e9b commit d734e64
Show file tree
Hide file tree
Showing 14 changed files with 673 additions and 96 deletions.
2 changes: 1 addition & 1 deletion NDTensors/Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "NDTensors"
uuid = "23ae76d9-e61a-49c4-8f12-3f1a16adf9cf"
authors = ["Matthew Fishman <mfishman@flatironinstitute.org>"]
version = "0.3.38"
version = "0.3.39"

[deps]
Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,12 @@ function TensorAlgebra.splitdims(
return length(axis) length(axes(a, i))
end
blockperms = invblockperm.(blocksortperm.(axes_prod))
a_blockpermed = a[blockperms...]
# TODO: This is doing extra copies of the blocks,
# use `@view a[axes_prod...]` instead.
# That will require implementing some reindexing logic
# for this combination of slicing.
a_unblocked = a[axes_prod...]
a_blockpermed = a_unblocked[blockperms...]
return splitdims(BlockReshapeFusion(), a_blockpermed, split_axes...)
end

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,14 +87,9 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64})
a = BlockSparseArray{elt}(d1, d2, d1, d2)
blockdiagonal!(randn!, a)
m = fusedims(a, (1, 2), (3, 4))
# TODO: Once block merging is implemented, this should
# be the real test.
for ax in axes(m)
@test ax isa GradedOneTo
# TODO: Current `fusedims` doesn't merge
# common sectors, need to fix.
@test_broken blocklabels(ax) == [U1(0), U1(1), U1(2)]
@test blocklabels(ax) == [U1(0), U1(1), U1(1), U1(2)]
@test blocklabels(ax) == [U1(0), U1(1), U1(2)]
end
for I in CartesianIndices(m)
if I CartesianIndex.([(1, 1), (4, 4)])
Expand All @@ -105,10 +100,7 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64})
end
@test a[1, 1, 1, 1] == m[1, 1]
@test a[2, 2, 2, 2] == m[4, 4]
# TODO: Current `fusedims` doesn't merge
# common sectors, need to fix.
@test_broken blocksize(m) == (3, 3)
@test blocksize(m) == (4, 4)
@test blocksize(m) == (3, 3)
@test a == splitdims(m, (d1, d2), (d1, d2))
end
@testset "dual axes" begin
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,14 @@ using BlockArrays:
AbstractBlockArray,
AbstractBlockVector,
Block,
BlockIndex,
BlockIndexRange,
BlockRange,
BlockSlice,
BlockVector,
BlockedOneTo,
BlockedUnitRange,
BlockVector,
BlockSlice,
BlockedVector,
block,
blockaxes,
blockedrange,
Expand All @@ -17,8 +20,30 @@ using BlockArrays:
findblockindex
using Compat: allequal
using Dictionaries: Dictionary, Indices
using ..GradedAxes: blockedunitrange_getindices
using ..SparseArrayInterface: stored_indices
using ..GradedAxes: blockedunitrange_getindices, to_blockindices
using ..SparseArrayInterface: SparseArrayInterface, nstored, stored_indices

# A return type for `blocks(array)` when `array` isn't blocked.
# Represents a vector with just that single block.
struct SingleBlockView{T,N,Array<:AbstractArray{T,N}} <: AbstractArray{T,N}
array::Array
end
blocks_maybe_single(a) = blocks(a)
blocks_maybe_single(a::Array) = SingleBlockView(a)
function Base.getindex(a::SingleBlockView{<:Any,N}, index::Vararg{Int,N}) where {N}
@assert all(isone, index)
return a.array
end

# A wrapper around a potentially blocked array that is not blocked.
struct NonBlockedArray{T,N,Array<:AbstractArray{T,N}} <: AbstractArray{T,N}
array::Array
end
Base.size(a::NonBlockedArray) = size(a.array)
Base.getindex(a::NonBlockedArray{<:Any,N}, I::Vararg{Integer,N}) where {N} = a.array[I...]
BlockArrays.blocks(a::NonBlockedArray) = SingleBlockView(a.array)
const NonBlockedVector{T,Array} = NonBlockedArray{T,1,Array}
NonBlockedVector(array::AbstractVector) = NonBlockedArray(array)

# BlockIndices works around an issue that the indices of BlockSlice
# are restricted to AbstractUnitRange{Int}.
Expand All @@ -37,6 +62,43 @@ function Base.getindex(S::BlockIndices, i::BlockSlice{<:Block{1}})
@assert length(S.indices[Block(i)]) == length(i.indices)
return BlockSlice(S.blocks[Int(Block(i))], S.indices[Block(i)])
end

# This is used in slicing like:
# a = BlockSparseArray{Float64}([2, 2, 2, 2], [2, 2, 2, 2])
# I = BlockedVector([Block(4), Block(3), Block(2), Block(1)], [2, 2])
# a[I, I]
function Base.getindex(
S::BlockIndices{<:AbstractBlockVector{<:Block{1}}}, i::BlockSlice{<:Block{1}}
)
# TODO: Check for conistency of indices.
# Wrapping the indices in `NonBlockedVector` reinterprets the blocked indices
# as a single block, since the result shouldn't be blocked.
return NonBlockedVector(BlockIndices(S.blocks[Block(i)], S.indices[Block(i)]))
end
function Base.getindex(
S::BlockIndices{<:BlockedVector{<:Block{1},<:BlockRange{1}}}, i::BlockSlice{<:Block{1}}
)
return i
end

# Used in indexing such as:
# ```julia
# a = BlockSparseArray{Float64}([2, 2, 2, 2], [2, 2, 2, 2])
# I = BlockedVector([Block(4), Block(3), Block(2), Block(1)], [2, 2])
# b = @view a[I, I]
# @view b[Block(1, 1)[1:2, 2:2]]
# ```
# This is similar to the definition:
# blocksparse_to_indices(a, inds, I::Tuple{UnitRange{<:Integer},Vararg{Any}})
function Base.getindex(
a::NonBlockedVector{<:Integer,<:BlockIndices}, I::UnitRange{<:Integer}
)
ax = only(axes(a.array.indices))
brs = to_blockindices(ax, I)
inds = blockedunitrange_getindices(ax, I)
return NonBlockedVector(a.array[BlockSlice(brs, inds)])
end

function Base.getindex(S::BlockIndices, i::BlockSlice{<:BlockRange{1}})
# TODO: Check that `i.indices` is consistent with `S.indices`.
# TODO: Turn this into a `blockedunitrange_getindices` definition.
Expand All @@ -50,6 +112,34 @@ function Base.getindex(S::BlockIndices, i::BlockSlice{<:BlockRange{1}})
return BlockIndices(subblocks, subindices)
end

# Used when performing slices like:
# @views a[[Block(2), Block(1)]][2:4, 2:4]
function Base.getindex(S::BlockIndices, i::BlockSlice{<:BlockVector{<:BlockIndex{1}}})
subblocks = mortar(
map(blocks(i.block)) do br
return S.blocks[Int(Block(br))][only(br.indices)]
end,
)
subindices = mortar(
map(blocks(i.block)) do br
S.indices[br]
end,
)
return BlockIndices(subblocks, subindices)
end

# Similar to the definition of `BlockArrays.BlockSlices`:
# ```julia
# const BlockSlices = Union{Base.Slice,BlockSlice{<:BlockRange{1}}}
# ```
# but includes `BlockIndices`, where the blocks aren't contiguous.
const BlockSliceCollection = Union{
Base.Slice,BlockSlice{<:BlockRange{1}},BlockIndices{<:Vector{<:Block{1}}}
}
const SubBlockSliceCollection = BlockIndices{
<:BlockVector{<:BlockIndex{1},<:Vector{<:BlockIndexRange{1}}}
}

# TODO: This is type piracy. This is used in `reindex` when making
# views of blocks of sliced block arrays, for example:
# ```julia
Expand Down Expand Up @@ -218,6 +308,12 @@ function blockrange(axis::AbstractUnitRange, r::UnitRange)
return findblock(axis, first(r)):findblock(axis, last(r))
end

# Occurs when slicing with `a[2:4, 2:4]`.
function blockrange(axis::BlockedOneTo{<:Integer}, r::BlockedUnitRange{<:Integer})
# TODO: Check the blocks are commensurate.
return findblock(axis, first(r)):findblock(axis, last(r))
end

function blockrange(axis::AbstractUnitRange, r::Int)
## return findblock(axis, r)
return error("Slicing with integer values isn't supported.")
Expand All @@ -241,14 +337,17 @@ function blockrange(axis::BlockedOneTo{<:Integer}, r::BlockedOneTo{<:Integer})
return only(blockaxes(r))
end

# This handles changing the blocking, for example:
# This handles block merging:
# a = BlockSparseArray{Float64}([2, 2, 2, 2], [2, 2, 2, 2])
# I = BlockedVector(Block.(1:4), [2, 2])
# I = BlockVector(Block.(1:4), [2, 2])
# I = BlockedVector([Block(4), Block(3), Block(2), Block(1)], [2, 2])
# I = BlockVector([Block(4), Block(3), Block(2), Block(1)], [2, 2])
# a[I, I]
# TODO: Generalize to `AbstractBlockedUnitRange` and `AbstractBlockVector`.
function blockrange(axis::BlockedOneTo{<:Integer}, r::BlockVector{<:Integer})
# TODO: Probably this is incorrect and should be something like:
# return findblock(axis, first(r)):findblock(axis, last(r))
function blockrange(axis::BlockedOneTo{<:Integer}, r::AbstractBlockVector{<:Block{1}})
for b in r
@assert b blockaxes(axis, 1)
end
return only(blockaxes(r))
end

Expand Down Expand Up @@ -287,6 +386,10 @@ function blockrange(axis::AbstractUnitRange, r::Base.Slice)
return only(blockaxes(axis))
end

function blockrange(axis::AbstractUnitRange, r::NonBlockedVector)
return Block(1):Block(1)
end

function blockrange(axis::AbstractUnitRange, r)
return error("Slicing not implemented for range of type `$(typeof(r))`.")
end
Expand Down Expand Up @@ -423,7 +526,18 @@ function Base.setindex!(a::BlockView{<:Any,N}, value, index::Vararg{Int,N}) wher
return a
end

function view!(a::BlockSparseArray{<:Any,N}, index::Block{N}) where {N}
function SparseArrayInterface.nstored(a::BlockView)
# TODO: Store whether or not the block is stored already as
# a Bool in `BlockView`.
I = CartesianIndex(Int.(a.block))
# TODO: Use `block_stored_indices`.
if I stored_indices(blocks(a.array))
return nstored(blocks(a.array)[I])
end
return 0
end

function view!(a::AbstractArray{<:Any,N}, index::Block{N}) where {N}
return view!(a, Tuple(index)...)
end
function view!(a::AbstractArray{<:Any,N}, index::Vararg{Block{1},N}) where {N}
Expand Down
3 changes: 2 additions & 1 deletion NDTensors/src/lib/BlockSparseArrays/src/BlockSparseArrays.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
module BlockSparseArrays
include("BlockArraysExtensions/BlockArraysExtensions.jl")
include("blocksparsearrayinterface/blocksparsearrayinterface.jl")
include("blocksparsearrayinterface/linearalgebra.jl")
include("blocksparsearrayinterface/blockzero.jl")
include("blocksparsearrayinterface/broadcast.jl")
include("blocksparsearrayinterface/arraylayouts.jl")
include("blocksparsearrayinterface/views.jl")
include("abstractblocksparsearray/abstractblocksparsearray.jl")
include("abstractblocksparsearray/wrappedabstractblocksparsearray.jl")
include("abstractblocksparsearray/abstractblocksparsematrix.jl")
Expand All @@ -15,7 +17,6 @@ include("abstractblocksparsearray/broadcast.jl")
include("abstractblocksparsearray/map.jl")
include("blocksparsearray/defaults.jl")
include("blocksparsearray/blocksparsearray.jl")
include("BlockArraysExtensions/BlockArraysExtensions.jl")
include("BlockArraysSparseArrayInterfaceExt/BlockArraysSparseArrayInterfaceExt.jl")
include("../ext/BlockSparseArraysTensorAlgebraExt/src/BlockSparseArraysTensorAlgebraExt.jl")
include("../ext/BlockSparseArraysGradedAxesExt/src/BlockSparseArraysGradedAxesExt.jl")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,19 +25,57 @@ end
# This is type piracy, try to avoid this, maybe requires defining `map`.
## Base.promote_shape(a1::Tuple{Vararg{BlockedUnitRange}}, a2::Tuple{Vararg{BlockedUnitRange}}) = combine_axes(a1, a2)

reblock(a) = a

# If the blocking of the slice doesn't match the blocking of the
# parent array, reblock according to the blocking of the parent array.
function reblock(
a::SubArray{<:Any,<:Any,<:AbstractBlockSparseArray,<:Tuple{Vararg{AbstractUnitRange}}}
)
# TODO: This relies on the behavior that slicing a block sparse
# array with a UnitRange inherits the blocking of the underlying
# block sparse array, we might change that default behavior
# so this might become something like `@blocked parent(a)[...]`.
return @view parent(a)[UnitRange{Int}.(parentindices(a))...]
end

function reblock(
a::SubArray{<:Any,<:Any,<:AbstractBlockSparseArray,<:Tuple{Vararg{NonBlockedArray}}}
)
return @view parent(a)[map(I -> I.array, parentindices(a))...]
end

function reblock(
a::SubArray{
<:Any,
<:Any,
<:AbstractBlockSparseArray,
<:Tuple{Vararg{BlockIndices{<:AbstractBlockVector{<:Block{1}}}}},
},
)
# Remove the blocking.
return @view parent(a)[map(I -> Vector(I.blocks), parentindices(a))...]
end

# TODO: Rewrite this so that it takes the blocking structure
# made by combining the blocking of the axes (i.e. the blocking that
# is used to determine `union_stored_blocked_cartesianindices(...)`).
# `reblock` is a partial solution to that, but a bit ad-hoc.
# TODO: Move to `blocksparsearrayinterface/map.jl`.
function SparseArrayInterface.sparse_map!(
::BlockSparseArrayStyle, f, a_dest::AbstractArray, a_srcs::Vararg{AbstractArray}
)
a_dest, a_srcs = reblock(a_dest), reblock.(a_srcs)
for I in union_stored_blocked_cartesianindices(a_dest, a_srcs...)
BI_dest = blockindexrange(a_dest, I)
BI_srcs = map(a_src -> blockindexrange(a_src, I), a_srcs)
# TODO: Investigate why this doesn't work:
# block_dest = @view a_dest[_block(BI_dest)]
block_dest = blocks(a_dest)[Int.(Tuple(_block(BI_dest)))...]
block_dest = blocks_maybe_single(a_dest)[Int.(Tuple(_block(BI_dest)))...]
# TODO: Investigate why this doesn't work:
# block_srcs = ntuple(i -> @view(a_srcs[i][_block(BI_srcs[i])]), length(a_srcs))
block_srcs = ntuple(length(a_srcs)) do i
return blocks(a_srcs[i])[Int.(Tuple(_block(BI_srcs[i])))...]
return blocks_maybe_single(a_srcs[i])[Int.(Tuple(_block(BI_srcs[i])))...]
end
subblock_dest = @view block_dest[BI_dest.indices...]
subblock_srcs = ntuple(i -> @view(block_srcs[i][BI_srcs[i].indices...]), length(a_srcs))
Expand Down
Loading

2 comments on commit d734e64

@mtfishman
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator register subdir=NDTensors

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/110179

Tip: Release Notes

Did you know you can add release notes too? Just add markdown formatted text underneath the comment after the text
"Release notes:" and it will be added to the registry PR, and if TagBot is installed it will also be added to the
release that TagBot creates. i.e.

@JuliaRegistrator register

Release notes:

## Breaking changes

- blah

To add them here just re-invoke and the PR will be updated.

Tagging

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a NDTensors-v0.3.39 -m "<description of version>" d734e640a385ffa9157d5edd4786aea98033fc0b
git push origin NDTensors-v0.3.39

Please sign in to comment.