-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
5 changed files
with
171 additions
and
5 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,107 @@ | ||
# This file defines BlockedTuple, a Tuple of heterogeneous Tuple with a BlockArrays.jl | ||
# like interface | ||
|
||
using BlockArrays: Block, BlockArrays, BlockIndexRange, BlockRange, blockedrange | ||
|
||
using TypeParameterAccessors: unspecify_type_parameters | ||
|
||
# | ||
# ================================== AbstractBlockTuple ================================== | ||
# | ||
abstract type AbstractBlockTuple end | ||
|
||
# Base interface | ||
Base.axes(bt::AbstractBlockTuple) = (blockedrange([blocklengths(bt)...]),) | ||
|
||
Base.deepcopy(bt::AbstractBlockTuple) = deepcopy.(bt) | ||
|
||
Base.firstindex(::AbstractBlockTuple) = 1 | ||
|
||
Base.getindex(bt::AbstractBlockTuple, i::Integer) = Tuple(bt)[i] | ||
Base.getindex(bt::AbstractBlockTuple, r::AbstractUnitRange) = Tuple(bt)[r] | ||
Base.getindex(bt::AbstractBlockTuple, b::Block{1}) = blocks(bt)[Int(b)] | ||
function Base.getindex(bt::AbstractBlockTuple, br::BlockRange{1}) | ||
r = Int.(br) | ||
T = unspecify_type_parameters(typeof(bt)) | ||
flat = Tuple(bt)[blockfirsts(bt)[first(r)]:blocklasts(bt)[last(r)]] | ||
return T{blocklengths(bt)[r]}(flat) | ||
end | ||
function Base.getindex(bt::AbstractBlockTuple, bi::BlockIndexRange{1}) | ||
return bt[Block(bi)][only(bi.indices)] | ||
end | ||
|
||
Base.iterate(bt::AbstractBlockTuple) = iterate(Tuple(bt)) | ||
Base.iterate(bt::AbstractBlockTuple, i::Int) = iterate(Tuple(bt), i) | ||
|
||
Base.length(bt::AbstractBlockTuple) = length(Tuple(bt)) | ||
|
||
Base.lastindex(bt::AbstractBlockTuple) = length(bt) | ||
|
||
function Base.map(f, bt::AbstractBlockTuple) | ||
return unspecify_type_parameters(typeof(bt)){blocklengths(bt)}(map(f, Tuple(bt))) | ||
end | ||
|
||
# Broadcast interface | ||
Base.broadcastable(bt::AbstractBlockTuple) = bt | ||
struct AbstractBlockTupleBroadcastStyle{BlockLengths,BT} <: Broadcast.BroadcastStyle end | ||
function Base.BroadcastStyle(T::Type{<:AbstractBlockTuple}) | ||
return AbstractBlockTupleBroadcastStyle{blocklengths(T),unspecify_type_parameters(T)}() | ||
end | ||
|
||
# BroadcastStyle is not called for two identical styles | ||
function Base.BroadcastStyle( | ||
::AbstractBlockTupleBroadcastStyle, ::AbstractBlockTupleBroadcastStyle | ||
) | ||
throw(DimensionMismatch("Incompatible blocks")) | ||
end | ||
function Base.copy( | ||
bc::Broadcast.Broadcasted{AbstractBlockTupleBroadcastStyle{BlockLengths,BT}} | ||
) where {BlockLengths,BT} | ||
return BT{BlockLengths}(bc.f.((Tuple.(bc.args))...)) | ||
end | ||
|
||
# BlockArrays interface | ||
function BlockArrays.blockfirsts(bt::AbstractBlockTuple) | ||
return (0, cumsum(Base.front(blocklengths(bt)))...) .+ 1 | ||
end | ||
|
||
function BlockArrays.blocklasts(bt::AbstractBlockTuple) | ||
return cumsum(blocklengths(bt)[begin:end]) | ||
end | ||
|
||
BlockArrays.blocklength(bt::AbstractBlockTuple) = length(blocklengths(bt)) | ||
|
||
BlockArrays.blocklengths(bt::AbstractBlockTuple) = blocklengths(typeof(bt)) | ||
|
||
function BlockArrays.blocks(bt::AbstractBlockTuple) | ||
bf = blockfirsts(bt) | ||
bl = blocklasts(bt) | ||
return ntuple(i -> Tuple(bt)[bf[i]:bl[i]], blocklength(bt)) | ||
end | ||
|
||
# | ||
# ===================================== BlockedTuple ===================================== | ||
# | ||
struct BlockedTuple{BlockLengths,Flat} <: AbstractBlockTuple | ||
flat::Flat | ||
|
||
function BlockedTuple{BlockLengths}(flat::Tuple) where {BlockLengths} | ||
length(flat) != sum(BlockLengths) && throw(DimensionMismatch("Invalid total length")) | ||
return new{BlockLengths,typeof(flat)}(flat) | ||
end | ||
end | ||
|
||
# TensorAlgebra Interface | ||
tuplemortar(tt::Tuple{Vararg{Tuple}}) = BlockedTuple{length.(tt)}(flatten_tuples(tt)) | ||
function BlockedTuple(flat::Tuple, BlockLengths::Tuple{Vararg{Int}}) | ||
return BlockedTuple{BlockLengths}(flat) | ||
end | ||
BlockedTuple(bt::AbstractBlockTuple) = BlockedTuple{blocklengths(bt)}(Tuple(bt)) | ||
|
||
# Base interface | ||
Base.Tuple(bt::BlockedTuple) = bt.flat | ||
|
||
# BlockArrays interface | ||
function BlockArrays.blocklengths(::Type{<:BlockedTuple{BlockLengths}}) where {BlockLengths} | ||
return BlockLengths | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
using Test: @test, @test_throws | ||
|
||
using BlockArrays: Block, blocklength, blocklengths, blockedrange, blockisequal, blocks | ||
using TestExtras: @constinferred | ||
|
||
using TensorAlgebra: BlockedTuple, tuplemortar | ||
|
||
@testset "BlockedTuple" begin | ||
flat = (true, 'a', 2, "b", 3.0) | ||
divs = (1, 2, 2) | ||
|
||
bt = BlockedTuple{divs}(flat) | ||
|
||
@test (@constinferred Tuple(bt)) == flat | ||
@test bt == tuplemortar(((true,), ('a', 2), ("b", 3.0))) | ||
@test bt == BlockedTuple(flat, divs) | ||
@test BlockedTuple(bt) == bt | ||
@test blocklength(bt) == 3 | ||
@test blocklengths(bt) == (1, 2, 2) | ||
@test (@constinferred blocks(bt)) == ((true,), ('a', 2), ("b", 3.0)) | ||
|
||
@test (@constinferred bt[1]) == true | ||
@test (@constinferred bt[2]) == 'a' | ||
|
||
# it is hard to make bt[Block(1)] type stable as compile-time knowledge of 1 is lost in Block | ||
@test bt[Block(1)] == blocks(bt)[1] | ||
@test bt[Block(2)] == blocks(bt)[2] | ||
@test bt[Block(1):Block(2)] == tuplemortar(((true,), ('a', 2))) | ||
@test bt[Block(2)[1:2]] == ('a', 2) | ||
@test bt[2:4] == ('a', 2, "b") | ||
|
||
@test firstindex(bt) == 1 | ||
@test lastindex(bt) == 5 | ||
@test length(bt) == 5 | ||
|
||
@test iterate(bt) == (1, 2) | ||
@test iterate(bt, 2) == ('a', 3) | ||
@test blockisequal(only(axes(bt)), blockedrange([1, 2, 2])) | ||
|
||
@test_throws DimensionMismatch BlockedTuple{(1, 2, 3)}(flat) | ||
|
||
bt = tuplemortar(((1,), (4, 2), (5, 3))) | ||
@test Tuple(bt) == (1, 4, 2, 5, 3) | ||
@test blocklengths(bt) == (1, 2, 2) | ||
@test deepcopy(bt) == bt | ||
|
||
@test (@constinferred map(n -> n + 1, bt)) == | ||
BlockedTuple{blocklengths(bt)}(Tuple(bt) .+ 1) | ||
@test bt .+ tuplemortar(((1,), (1, 1), (1, 1))) == | ||
BlockedTuple{blocklengths(bt)}(Tuple(bt) .+ 1) | ||
@test_throws DimensionMismatch bt .+ tuplemortar(((1, 1), (1, 1), (1,))) | ||
|
||
bt = tuplemortar(((1:2, 1:2), (1:3,))) | ||
@test length.(bt) == tuplemortar(((2, 2), (3,))) | ||
end |