Skip to content

Commit

Permalink
[NDTensors] Introduce LabelledNumbers and GradedAxesNext (#1351)
Browse files Browse the repository at this point in the history
  • Loading branch information
mtfishman committed Mar 15, 2024
1 parent 638624f commit 12fbcc2
Show file tree
Hide file tree
Showing 16 changed files with 592 additions and 0 deletions.
2 changes: 2 additions & 0 deletions NDTensors/src/imports.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ for lib in [
:BroadcastMapConversion,
:RankFactorization,
:Sectors,
:LabelledNumbers,
:GradedAxesNext,
:GradedAxes,
:TensorAlgebra,
:SparseArrayInterface,
Expand Down
2 changes: 2 additions & 0 deletions NDTensors/src/lib/GradedAxesNext/.JuliaFormatter.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
style = "blue"
indent = 2
3 changes: 3 additions & 0 deletions NDTensors/src/lib/GradedAxesNext/src/GradedAxesNext.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
module GradedAxesNext
include("gradedunitrange.jl")
end
245 changes: 245 additions & 0 deletions NDTensors/src/lib/GradedAxesNext/src/gradedunitrange.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,245 @@
using BlockArrays:
BlockArrays,
Block,
BlockedUnitRange,
BlockRange,
BlockVector,
blockedrange,
BlockIndexRange,
blockfirsts,
blocklasts,
blocklengths,
findblock,
findblockindex,
mortar
using ..LabelledNumbers: LabelledNumbers, LabelledInteger, label, labelled, unlabel

# Custom `BlockedUnitRange` constructor that takes a unit range
# and a set of block lengths, similar to `BlockArray(::AbstractArray, blocklengths...)`.
function blockedunitrange(a::AbstractUnitRange, blocklengths)
blocklengths_shifted = copy(blocklengths)
blocklengths_shifted[1] += (first(a) - 1)
blocklasts = cumsum(blocklengths_shifted)
return BlockArrays._BlockedUnitRange(first(a), blocklasts)
end

# Circumvents issue in `findblock` that assumes the `BlockedUnitRange`
# starts at 1.
# TODO: Raise an issue with `BlockArrays`.
function blockedunitrange_findblock(a::BlockedUnitRange, index::Integer)
@boundscheck index in 1:length(a) || throw(BoundsError(a, index))
return @inbounds findblock(a, index + first(a) - 1)
end

# Circumvents issue in `findblockindex` that assumes the `BlockedUnitRange`
# starts at 1.
# TODO: Raise an issue with `BlockArrays`.
function blockedunitrange_findblockindex(a::BlockedUnitRange, index::Integer)
@boundscheck index in 1:length(a) || throw(BoundsError())
return @inbounds findblockindex(a, index + first(a) - 1)
end

const GradedUnitRange{BlockLasts<:Vector{<:LabelledInteger}} = BlockedUnitRange{BlockLasts}

function gradedrange(lblocklengths::AbstractVector{<:LabelledInteger})
brange = blockedrange(unlabel.(lblocklengths))
lblocklasts = labelled.(blocklasts(brange), label.(lblocklengths))
# TODO: `first` is forced to be `Int` in `BlockArrays.BlockedUnitRange`,
# so this doesn't do anything right now. Make a PR to generalize it.
firstlength = first(lblocklengths)
lfirst = oneunit(firstlength)
return BlockArrays._BlockedUnitRange(lfirst, lblocklasts)
end

Base.last(a::GradedUnitRange) = isempty(a.lasts) ? first(a) - 1 : last(a.lasts)

function gradedrange(lblocklengths::AbstractVector{<:Pair{<:Any,<:Integer}})
return gradedrange(labelled.(last.(lblocklengths), first.(lblocklengths)))
end

function labelled_blocks(a::BlockedUnitRange, labels)
return BlockArrays._BlockedUnitRange(a.first, labelled.(a.lasts, labels))
end

function BlockArrays.findblock(a::GradedUnitRange, index::Integer)
return blockedunitrange_findblock(unlabel_blocks(a), index)
end

function blockedunitrange_findblock(a::GradedUnitRange, index::Integer)
return blockedunitrange_findblock(unlabel_blocks(a), index)
end

function blockedunitrange_findblockindex(a::GradedUnitRange, index::Integer)
return blockedunitrange_findblockindex(unlabel_blocks(a), index)
end

function BlockArrays.findblockindex(a::GradedUnitRange, index::Integer)
return blockedunitrange_findblockindex(unlabel_blocks(a), index)
end

## Block label interface

# Internal function
function get_label(a::BlockedUnitRange, index::Block{1})
return label(blocklasts(a)[Int(index)])
end

# Internal function
function get_label(a::BlockedUnitRange, index::Integer)
return get_label(a, blockedunitrange_findblock(a, index))
end

function blocklabels(a::BlockVector)
return map(BlockRange(a)) do block
return label(@view(a[block]))
end
end

function blocklabels(a::BlockedUnitRange)
# Using `a.lasts` here since that is what is stored
# inside of `BlockedUnitRange`, maybe change that.
# For example, it could be something like:
#
# map(BlockRange(a)) do block
# return label(@view(a[block]))
# end
#
return label.(a.lasts)
end

# TODO: This relies on internals of `BlockArrays`, maybe redesign
# to try to avoid that.
# TODO: Define `set_grades`, `set_sector_labels`, `set_labels`.
function unlabel_blocks(a::BlockedUnitRange)
return BlockArrays._BlockedUnitRange(a.first, unlabel.(a.lasts))
end

## BlockedUnitRage interface

function Base.axes(ga::GradedUnitRange)
return map(axes(unlabel_blocks(ga))) do a
return labelled_blocks(a, blocklabels(ga))
end
end

function BlockArrays.blockfirsts(a::GradedUnitRange)
return labelled.(blockfirsts(unlabel_blocks(a)), blocklabels(a))
end

function BlockArrays.blocklasts(a::GradedUnitRange)
return labelled.(blocklasts(unlabel_blocks(a)), blocklabels(a))
end

function BlockArrays.blocklengths(a::GradedUnitRange)
return labelled.(blocklengths(unlabel_blocks(a)), blocklabels(a))
end

function Base.first(a::GradedUnitRange)
return labelled(first(unlabel_blocks(a)), label(a[Block(1)]))
end

function firstblockindices(a::GradedUnitRange)
return labelled.(firstblockindices(unlabel_blocks(a)), blocklabels(a))
end

function blockedunitrange_getindex(a::GradedUnitRange, index)
# This uses `blocklasts` since that is what is stored
# in `BlockedUnitRange`, maybe abstract that away.
return labelled(unlabel_blocks(a)[index], get_label(a, index))
end

# Like `a[indices]` but preserves block structure.
using BlockArrays: block, blockindex
function blockedunitrange_getindices(
a::BlockedUnitRange, indices::AbstractUnitRange{<:Integer}
)
first_blockindex = blockedunitrange_findblockindex(a, first(indices))
last_blockindex = blockedunitrange_findblockindex(a, last(indices))
first_block = block(first_blockindex)
last_block = block(last_blockindex)
blocklengths = if first_block == last_block
[length(indices)]
else
map(first_block:last_block) do block
if block == first_block
return length(a[first_block]) - blockindex(first_blockindex) + 1
end
if block == last_block
return blockindex(last_blockindex)
end
return length(a[block])
end
end
return blockedunitrange(indices .+ (first(a) - 1), blocklengths)
end

function blockedunitrange_getindices(a::BlockedUnitRange, indices::BlockIndexRange)
return a[block(indices)][only(indices.indices)]
end

function blockedunitrange_getindices(a::BlockedUnitRange, indices::Vector{<:Integer})
return map(index -> a[index], indices)
end

function blockedunitrange_getindices(
a::BlockedUnitRange, indices::Vector{<:Union{Block{1},BlockIndexRange{1}}}
)
return mortar(map(index -> a[index], indices))
end

function blockedunitrange_getindices(a::BlockedUnitRange, indices)
return error("Not implemented.")
end

# The blocks of the corresponding slice.
_blocks(a::AbstractUnitRange, indices) = error("Not implemented")
function _blocks(a::AbstractUnitRange, indices::AbstractUnitRange)
return findblock(a, first(indices)):findblock(a, last(indices))
end
function _blocks(a::AbstractUnitRange, indices::BlockRange)
return indices
end

# The block labels of the corresponding slice.
function blocklabels(a::AbstractUnitRange, indices)
return map(_blocks(a, indices)) do block
return label(a[block])
end
end

function blockedunitrange_getindices(
ga::GradedUnitRange, indices::AbstractUnitRange{<:Integer}
)
a_indices = blockedunitrange_getindices(unlabel_blocks(ga), indices)
return labelled_blocks(a_indices, blocklabels(ga, indices))
end

function blockedunitrange_getindices(ga::GradedUnitRange, indices::BlockRange)
return labelled_blocks(unlabel_blocks(ga)[indices], blocklabels(ga, indices))
end

function Base.getindex(a::GradedUnitRange, index::Integer)
return blockedunitrange_getindex(a, index)
end

function Base.getindex(a::GradedUnitRange, index::Block{1})
return blockedunitrange_getindex(a, index)
end

function Base.getindex(a::GradedUnitRange, indices::BlockIndexRange)
return blockedunitrange_getindices(a, indices)
end

function Base.getindex(
a::GradedUnitRange, indices::BlockRange{1,<:Tuple{AbstractUnitRange{Int}}}
)
return blockedunitrange_getindices(a, indices)
end

function Base.getindex(a::GradedUnitRange, indices)
return blockedunitrange_getindices(a, indices)
end

function Base.getindex(a::GradedUnitRange, indices::AbstractUnitRange{<:Integer})
return blockedunitrange_getindices(a, indices)
end
4 changes: 4 additions & 0 deletions NDTensors/src/lib/GradedAxesNext/test/Project.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
[deps]
BlockArrays = "8e7c35d0-a365-5155-bbbb-fb81a777f24e"
NDTensors = "23ae76d9-e61a-49c4-8f12-3f1a16adf9cf"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
107 changes: 107 additions & 0 deletions NDTensors/src/lib/GradedAxesNext/test/runtests.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
@eval module $(gensym())
using BlockArrays:
Block, BlockVector, blockedrange, blockfirsts, blocklasts, blocklength, blocklengths
using NDTensors.GradedAxesNext: GradedUnitRange, blocklabels, gradedrange
using NDTensors.LabelledNumbers: LabelledUnitRange, label, unlabel
using Test: @test, @test_broken, @testset
@testset "GradedAxes" begin
a = gradedrange(["x" => 2, "y" => 3])
@test a isa GradedUnitRange
@test length(a) == 5
@test a[Block(2)] == 3:5
@test label(a[Block(2)]) == "y"
@test a[Block(2)] isa LabelledUnitRange
@test a[4] == 4
@test label(a[4]) == "y"
@test unlabel(a[4]) == 4
@test blocklengths(a) == [2, 3]
@test blocklabels(a) == ["x", "y"]
@test label.(blocklengths(a)) == ["x", "y"]
@test blockfirsts(a) == [1, 3]
@test label.(blockfirsts(a)) == ["x", "y"]
@test first(a) == 1
@test label(first(a)) == "x"
@test blocklasts(a) == [2, 5]
@test label.(blocklasts(a)) == ["x", "y"]
@test last(a) == 5
@test label(last(a)) == "y"
@test a[Block(2)] == 3:5
@test label(a[Block(2)]) == "y"
@test length(a[Block(2)]) == 3
@test blocklengths(only(axes(a))) == blocklengths(a)
@test blocklabels(only(axes(a))) == blocklabels(a)

# Slicing operations
x = gradedrange(["x" => 2, "y" => 3])
a = x[2:4]
@test a isa GradedUnitRange
@test length(a) == 3
@test blocklength(a) == 2
@test a[Block(1)] == 2:2
@test label(a[Block(1)]) == "x"
@test a[Block(2)] == 3:4
@test label(a[Block(2)]) == "y"
@test isone(first(only(axes(a))))
@test length(only(axes(a))) == length(a)
@test blocklengths(only(axes(a))) == blocklengths(a)

x = gradedrange(["x" => 2, "y" => 3])
a = x[3:4]
@test a isa GradedUnitRange
@test length(a) == 2
@test blocklength(a) == 1
@test a[Block(1)] == 3:4
@test label(a[Block(1)]) == "y"

x = gradedrange(["x" => 2, "y" => 3])
a = x[2:4][1:2]
@test a isa GradedUnitRange
@test length(a) == 2
@test blocklength(a) == 2
@test a[Block(1)] == 2:2
@test label(a[Block(1)]) == "x"
@test a[Block(2)] == 3:3
@test label(a[Block(2)]) == "y"

x = gradedrange(["x" => 2, "y" => 3])
a = x[Block(2)[2:3]]
@test a isa LabelledUnitRange
@test length(a) == 2
@test a == 4:5
@test label(a) == "y"

x = gradedrange(["x" => 2, "y" => 3, "z" => 4])
a = x[Block(2):Block(3)]
@test a isa GradedUnitRange
@test length(a) == 7
@test blocklength(a) == 2
@test blocklengths(a) == [3, 4]
@test blocklabels(a) == ["y", "z"]
@test a[Block(1)] == 3:5
@test a[Block(2)] == 6:9

x = gradedrange(["x" => 2, "y" => 3, "z" => 4])
a = x[[Block(3), Block(2)]]
@test a isa BlockVector
@test length(a) == 7
@test blocklength(a) == 2
# TODO: `BlockArrays` doesn't define `blocklengths`
# for `BlockVector`, should it?
@test_broken blocklengths(a) == [4, 3]
@test blocklabels(a) == ["z", "y"]
@test a[Block(1)] == 6:9
@test a[Block(2)] == 3:5

x = gradedrange(["x" => 2, "y" => 3, "z" => 4])
a = x[[Block(3)[2:3], Block(2)[2:3]]]
@test a isa BlockVector
@test length(a) == 4
@test blocklength(a) == 2
# TODO: `BlockArrays` doesn't define `blocklengths`
# for `BlockVector`, should it?
@test_broken blocklengths(a) == [2, 2]
@test blocklabels(a) == ["z", "y"]
@test a[Block(1)] == 7:8
@test a[Block(2)] == 4:5
end
end
2 changes: 2 additions & 0 deletions NDTensors/src/lib/LabelledNumbers/.JuliaFormatter.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
style = "blue"
indent = 2
7 changes: 7 additions & 0 deletions NDTensors/src/lib/LabelledNumbers/src/LabelledNumbers.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
module LabelledNumbers
include("labelled_interface.jl")
include("labellednumber.jl")
include("labelledinteger.jl")
include("labelledarray.jl")
include("labelledunitrange.jl")
end
Loading

0 comments on commit 12fbcc2

Please sign in to comment.