diff --git a/Project.toml b/Project.toml index 5df2370..3259e66 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "TensorAlgebra" uuid = "68bd88dc-f39d-4e12-b2ca-f046b68fcc6a" authors = ["ITensor developers and contributors"] -version = "0.1.0" +version = "0.1.1" [deps] ArrayLayouts = "4c555306-a7a7-4459-81d9-ec55ddd5c99a" @@ -9,6 +9,7 @@ BlockArrays = "8e7c35d0-a365-5155-bbbb-fb81a777f24e" EllipsisNotation = "da5c29d0-fa7d-589e-88eb-ea29b0a81949" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" TupleTools = "9d95972d-f1c8-5527-a6e0-b4b365fa01f6" +TypeParameterAccessors = "7e5a90cf-f82e-492e-a09b-e3e26432c138" [weakdeps] GradedUnitRanges = "e2de450a-8a67-46c7-b59c-01d5a3d041c5" @@ -23,4 +24,5 @@ EllipsisNotation = "1.8.0" GradedUnitRanges = "0.1.0" LinearAlgebra = "1.10" TupleTools = "1.6.0" +TypeParameterAccessors = "0.2.1" julia = "1.10" diff --git a/src/TensorAlgebra.jl b/src/TensorAlgebra.jl index 05358a4..9035f0b 100644 --- a/src/TensorAlgebra.jl +++ b/src/TensorAlgebra.jl @@ -1,6 +1,7 @@ module TensorAlgebra include("blockedpermutation.jl") include("BaseExtensions/BaseExtensions.jl") +include("blockedtuple.jl") include("fusedims.jl") include("splitdims.jl") include("contract/contract.jl") diff --git a/src/blockedtuple.jl b/src/blockedtuple.jl new file mode 100644 index 0000000..966c76c --- /dev/null +++ b/src/blockedtuple.jl @@ -0,0 +1,107 @@ +# This file defines BlockedTuple, a Tuple of heterogeneous Tuple with a BlockArrays.jl +# like interface + +using BlockArrays: Block, BlockArrays, BlockIndexRange, BlockRange, blockedrange + +using TypeParameterAccessors: unspecify_type_parameters + +# +# ================================== AbstractBlockTuple ================================== +# +abstract type AbstractBlockTuple end + +# Base interface +Base.axes(bt::AbstractBlockTuple) = (blockedrange([blocklengths(bt)...]),) + +Base.deepcopy(bt::AbstractBlockTuple) = deepcopy.(bt) + +Base.firstindex(::AbstractBlockTuple) = 1 + +Base.getindex(bt::AbstractBlockTuple, i::Integer) = Tuple(bt)[i] +Base.getindex(bt::AbstractBlockTuple, r::AbstractUnitRange) = Tuple(bt)[r] +Base.getindex(bt::AbstractBlockTuple, b::Block{1}) = blocks(bt)[Int(b)] +function Base.getindex(bt::AbstractBlockTuple, br::BlockRange{1}) + r = Int.(br) + T = unspecify_type_parameters(typeof(bt)) + flat = Tuple(bt)[blockfirsts(bt)[first(r)]:blocklasts(bt)[last(r)]] + return T{blocklengths(bt)[r]}(flat) +end +function Base.getindex(bt::AbstractBlockTuple, bi::BlockIndexRange{1}) + return bt[Block(bi)][only(bi.indices)] +end + +Base.iterate(bt::AbstractBlockTuple) = iterate(Tuple(bt)) +Base.iterate(bt::AbstractBlockTuple, i::Int) = iterate(Tuple(bt), i) + +Base.length(bt::AbstractBlockTuple) = length(Tuple(bt)) + +Base.lastindex(bt::AbstractBlockTuple) = length(bt) + +function Base.map(f, bt::AbstractBlockTuple) + return unspecify_type_parameters(typeof(bt)){blocklengths(bt)}(map(f, Tuple(bt))) +end + +# Broadcast interface +Base.broadcastable(bt::AbstractBlockTuple) = bt +struct AbstractBlockTupleBroadcastStyle{BlockLengths,BT} <: Broadcast.BroadcastStyle end +function Base.BroadcastStyle(T::Type{<:AbstractBlockTuple}) + return AbstractBlockTupleBroadcastStyle{blocklengths(T),unspecify_type_parameters(T)}() +end + +# BroadcastStyle is not called for two identical styles +function Base.BroadcastStyle( + ::AbstractBlockTupleBroadcastStyle, ::AbstractBlockTupleBroadcastStyle +) + throw(DimensionMismatch("Incompatible blocks")) +end +function Base.copy( + bc::Broadcast.Broadcasted{AbstractBlockTupleBroadcastStyle{BlockLengths,BT}} +) where {BlockLengths,BT} + return BT{BlockLengths}(bc.f.((Tuple.(bc.args))...)) +end + +# BlockArrays interface +function BlockArrays.blockfirsts(bt::AbstractBlockTuple) + return (0, cumsum(Base.front(blocklengths(bt)))...) .+ 1 +end + +function BlockArrays.blocklasts(bt::AbstractBlockTuple) + return cumsum(blocklengths(bt)[begin:end]) +end + +BlockArrays.blocklength(bt::AbstractBlockTuple) = length(blocklengths(bt)) + +BlockArrays.blocklengths(bt::AbstractBlockTuple) = blocklengths(typeof(bt)) + +function BlockArrays.blocks(bt::AbstractBlockTuple) + bf = blockfirsts(bt) + bl = blocklasts(bt) + return ntuple(i -> Tuple(bt)[bf[i]:bl[i]], blocklength(bt)) +end + +# +# ===================================== BlockedTuple ===================================== +# +struct BlockedTuple{BlockLengths,Flat} <: AbstractBlockTuple + flat::Flat + + function BlockedTuple{BlockLengths}(flat::Tuple) where {BlockLengths} + length(flat) != sum(BlockLengths) && throw(DimensionMismatch("Invalid total length")) + return new{BlockLengths,typeof(flat)}(flat) + end +end + +# TensorAlgebra Interface +tuplemortar(tt::Tuple{Vararg{Tuple}}) = BlockedTuple{length.(tt)}(flatten_tuples(tt)) +function BlockedTuple(flat::Tuple, BlockLengths::Tuple{Vararg{Int}}) + return BlockedTuple{BlockLengths}(flat) +end +BlockedTuple(bt::AbstractBlockTuple) = BlockedTuple{blocklengths(bt)}(Tuple(bt)) + +# Base interface +Base.Tuple(bt::BlockedTuple) = bt.flat + +# BlockArrays interface +function BlockArrays.blocklengths(::Type{<:BlockedTuple{BlockLengths}}) where {BlockLengths} + return BlockLengths +end diff --git a/test/Project.toml b/test/Project.toml index 487e9ee..fe35b0f 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -1,4 +1,5 @@ [deps] +Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" BlockArrays = "8e7c35d0-a365-5155-bbbb-fb81a777f24e" Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa" EllipsisNotation = "da5c29d0-fa7d-589e-88eb-ea29b0a81949" @@ -8,16 +9,16 @@ LabelledNumbers = "f856a3a6-4152-4ec4-b2a7-02c1a55d7993" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" NDTensors = "23ae76d9-e61a-49c4-8f12-3f1a16adf9cf" Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" -TensorAlgebra = "68bd88dc-f39d-4e12-b2ca-f046b68fcc6a" -TensorOperations = "6aa20fa7-93e2-5fca-9bc0-fbd0db3c71a2" -Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f" Suppressor = "fd094767-a336-5f1f-9728-57cf17d0bbfb" +TensorAlgebra = "68bd88dc-f39d-4e12-b2ca-f046b68fcc6a" +TensorOperations = "6aa20fa7-93e2-5fca-9bc0-fbd0db3c71a2" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" +TestExtras = "5ed8adda-3752-4e41-b88a-e8b09835ee3a" [compat] -TensorOperations = "4.1.1" Aqua = "0.8.9" SafeTestsets = "0.1" Suppressor = "0.2" +TensorOperations = "5.1.3" Test = "1.10" diff --git a/test/test_blockedtuple.jl b/test/test_blockedtuple.jl new file mode 100644 index 0000000..fd599ce --- /dev/null +++ b/test/test_blockedtuple.jl @@ -0,0 +1,55 @@ +using Test: @test, @test_throws + +using BlockArrays: Block, blocklength, blocklengths, blockedrange, blockisequal, blocks +using TestExtras: @constinferred + +using TensorAlgebra: BlockedTuple, tuplemortar + +@testset "BlockedTuple" begin + flat = (true, 'a', 2, "b", 3.0) + divs = (1, 2, 2) + + bt = BlockedTuple{divs}(flat) + + @test (@constinferred Tuple(bt)) == flat + @test bt == tuplemortar(((true,), ('a', 2), ("b", 3.0))) + @test bt == BlockedTuple(flat, divs) + @test BlockedTuple(bt) == bt + @test blocklength(bt) == 3 + @test blocklengths(bt) == (1, 2, 2) + @test (@constinferred blocks(bt)) == ((true,), ('a', 2), ("b", 3.0)) + + @test (@constinferred bt[1]) == true + @test (@constinferred bt[2]) == 'a' + + # it is hard to make bt[Block(1)] type stable as compile-time knowledge of 1 is lost in Block + @test bt[Block(1)] == blocks(bt)[1] + @test bt[Block(2)] == blocks(bt)[2] + @test bt[Block(1):Block(2)] == tuplemortar(((true,), ('a', 2))) + @test bt[Block(2)[1:2]] == ('a', 2) + @test bt[2:4] == ('a', 2, "b") + + @test firstindex(bt) == 1 + @test lastindex(bt) == 5 + @test length(bt) == 5 + + @test iterate(bt) == (1, 2) + @test iterate(bt, 2) == ('a', 3) + @test blockisequal(only(axes(bt)), blockedrange([1, 2, 2])) + + @test_throws DimensionMismatch BlockedTuple{(1, 2, 3)}(flat) + + bt = tuplemortar(((1,), (4, 2), (5, 3))) + @test Tuple(bt) == (1, 4, 2, 5, 3) + @test blocklengths(bt) == (1, 2, 2) + @test deepcopy(bt) == bt + + @test (@constinferred map(n -> n + 1, bt)) == + BlockedTuple{blocklengths(bt)}(Tuple(bt) .+ 1) + @test bt .+ tuplemortar(((1,), (1, 1), (1, 1))) == + BlockedTuple{blocklengths(bt)}(Tuple(bt) .+ 1) + @test_throws DimensionMismatch bt .+ tuplemortar(((1, 1), (1, 1), (1,))) + + bt = tuplemortar(((1:2, 1:2), (1:3,))) + @test length.(bt) == tuplemortar(((2, 2), (3,))) +end