Skip to content

Commit

Permalink
Fix some broken tests
Browse files Browse the repository at this point in the history
  • Loading branch information
mtfishman committed Jun 30, 2024
1 parent 17e274a commit 29c5fbe
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,22 @@ function Base.getindex(S::BlockIndices, i::BlockSlice{<:BlockRange{1}})
return BlockIndices(subblocks, subindices)
end

# Used when performing slices like:
# @views a[[Block(2), Block(1)]][2:4, 2:4]
function Base.getindex(S::BlockIndices, i::BlockSlice{<:BlockVector{<:BlockIndex{1}}})
subblocks = mortar(
map(blocks(i.block)) do br
return S.blocks[Int(Block(br))][only(br.indices)]
end,
)
subindices = mortar(
map(blocks(i.block)) do br
S.indices[br]
end,
)
return BlockIndices(subblocks, subindices)
end

# Similar to the definition of `BlockArrays.BlockSlices`:
# ```julia
# const BlockSlices = Union{Base.Slice,BlockSlice{<:BlockRange{1}}}
Expand Down
18 changes: 9 additions & 9 deletions NDTensors/src/lib/BlockSparseArrays/test/test_basics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -650,18 +650,18 @@ include("TestBlockSparseArraysUtils.jl")
@test block_nstored(c) == 2
@test blocksize(c) == (2, 2)
@test blocklengths.(axes(c)) == ([2, 3], [2, 3])
@test_broken size(c[Block(1, 1)]) == (2, 2)
@test_broken c[Block(1, 1)] == a[Block(2, 2)[2:3, 2:3]]
@test_broken size(c[Block(2, 2)]) == (3, 3)
@test_broken c[Block(2, 2)] == a[Block(1, 1)[1:3, 1:3]]
@test_broken size(c[Block(2, 1)]) == (3, 2)
@test_broken iszero(c[Block(2, 1)])
@test_broken size(c[Block(1, 2)]) == (2, 3)
@test_broken iszero(c[Block(1, 2)])
@test size(c[Block(1, 1)]) == (2, 2)
@test c[Block(1, 1)] == a[Block(2, 2)[2:3, 2:3]]
@test size(c[Block(2, 2)]) == (3, 3)
@test c[Block(2, 2)] == a[Block(1, 1)[1:3, 1:3]]
@test size(c[Block(2, 1)]) == (3, 2)
@test iszero(c[Block(2, 1)])
@test size(c[Block(1, 2)]) == (2, 3)
@test iszero(c[Block(1, 2)])

x = randn(elt, 3, 3)
c[Block(2, 2)] = x
@test_broken c[Block(2, 2)] == x
@test c[Block(2, 2)] == x
@test a[Block(1, 1)[1:3, 1:3]] == x

a = BlockSparseArray{elt}([2, 3], [3, 4])
Expand Down

0 comments on commit 29c5fbe

Please sign in to comment.