From c931458357220f84cf5b38a43e3c2156f6ea1801 Mon Sep 17 00:00:00 2001 From: Jutho Haegeman Date: Fri, 1 Dec 2023 15:52:57 +0100 Subject: [PATCH 1/5] rewrite constructbraidingtensor preprocessor --- src/planar/preprocessors.jl | 234 +++++++++++++++++++--------------- src/tensors/braidingtensor.jl | 4 + test/planar.jl | 35 +++++ 3 files changed, 171 insertions(+), 102 deletions(-) diff --git a/src/planar/preprocessors.jl b/src/planar/preprocessors.jl index 6c284ad9..3d6abecc 100644 --- a/src/planar/preprocessors.jl +++ b/src/planar/preprocessors.jl @@ -68,129 +68,159 @@ _is_adjoint(ex) = isexpr(ex, TO.prime) _remove_adjoint(ex) = _is_adjoint(ex) ? ex.args[1] : ex _add_adjoint(ex) = Expr(TO.prime, ex) -# used by `@planar`: realize explicit braiding tensors +# used by `@planar`: identify braiding tensors (corresponding to name τ) and discover their +# spaces from the rest of the expression. Construct the explicit BraidingTensor objects and +# insert them in the expression. function _construct_braidingtensors(ex::Expr) if TO.isdefinition(ex) || TO.isassignment(ex) lhs, rhs = TO.getlhs(ex), TO.getrhs(ex) - if TO.istensorexpr(rhs) - list = TO.gettensors(_conj_to_adjoint(rhs)) - if TO.isassignment(ex) && TO.istensor(lhs) - obj, l, r = TO.decomposetensor(lhs) - lhs_as_rhs = Expr(:typed_vcat, _add_adjoint(obj), - Expr(:tuple, r...), Expr(:tuple, l...)) - push!(list, lhs_as_rhs) - end - else + if !TO.istensorexpr(rhs) return ex end + preargs = Vector{Any}() + indexmap = Dict{Any,Any}() + if TO.isassignment(ex) && TO.istensor(lhs) + obj, leftind, rightind = TO.decomposetensor(lhs) + for (i, l) in enumerate(leftind) + indexmap[l] = Expr(:call, :space, _add_adjoint(obj), i) + end + for (i, l) in enumerate(rightind) + indexmap[l] = Expr(:call, :space, _add_adjoint(obj), length(leftind) + i) + end + end + newrhs, success = _construct_braidingtensors!(rhs, preargs, indexmap) + success || + throw(ArgumentError("cannot determine the spaces of all braiding tensors in $ex")) + pre = Expr(:macrocall, Symbol("@notensor"), + LineNumberNode(@__LINE__, Symbol(@__FILE__)), Expr(:block, preargs...)) + return Expr(:block, pre, Expr(ex.head, lhs, newrhs)) elseif TO.istensorexpr(ex) - list = TO.gettensors(_conj_to_adjoint(ex)) + preargs = Vector{Any}() + indexmap = Dict{Any,Any}() + newex, success = _construct_braidingtensors!(ex, preargs, indexmap) + success || + throw(ArgumentError("cannot determine the spaces of all braiding tensors in $ex")) + pre = Expr(:macrocall, Symbol("@notensor"), + LineNumberNode(@__LINE__, Symbol(@__FILE__)), Expr(:block, preargs...)) + return Expr(:block, pre, newex) else return Expr(ex.head, map(_construct_braidingtensors, ex.args)...) end +end +_construct_braidingtensors(x) = x - i = 1 - translatebraidings = Dict{Any,Any}() - while i <= length(list) - t = list[i] - if _remove_adjoint(TO.gettensorobject(t)) == :τ - translatebraidings[t] = Expr(:call, GlobalRef(TensorKit, :BraidingTensor)) - deleteat!(list, i) - else - i += 1 - end - end - - unresolved = Any[] # list of indices that we couldn't yet figure out - indexmap = Dict{Any,Any}() - # indexmap[i] contains the expression to resolve the space for index i - for (t, construct_expr) in translatebraidings - obj, leftind, rightind = TO.decomposetensor(t) - length(leftind) == length(rightind) == 2 || - throw(ArgumentError("The name τ is reserved for the braiding, and should have two input and two output indices.")) - if _is_adjoint(obj) - i1b, i2b, = leftind - i2a, i1a, = rightind - else - i2b, i1b, = leftind - i1a, i2a, = rightind - end - - obj_and_pos1a = _findindex(i1a, list) - obj_and_pos2a = _findindex(i2a, list) - obj_and_pos1b = _findindex(i1b, list) - obj_and_pos2b = _findindex(i2b, list) +function _construct_braidingtensors!(ex, preargs, indexmap) # ex is guaranteed to be a single tensor expression + if TO.istensor(ex) + obj, leftind, rightind = TO.decomposetensor(ex) + if _remove_adjoint(obj) == :τ + # try to construct a braiding tensor + length(leftind) == length(rightind) == 2 || + throw(ArgumentError("The name τ is reserved for the braiding, and should have two input and two output indices.")) + if _is_adjoint(obj) + i1b, i2b, = leftind + i2a, i1a, = rightind + else + i2b, i1b, = leftind + i1a, i2a, = rightind + end - if !isnothing(obj_and_pos1a) - indexmap[i1b] = Expr(:call, :space, obj_and_pos1a...) - indexmap[i1a] = Expr(:call, :space, obj_and_pos1a...) - elseif !isnothing(obj_and_pos1b) - indexmap[i1b] = Expr(TO.prime, Expr(:call, :space, obj_and_pos1b...)) - indexmap[i1a] = Expr(TO.prime, Expr(:call, :space, obj_and_pos1b...)) + foundV1, foundV2 = false, false + if haskey(indexmap, i1a) + V1 = indexmap[i1a] + foundV1 = true + elseif haskey(indexmap, i1b) + V1 = Expr(:call, :dual, indexmap[i1b]) + foundV1 = true + end + if haskey(indexmap, i2a) + V2 = indexmap[i2a] + foundV2 = true + elseif haskey(indexmap, i2b) + V2 = Expr(:call, :dual, indexmap[i2b]) + foundV2 = true + end + if foundV1 && foundV2 + s = gensym(:τ) + constructex = Expr(:call, GlobalRef(TensorKit, :BraidingTensor), V1, V2) + push!(preargs, Expr(:(=), s, constructex)) + obj = _is_adjoint(obj) ? _add_adjoint(s) : s + success = true + else + success = false + end + newex = Expr(:typed_vcat, obj, Expr(:tuple, leftind...), + Expr(:tuple, rightind...)) else - push!(unresolved, (i1a, i1b)) + newex = ex + success = true end - - if !isnothing(obj_and_pos2a) - indexmap[i2b] = Expr(:call, :space, obj_and_pos2a...) - indexmap[i2a] = Expr(:call, :space, obj_and_pos2a...) - elseif !isnothing(obj_and_pos2b) - indexmap[i2b] = Expr(TO.prime, Expr(:call, :space, obj_and_pos2b...)) - indexmap[i2a] = Expr(TO.prime, Expr(:call, :space, obj_and_pos2b...)) - else - push!(unresolved, (i2a, i2b)) + # add spaces of the tensor to the indexmap + for (i, l) in enumerate(leftind) + if !haskey(indexmap, l) + indexmap[l] = Expr(:call, :space, obj, i) + end end - end - # simple loop that tries to resolve as many indices as possible - changed = true - while changed == true - changed = false - i = 1 - while i <= length(unresolved) - (i1, i2) = unresolved[i] - if i1 in keys(indexmap) - changed = true - indexmap[i2] = indexmap[i1] - deleteat!(unresolved, i) - elseif i2 in keys(indexmap) - changed = true - indexmap[i1] = indexmap[i2] - deleteat!(unresolved, i) - else - i += 1 + for (i, l) in enumerate(rightind) + if !haskey(indexmap, l) + indexmap[l] = Expr(:call, :space, obj, length(leftind) + i) end end - end - !isempty(unresolved) && - throw(ArgumentError("cannot determine the spaces of indices " * - string(tuple(unresolved...)) * - "for the braiding tensors in $(ex)")) - - pre = Expr(:block) - for (t, construct_expr) in translatebraidings - obj, leftind, rightind = TO.decomposetensor(t) - if _is_adjoint(obj) - i1b, i2b, = leftind - i2a, i1a, = rightind - else - i2b, i1b, = leftind - i1a, i2a, = rightind + return newex, success + elseif TO.isgeneraltensor(ex) + args = ex.args + newargs = Vector{Any}(undef, length(args)) + success = true + for i in 1:length(ex.args) + newargs[i], successa = _construct_braidingtensors!(args[i], preargs, indexmap) + success = success && successa end - push!(construct_expr.args, indexmap[i1b]) - push!(construct_expr.args, indexmap[i2b]) - s = gensym(:τ) - push!(pre.args, :(@notensor $s = $construct_expr)) - ex = TO.replacetensorobjects(ex) do o, l, r - if o == obj && l == leftind && r == rightind - return obj == :τ ? s : Expr(TO.prime, s) - else - return o + newex = Expr(ex.head, newargs...) + return newex, success + elseif isexpr(ex, :call) && ex.args[1] == :* + args = ex.args + newargs = Vector{Any}(undef, length(args)) + newargs[1] = args[1] + successes = map(i -> false, args) + successes[1] = true + numsuccess = 1 + while !all(successes) + for i in 2:length(ex.args) + successes[i] && continue + newargs[i], successa = _construct_braidingtensors!(args[i], preargs, + indexmap) + successes[i] = successa + end + if numsuccess == count(successes) + break + end + numsuccess = count(successes) + end + success = numsuccess == length(successes) + newex = Expr(ex.head, newargs...) + return newex, success + elseif isexpr(ex, :call) && ex.args[1] ∈ (:+, :-) + args = ex.args + newargs = Vector{Any}(undef, length(args)) + newargs[1] = args[1] + success = true + indices = TO.getindices(ex) + for i in 2:length(ex.args) + indexmapa = copy(indexmap) + newargs[i], successa = _construct_braidingtensors!(args[i], preargs, indexmapa) + for l in indices[i] + if !haskey(indexmap, l) && haskey(indexmapa, l) + indexmap[l] = indexmapa[l] + end end + success = success && successa end + newex = Expr(ex.head, newargs...) + return newex, success + else + @show("huh?") + return ex, true end - return Expr(:block, pre, ex) end -_construct_braidingtensors(x) = x # used by non-planar parser of `@plansor`: remove explicit braiding tensors function _remove_braidingtensors(ex) diff --git a/src/tensors/braidingtensor.jl b/src/tensors/braidingtensor.jl index 785ebf49..2ee46f66 100644 --- a/src/tensors/braidingtensor.jl +++ b/src/tensors/braidingtensor.jl @@ -3,9 +3,13 @@ #====================================================================# """ struct BraidingTensor{S<:IndexSpace} <: AbstractTensorMap{S, 2, 2} + BraidingTensor(V1::S, V2::S, adjoint::Bool=false) where {S<:IndexSpace} Specific subtype of [`AbstractTensorMap`](@ref) for representing the braiding tensor that braids the first input over the second input; its inverse can be obtained as the adjoint. + +It holds that `domain(BraidingTensor(V1, V2)) == V1 ⊗ V2` and +`codomain(BraidingTensor(V1, V2)) == V2 ⊗ V1`. """ struct BraidingTensor{S<:IndexSpace,A} <: AbstractTensorMap{S,2,2} V1::S diff --git a/test/planar.jl b/test/planar.jl index bf3efc19..755f6a46 100644 --- a/test/planar.jl +++ b/test/planar.jl @@ -140,6 +140,23 @@ end conj(x′[1 3; -1]) @test force_planar(ρ2) ≈ ρ2′ @test ρ2 ≈ ρ3 + + # Periodic boundary conditions + # ---------------------------- + f1 = isomorphism(storagetype(O), fuse(Vmpo^3), Vmpo ⊗ Vmpo' ⊗ Vmpo) + f2 = isomorphism(storagetype(O), fuse(Vmpo^3), Vmpo ⊗ Vmpo' ⊗ Vmpo) + f1′ = force_planar(f1) + f2′ = force_planar(f2) + @tensor O_periodic1[-1 -2; -3 -4] := O[1 -2; -3 2] * f1[-1; 1 3 4] * + conj(f2[-4; 2 3 4]) + @plansor O_periodic2[-1 -2; -3 -4] := O[1 2; -3 6] * f1[-1; 1 3 5] * + conj(f2[-4; 6 7 8]) * τ[2 3; 7 4] * + τ[4 5; 8 -2] + @planar O_periodic′[-1 -2; -3 -4] := O′[1 2; -3 6] * f1′[-1; 1 3 5] * + conj(f2′[-4; 6 7 8]) * τ[2 3; 7 4] * + τ[4 5; 8 -2] + @test O_periodic1 ≈ O_periodic2 + @test force_planar(O_periodic1) ≈ O_periodic′ end @testset "MERA networks" begin @@ -171,4 +188,22 @@ end end @test C ≈ C′ end + + @testset "Issue 93" begin + T = Float64 + V1 = ℂ^2 + V2 = ℂ^3 + t1 = TensorMap(rand, T, V1 ← V2) + t2 = TensorMap(rand, T, V2 ← V1) + + tr1 = @planar opt = true t1[a; b] * t2[b; a] + tr2 = @planar opt = true t1[d; a] * t2[b; c] * τ[c b; a d] + tr3 = @planar opt = true t1[d; a] * t2[b; c] * τ[a c; d b] + tr4 = @planar opt = true t1[f; a] * t2[c; d] * τ[d b; c e] * τ[e b; a f] + tr5 = @planar opt = true t1[f; a] * t2[c; d] * τ[d b; c e] * τ[a e; f b] + tr6 = @planar opt = true t1[f; a] * t2[c; d] * τ[c d; e b] * τ[e b; a f] + tr7 = @planar opt = true t1[f; a] * t2[c; d] * τ[c d; e b] * τ[a e; f b] + + @test tr1 ≈ tr2 ≈ tr3 ≈ tr4 ≈ tr5 ≈ tr6 ≈ tr7 + end end From f067ed48f2e67e14e0a22f85eb39a866530675ff Mon Sep 17 00:00:00 2001 From: Jutho Haegeman Date: Fri, 1 Dec 2023 17:04:14 +0100 Subject: [PATCH 2/5] some more planar fixes --- src/planar/analyzers.jl | 4 ++++ src/planar/preprocessors.jl | 21 +++++++++++++++++---- src/tensors/braidingtensor.jl | 23 +++++++++++++++++++++++ test/planar.jl | 14 +++++++------- 4 files changed, 51 insertions(+), 11 deletions(-) diff --git a/src/planar/analyzers.jl b/src/planar/analyzers.jl index b91ff2ba..4d22cf0a 100644 --- a/src/planar/analyzers.jl +++ b/src/planar/analyzers.jl @@ -36,6 +36,10 @@ function get_possible_planar_indices(ex) end end return inds + elseif isexpr(ex, :call) && ex.args[1] == :/ + return get_possible_planar_indices(ex.args[2]) + elseif isexpr(ex, :call) && ex.args[1] == :\ + return get_possible_planar_indices(ex.args[3]) else return Any[] end diff --git a/src/planar/preprocessors.jl b/src/planar/preprocessors.jl index 3d6abecc..9fbd51a5 100644 --- a/src/planar/preprocessors.jl +++ b/src/planar/preprocessors.jl @@ -110,7 +110,9 @@ end _construct_braidingtensors(x) = x function _construct_braidingtensors!(ex, preargs, indexmap) # ex is guaranteed to be a single tensor expression - if TO.istensor(ex) + if TO.isscalarexpr(ex) + return ex, true + elseif TO.istensor(ex) obj, leftind, rightind = TO.decomposetensor(ex) if _remove_adjoint(obj) == :τ # try to construct a braiding tensor @@ -216,9 +218,14 @@ function _construct_braidingtensors!(ex, preargs, indexmap) # ex is guaranteed t end newex = Expr(ex.head, newargs...) return newex, success + elseif isexpr(ex, :call) && ex.args[1] == :/ && length(ex.args) == 3 + newarg, success = _construct_braidingtensors!(ex.args[2], preargs, indexmap) + return Expr(:call, :/, newarg, ex.args[3]), success + elseif isexpr(ex, :call) && ex.args[1] == :\ && length(ex.args) == 3 + newarg, success = _construct_braidingtensors!(ex.args[3], preargs, indexmap) + return Expr(:call, :\, ex.args[2], newarg), success else - @show("huh?") - return ex, true + error("unexpected expression $ex") end end @@ -520,8 +527,14 @@ function _extract_contraction_pairs(rhs, lhs, pre, temporaries) for a in rhs.args[2:end]] return Expr(rhs.head, rhs.args[1], args...) + elseif isexpr(rhs, :call) && rhs.args[1] == :/ + newarg = _extract_contraction_pairs(rhs.args[2], lhs, pre, temporaries) + return Expr(:call, :/, newarg, rhs.args[3]) + elseif isexpr(rhs, :call) && rhs.args[1] == :\ + newarg = _extract_contraction_pairs(rhs.args[3], lhs, pre, temporaries) + return Expr(:call, :\, rhs.args[2], newarg) else - throw(ArgumentError("unknown tensor expression")) + throw(ArgumentError("unknown tensor expression $ex")) end end diff --git a/src/tensors/braidingtensor.jl b/src/tensors/braidingtensor.jl index 2ee46f66..ab945e81 100644 --- a/src/tensors/braidingtensor.jl +++ b/src/tensors/braidingtensor.jl @@ -271,6 +271,18 @@ function planarcontract!(C::AbstractTensorMap{S,N₁,N₂}, end return C end +function planarcontract!(C::AbstractTensorMap{S,N₁,N₂}, + A::BraidingTensor{S}, + (oindA, cindA)::Index2Tuple{2,2}, + B::BraidingTensor{S}, + (cindB, oindB)::Index2Tuple{2,2}, + (p1, p2)::Index2Tuple{N₁,N₂}, + α::Number, β::Number, + backend::Backend...) where {S,N₁,N₂} + return planarcontract!(C, copy(A), (oindA, cindA), B, (cindB, oindB), (p1, p2), α, β, + backend...) +end + function planarcontract!(C::AbstractTensorMap{S,N₁,N₂}, A::AbstractTensorMap{S}, (oindA, cindA)::Index2Tuple{N₃,2}, @@ -317,6 +329,17 @@ function planarcontract!(C::AbstractTensorMap{S,N₁,N₂}, end return C end +function planarcontract!(C::AbstractTensorMap{S,N₁,N₂}, + A::BraidingTensor{S}, + (oindA, cindA)::Index2Tuple{2,2}, + B::BraidingTensor{S}, + (cindB, oindB)::Index2Tuple{2,2}, + (p1, p2)::Index2Tuple{N₁,N₂}, + α::Number, β::Number, + backend::Backend...) where {S,N₁,N₂} + return planarcontract!(C, copy(A), (oindA, cindA), B, (cindB, oindB), (p1, p2), α, β, + backend...) +end # Fallback cases for planarcontract! # TODO: implement specialised cases for contracting 0, 1, 3 and 4 indices diff --git a/test/planar.jl b/test/planar.jl index 755f6a46..f8940aa4 100644 --- a/test/planar.jl +++ b/test/planar.jl @@ -196,13 +196,13 @@ end t1 = TensorMap(rand, T, V1 ← V2) t2 = TensorMap(rand, T, V2 ← V1) - tr1 = @planar opt = true t1[a; b] * t2[b; a] - tr2 = @planar opt = true t1[d; a] * t2[b; c] * τ[c b; a d] - tr3 = @planar opt = true t1[d; a] * t2[b; c] * τ[a c; d b] - tr4 = @planar opt = true t1[f; a] * t2[c; d] * τ[d b; c e] * τ[e b; a f] - tr5 = @planar opt = true t1[f; a] * t2[c; d] * τ[d b; c e] * τ[a e; f b] - tr6 = @planar opt = true t1[f; a] * t2[c; d] * τ[c d; e b] * τ[e b; a f] - tr7 = @planar opt = true t1[f; a] * t2[c; d] * τ[c d; e b] * τ[a e; f b] + tr1 = @planar opt = true t1[a; b] * t2[b; a]/2 + tr2 = @planar opt = true t1[d; a] * t2[b; c] * 1/2 * τ[c b; a d] + tr3 = @planar opt = true t1[d; a] * t2[b; c] * τ[a c; d b] /2 + tr4 = @planar opt = true t1[f; a] * 1/2 * t2[c; d] * τ[d b; c e] * τ[e b; a f] + tr5 = @planar opt = true t1[f; a] * t2[c; d]/2 * τ[d b; c e] * τ[a e; f b] + tr6 = @planar opt = true t1[f; a] * t2[c; d] * τ[c d; e b]/2 * τ[e b; a f] + tr7 = @planar opt = true t1[f; a] * t2[c; d] * (τ[c d; e b] * τ[a e; f b] /2) @test tr1 ≈ tr2 ≈ tr3 ≈ tr4 ≈ tr5 ≈ tr6 ≈ tr7 end From 1a63e383d005ccc6f432e7bbe9005409a72bf1e5 Mon Sep 17 00:00:00 2001 From: Jutho Date: Thu, 4 Jan 2024 00:15:59 +0100 Subject: [PATCH 3/5] fix construct and remove braidingtensor preprocessors --- src/planar/macros.jl | 2 +- src/planar/planaroperations.jl | 2 + src/planar/preprocessors.jl | 284 +++++++++++++++++---------------- src/tensors/braidingtensor.jl | 12 -- test/planar.jl | 24 ++- 5 files changed, 167 insertions(+), 157 deletions(-) diff --git a/src/planar/macros.jl b/src/planar/macros.jl index e10308d2..2682d864 100644 --- a/src/planar/macros.jl +++ b/src/planar/macros.jl @@ -102,7 +102,7 @@ function _plansor(expr, kwargs...) tparser = TO.tensorparser(expr, kwargs...) pparser = planarparser(expr, kwargs...) - insert!(tparser.preprocessors, 5, _remove_braidingtensors) + insert!(tparser.preprocessors, 4, _remove_braidingtensors) tensorex = tparser(expr) planarex = pparser(expr) diff --git a/src/planar/planaroperations.jl b/src/planar/planaroperations.jl index ee23bcd7..1d04b764 100644 --- a/src/planar/planaroperations.jl +++ b/src/planar/planaroperations.jl @@ -61,6 +61,7 @@ function planarcontract!(C::AbstractTensorMap{S,N₁,N₂}, codB, domB = codomainind(B), domainind(B) oindA, cindA = pA cindB, oindB = pB + # @show codA, domA, codB, domB, oindA, cindA, oindB, cindB, pAB oindA, cindA, oindB, cindB = reorder_indices(codA, domA, codB, domB, oindA, cindA, oindB, cindB, pAB...) @@ -91,6 +92,7 @@ _cyclicpermute(t::Tuple{}) = () function reorder_indices(codA, domA, codB, domB, oindA, oindB, p1, p2) N₁ = length(oindA) N₂ = length(oindB) + # @show codA, domA, codB, domB, oindA, oindB, p1, p2 @assert length(p1) == N₁ && all(in(p1), 1:N₁) @assert length(p2) == N₂ && all(in(p2), N₁ .+ (1:N₂)) oindA2 = TupleTools.getindices(oindA, p1) diff --git a/src/planar/preprocessors.jl b/src/planar/preprocessors.jl index 9fbd51a5..00b8ec5b 100644 --- a/src/planar/preprocessors.jl +++ b/src/planar/preprocessors.jl @@ -71,8 +71,11 @@ _add_adjoint(ex) = Expr(TO.prime, ex) # used by `@planar`: identify braiding tensors (corresponding to name τ) and discover their # spaces from the rest of the expression. Construct the explicit BraidingTensor objects and # insert them in the expression. -function _construct_braidingtensors(ex::Expr) - if TO.isdefinition(ex) || TO.isassignment(ex) +function _construct_braidingtensors(ex) + ex isa Expr || return ex + if ex.head == :macrocall && ex.args[1] == Symbol("@notensor") + return ex + elseif TO.isdefinition(ex) || TO.isassignment(ex) lhs, rhs = TO.getlhs(ex), TO.getrhs(ex) if !TO.istensorexpr(rhs) return ex @@ -82,10 +85,11 @@ function _construct_braidingtensors(ex::Expr) if TO.isassignment(ex) && TO.istensor(lhs) obj, leftind, rightind = TO.decomposetensor(lhs) for (i, l) in enumerate(leftind) - indexmap[l] = Expr(:call, :space, _add_adjoint(obj), i) + indexmap[l] = Expr(:call, :dual, Expr(:call, :space, obj, i)) end for (i, l) in enumerate(rightind) - indexmap[l] = Expr(:call, :space, _add_adjoint(obj), length(leftind) + i) + indexmap[l] = Expr(:call, :dual, + Expr(:call, :space, obj, length(leftind) + i)) end end newrhs, success = _construct_braidingtensors!(rhs, preargs, indexmap) @@ -107,11 +111,11 @@ function _construct_braidingtensors(ex::Expr) return Expr(ex.head, map(_construct_braidingtensors, ex.args)...) end end -_construct_braidingtensors(x) = x function _construct_braidingtensors!(ex, preargs, indexmap) # ex is guaranteed to be a single tensor expression if TO.isscalarexpr(ex) - return ex, true + # ex could be tensorscalar call with more braiding tensors + return _construct_braidingtensors(ex), true elseif TO.istensor(ex) obj, leftind, rightind = TO.decomposetensor(ex) if _remove_adjoint(obj) == :τ @@ -156,15 +160,17 @@ function _construct_braidingtensors!(ex, preargs, indexmap) # ex is guaranteed t newex = ex success = true end - # add spaces of the tensor to the indexmap - for (i, l) in enumerate(leftind) - if !haskey(indexmap, l) - indexmap[l] = Expr(:call, :space, obj, i) + if success == true + # add spaces of the tensor to the indexmap + for (i, l) in enumerate(leftind) + if !haskey(indexmap, l) + indexmap[l] = Expr(:call, :space, obj, i) + end end - end - for (i, l) in enumerate(rightind) - if !haskey(indexmap, l) - indexmap[l] = Expr(:call, :space, obj, length(leftind) + i) + for (i, l) in enumerate(rightind) + if !haskey(indexmap, l) + indexmap[l] = Expr(:call, :space, obj, length(leftind) + i) + end end end return newex, success @@ -232,132 +238,144 @@ end # used by non-planar parser of `@plansor`: remove explicit braiding tensors function _remove_braidingtensors(ex) ex isa Expr || return ex - outgoing = [] - - if TO.isdefinition(ex) || TO.isassignment(ex) + if ex.head == :macrocall && ex.args[1] == Symbol("@notensor") + return ex + elseif TO.isdefinition(ex) || TO.isassignment(ex) lhs, rhs = TO.getlhs(ex), TO.getrhs(ex) - if TO.istensorexpr(rhs) - list = TO.gettensors(_conj_to_adjoint(rhs)) - if TO.istensor(lhs) - obj, l, r = TO.decomposetensor(lhs) - outgoing = [l; r] - end - else + if !TO.istensorexpr(rhs) return ex end + indexmap = Dict{Any,Any}() + if TO.istensor(lhs) + obj, leftind, rightind = TO.decomposetensor(lhs) + end + newrhs, unchanged = _remove_braidingtensors!(rhs, indexmap) + isempty(indexmap) || + throw(ArgumentError("cannot determine the spaces of all braiding tensors in $ex")) + return Expr(ex.head, lhs, newrhs) elseif TO.istensorexpr(ex) - list = TO.gettensors(_conj_to_adjoint(ex)) + indexmap = Dict{Any,Any}() + + newex, unchanged = _remove_braidingtensors!(ex, indexmap) + isempty(indexmap) || + throw(ArgumentError("cannot determine the spaces of all braiding tensors in $ex")) + return newex else return Expr(ex.head, map(_remove_braidingtensors, ex.args)...) end +end - τs = Any[] - i = 1 - while i <= length(list) - t = list[i] - if _remove_adjoint(TO.gettensorobject(t)) == :τ - push!(τs, t) - deleteat!(list, i) - else - i += 1 - end - end +function _remove_braidingtensors!(ex, indexmap) # ex is guaranteed to be a single tensor expression + if TO.isscalarexpr(ex) + return _remove_braidingtensors(ex), true + elseif TO.istensor(ex) + obj, leftind, rightind = TO.decomposetensor(ex) + if _remove_adjoint(obj) == :τ + # remove braiding tensor and add labels to indexmap + length(leftind) == length(rightind) == 2 || + throw(ArgumentError("The name τ is reserved for the braiding, and should have two input and two output indices.")) - indexmap = Dict{Any,Any}() - # to remove the braidingtensors, we need to map certain indices to other indices - for t in τs - obj, leftind, rightind = TO.decomposetensor(t) - length(leftind) == length(rightind) == 2 || - throw(ArgumentError("The name τ is reserved for the braiding, and should have two input and two output indices.")) - if _is_adjoint(obj) i1b, i2b, = leftind i2a, i1a, = rightind - else - i2b, i1b, = leftind - i1a, i2a, = rightind - end - - i1a = get(indexmap, i1a, i1a) - i1b = get(indexmap, i1b, i1b) - i2a = get(indexmap, i2a, i2a) - i2b = get(indexmap, i2b, i2b) - - obj_and_pos1a = _findindex(i1a, list) - obj_and_pos2a = _findindex(i2a, list) - obj_and_pos1b = _findindex(i1b, list) - obj_and_pos2b = _findindex(i2b, list) - - if i1a in outgoing - indexmap[i1a] = i1a - indexmap[i1b] = i1a - elseif i1b in outgoing - indexmap[i1a] = i1b - indexmap[i1b] = i1b - else - if i1a isa Int && i1b isa Int - indexmap[i1a] = max(i1a, i1b) - indexmap[i1b] = max(i1a, i1b) + if i1a == i1b || (haskey(indexmap, i1a) && haskey(indexmap, i1b)) + throw(IndexError("Cannot resolve indices $i1a and $i1b that occur only on braidings.")) + elseif haskey(indexmap, i1a) + i1c = indexmap[i1a] + indexmap[i1c] = i1b + indexmap[i1b] = i1c + delete!(indexmap, i1a) + elseif haskey(indexmap, i1b) + i1c = indexmap[i1b] + indexmap[i1c] = i1a + indexmap[i1a] = i1c + delete!(indexmap, i1b) else - indexmap[i1a] = i1a + indexmap[i1a] = i1b indexmap[i1b] = i1a end - end - - if i2a in outgoing - indexmap[i2a] = i2a - indexmap[i2b] = i2a - elseif i2b in outgoing - indexmap[i2a] = i2b - indexmap[i2b] = i2b - else - if i2a isa Int && i2b isa Int - indexmap[i2a] = max(i2a, i2b) - indexmap[i2b] = max(i2a, i2b) + if i2a == i2b || (haskey(indexmap, i2a) && haskey(indexmap, i2b)) + throw(IndexError("Cannot resolve indices $i2a and $i2b that occur only on braidings.")) + elseif haskey(indexmap, i2a) + i2c = indexmap[i2a] + indexmap[i2c] = i2b + indexmap[i2b] = i2c + delete!(indexmap, i2a) + elseif haskey(indexmap, i2b) + i2c = indexmap[i2b] + indexmap[i2c] = i2a + indexmap[i2a] = i2c + delete!(indexmap, i2b) else - indexmap[i2a] = i2a + indexmap[i2a] = i2b indexmap[i2b] = i2a end + return One(), false # when there are still braiding tensors, we haven't finished + else + unchanged = true + for (i, l) in enumerate(leftind) + if haskey(indexmap, l) + unchanged = false + l′ = indexmap[l] + leftind[i] = l′ + delete!(indexmap, l) + delete!(indexmap, l′) + end + end + for (i, l) in enumerate(rightind) + if haskey(indexmap, l) + unchanged = false + l′ = indexmap[l] + rightind[i] = l′ + delete!(indexmap, l) + delete!(indexmap, l′) + end + end + return Expr(:typed_vcat, obj, Expr(:tuple, leftind...), + Expr(:tuple, rightind...)), unchanged end - end - - # simple loop that tries to simplify the indicemaps (a=>b,b=>c -> a=>c,b=>c) - changed = true - while changed == true - changed = false - i = 1 - for (k, v) in indexmap - if v in keys(indexmap) && indexmap[v] != v - changed = true - indexmap[k] = indexmap[v] + elseif TO.isgeneraltensor(ex) + args = ex.args + newargs = Vector{Any}(undef, length(args)) + unchanged = true + for i in 1:length(ex.args) + newargs[i], unchangeda = _remove_braidingtensors!(args[i], indexmap) + unchanged = unchanged && unchangeda + end + newex = Expr(ex.head, newargs...) + return newex, unchanged + elseif isexpr(ex, :call) && ex.args[1] == :* + args = ex.args + newargs = copy(args) + unchanged = map(i -> false, args) + unchanged[1] = true + for i in 2:length(ex.args) + newargs[i], unchanged[i] = _remove_braidingtensors!(newargs[i], indexmap) + end + all(unchanged) && return ex, true + while !all(unchanged) + for i in 2:length(ex.args) + newargs[i], unchanged[i] = _remove_braidingtensors!(newargs[i], indexmap) end end - end - - ex = TO.replaceindices(i -> get(indexmap, i, i), ex) - return _purge_braidingtensors(ex) -end - -function _purge_braidingtensors(ex) # actually remove the braidingtensors - ex isa Expr || return ex - args = collect(filter(ex.args) do a - if isexpr(a, :call) && a.args[1] == :conj - a = a.args[2] - end - if a isa Expr && TO.istensor(a) && - _remove_adjoint(TO.gettensorobject(a)) == :τ - _, leftind, rightind = TO.decomposetensor(a) - (leftind[1] == rightind[2] && leftind[2] == rightind[1]) || - throw(ArgumentError("unable to remove braiding tensor $a")) - return false - end - return true - end) - - # multiplication with only a single argument is (rightfully) seen as invalid syntax - if isexpr(ex, :call) && args[1] == :* && length(args) == 2 - return _purge_braidingtensors(args[2]) + return Expr(ex.head, newargs...), false + elseif isexpr(ex, :call) && ex.args[1] ∈ (:+, :-) + newargs = copy(ex.args) + indexmaps = [copy(indexmap) for _ in 1:(length(newargs) - 1)] + unchanged = true + for i in 2:length(ex.args) + newargs[i], unchangeda = _remove_braidingtensors!(ex.args[i], indexmaps[i - 1]) + unchanged = unchanged && unchangeda + end + newex = Expr(ex.head, newargs...) + return newex, unchanged + elseif isexpr(ex, :call) && ex.args[1] == :/ && length(ex.args) == 3 + newarg, unchanged = _remove_braidingtensors!(ex.args[2], indexmap) + return Expr(:call, :/, newarg, ex.args[3]), unchanged + elseif isexpr(ex, :call) && ex.args[1] == :\ && length(ex.args) == 3 + newarg, unchanged = _remove_braidingtensors!(ex.args[3], indexmap) + return Expr(:call, :\, ex.args[2], newarg), unchanged else - return Expr(ex.head, map(_purge_braidingtensors, args)...) + error("unexpected expression $ex") end end @@ -396,24 +414,15 @@ function _decompose_planar_contractions(ex::Expr, temporaries) end if TO.isassignment(ex) || TO.isdefinition(ex) lhs, rhs = TO.getlhs(ex), TO.getrhs(ex) - if TO.istensorexpr(rhs) - pre = Vector{Any}() - if TO.istensor(lhs) - rhs = _extract_contraction_pairs(rhs, lhs, pre, temporaries) - return Expr(:block, pre..., Expr(ex.head, lhs, rhs)) - else - lhssym = gensym(string(lhs)) - lhstensor = Expr(:typed_vcat, lhssym, Expr(:tuple), Expr(:tuple)) - rhs = _extract_contraction_pairs(rhs, lhstensor, pre, temporaries) - push!(temporaries, lhssym) - return Expr(:block, pre..., Expr(:(:=), lhstensor, rhs), - Expr(:(=), lhs, lhstensor)) - end + pre = Vector{Any}() + if TO.istensor(lhs) + rhs = _extract_contraction_pairs(rhs, lhs, pre, temporaries) else - return ex + rhs = _extract_contraction_pairs(rhs, (Any[], Any[]), pre, temporaries) end + return Expr(:block, pre..., Expr(ex.head, lhs, rhs)) end - if TO.istensorexpr(ex) + if TO.istensorexpr(ex) || (isexpr(ex, :call) && ex.args[1] == :tensorscalar) pre = Vector{Any}() rhs = _extract_contraction_pairs(ex, (Any[], Any[]), pre, temporaries) return Expr(:block, pre..., rhs) @@ -429,7 +438,10 @@ end # if lhs is an expression, it contains the existing lhs and thus the index order # if lhs is a tuple, the result is a temporary object and the tuple (lind, rind) gives a suggestion for the preferred index order function _extract_contraction_pairs(rhs, lhs, pre, temporaries) - if TO.isscalarexpr(rhs) + if isexpr(rhs, :call) && rhs.args[1] == :tensorscalar + newarg = _extract_contraction_pairs(rhs.args[2], lhs, pre, temporaries) + return Expr(:call, :tensorscalar, newarg) + elseif TO.isscalarexpr(rhs) return rhs elseif TO.isgeneraltensor(rhs) if TO.hastraceindices(rhs) && lhs isa Tuple @@ -479,7 +491,6 @@ function _extract_contraction_pairs(rhs, lhs, pre, temporaries) a2 = _extract_contraction_pairs(rhs.args[3], (cind2, reverse(oind2)), pre, temporaries) end - # @show a1, a2, oind1, oind2 if TO.isscalarexpr(a1) || TO.isscalarexpr(a2) rhs = Expr(:call, :*, a1, a2) @@ -499,7 +510,6 @@ function _extract_contraction_pairs(rhs, lhs, pre, temporaries) ind1, ind2 = ind2, ind1 oind1, oind2 = oind2, oind1 end - # @show a1, a2, oind1, oind2 if lhs isa Tuple rhs = Expr(:call, :*, a1, a2) s = gensym() diff --git a/src/tensors/braidingtensor.jl b/src/tensors/braidingtensor.jl index ab945e81..95028587 100644 --- a/src/tensors/braidingtensor.jl +++ b/src/tensors/braidingtensor.jl @@ -271,18 +271,6 @@ function planarcontract!(C::AbstractTensorMap{S,N₁,N₂}, end return C end -function planarcontract!(C::AbstractTensorMap{S,N₁,N₂}, - A::BraidingTensor{S}, - (oindA, cindA)::Index2Tuple{2,2}, - B::BraidingTensor{S}, - (cindB, oindB)::Index2Tuple{2,2}, - (p1, p2)::Index2Tuple{N₁,N₂}, - α::Number, β::Number, - backend::Backend...) where {S,N₁,N₂} - return planarcontract!(C, copy(A), (oindA, cindA), B, (cindB, oindB), (p1, p2), α, β, - backend...) -end - function planarcontract!(C::AbstractTensorMap{S,N₁,N₂}, A::AbstractTensorMap{S}, (oindA, cindA)::Index2Tuple{N₃,2}, diff --git a/test/planar.jl b/test/planar.jl index f8940aa4..0ba74faa 100644 --- a/test/planar.jl +++ b/test/planar.jl @@ -196,13 +196,23 @@ end t1 = TensorMap(rand, T, V1 ← V2) t2 = TensorMap(rand, T, V2 ← V1) - tr1 = @planar opt = true t1[a; b] * t2[b; a]/2 - tr2 = @planar opt = true t1[d; a] * t2[b; c] * 1/2 * τ[c b; a d] - tr3 = @planar opt = true t1[d; a] * t2[b; c] * τ[a c; d b] /2 - tr4 = @planar opt = true t1[f; a] * 1/2 * t2[c; d] * τ[d b; c e] * τ[e b; a f] - tr5 = @planar opt = true t1[f; a] * t2[c; d]/2 * τ[d b; c e] * τ[a e; f b] - tr6 = @planar opt = true t1[f; a] * t2[c; d] * τ[c d; e b]/2 * τ[e b; a f] - tr7 = @planar opt = true t1[f; a] * t2[c; d] * (τ[c d; e b] * τ[a e; f b] /2) + tr1 = @planar opt = true t1[a; b] * t2[b; a] / 2 + tr2 = @planar opt = true t1[d; a] * t2[b; c] * 1 / 2 * τ[c b; a d] + tr3 = @planar opt = true t1[d; a] * t2[b; c] * τ[a c; d b] / 2 + tr4 = @planar opt = true t1[f; a] * 1 / 2 * t2[c; d] * τ[d b; c e] * τ[e b; a f] + tr5 = @planar opt = true t1[f; a] * t2[c; d] / 2 * τ[d b; c e] * τ[a e; f b] + tr6 = @planar opt = true t1[f; a] * t2[c; d] * τ[c d; e b] / 2 * τ[e b; a f] + tr7 = @planar opt = true t1[f; a] * t2[c; d] * (τ[c d; e b] * τ[a e; f b] / 2) + + @test tr1 ≈ tr2 ≈ tr3 ≈ tr4 ≈ tr5 ≈ tr6 ≈ tr7 + + tr1 = @plansor opt = true t1[a; b] * t2[b; a] / 2 + tr2 = @plansor opt = true t1[d; a] * t2[b; c] * 1 / 2 * τ[c b; a d] + tr3 = @plansor opt = true t1[d; a] * t2[b; c] * τ[a c; d b] / 2 + tr4 = @plansor opt = true t1[f; a] * 1 / 2 * t2[c; d] * τ[d b; c e] * τ[e b; a f] + tr5 = @plansor opt = true t1[f; a] * t2[c; d] / 2 * τ[d b; c e] * τ[a e; f b] + tr6 = @plansor opt = true t1[f; a] * t2[c; d] * τ[c d; e b] / 2 * τ[e b; a f] + tr7 = @plansor opt = true t1[f; a] * t2[c; d] * (τ[c d; e b] * τ[a e; f b] / 2) @test tr1 ≈ tr2 ≈ tr3 ≈ tr4 ≈ tr5 ≈ tr6 ≈ tr7 end From a2c90c9e122471792f3921f182fd99a7cd61ff51 Mon Sep 17 00:00:00 2001 From: Jutho Date: Thu, 4 Jan 2024 15:00:23 +0100 Subject: [PATCH 4/5] Apply suggestions from code review Co-authored-by: Lukas <37111893+lkdvos@users.noreply.github.com> --- src/planar/planaroperations.jl | 2 -- src/planar/preprocessors.jl | 2 +- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/src/planar/planaroperations.jl b/src/planar/planaroperations.jl index 1d04b764..ee23bcd7 100644 --- a/src/planar/planaroperations.jl +++ b/src/planar/planaroperations.jl @@ -61,7 +61,6 @@ function planarcontract!(C::AbstractTensorMap{S,N₁,N₂}, codB, domB = codomainind(B), domainind(B) oindA, cindA = pA cindB, oindB = pB - # @show codA, domA, codB, domB, oindA, cindA, oindB, cindB, pAB oindA, cindA, oindB, cindB = reorder_indices(codA, domA, codB, domB, oindA, cindA, oindB, cindB, pAB...) @@ -92,7 +91,6 @@ _cyclicpermute(t::Tuple{}) = () function reorder_indices(codA, domA, codB, domB, oindA, oindB, p1, p2) N₁ = length(oindA) N₂ = length(oindB) - # @show codA, domA, codB, domB, oindA, oindB, p1, p2 @assert length(p1) == N₁ && all(in(p1), 1:N₁) @assert length(p2) == N₂ && all(in(p2), N₁ .+ (1:N₂)) oindA2 = TupleTools.getindices(oindA, p1) diff --git a/src/planar/preprocessors.jl b/src/planar/preprocessors.jl index 00b8ec5b..7f93560f 100644 --- a/src/planar/preprocessors.jl +++ b/src/planar/preprocessors.jl @@ -375,7 +375,7 @@ function _remove_braidingtensors!(ex, indexmap) # ex is guaranteed to be a singl newarg, unchanged = _remove_braidingtensors!(ex.args[3], indexmap) return Expr(:call, :\, ex.args[2], newarg), unchanged else - error("unexpected expression $ex") + throw(ArgumentError("unexpected expression $ex")) end end From dc552e937648b52cfdc5c605912202f6479d2a0a Mon Sep 17 00:00:00 2001 From: Jutho Date: Fri, 5 Jan 2024 08:56:36 +0100 Subject: [PATCH 5/5] update Project.toml and fix rebase result --- Project.toml | 2 +- test/planar.jl | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/Project.toml b/Project.toml index fc0c157c..6af92733 100644 --- a/Project.toml +++ b/Project.toml @@ -19,7 +19,7 @@ HalfIntegers = "1" LRUCache = "1.0.2" PackageExtensionCompat = "1" Strided = "2" -TensorOperations = "4.0.6 - 4.0.7" +TensorOperations = "4.1" TupleTools = "1.1" VectorInterface = "0.4" WignerSymbols = "1,2" diff --git a/test/planar.jl b/test/planar.jl index 0ba74faa..d7edf146 100644 --- a/test/planar.jl +++ b/test/planar.jl @@ -89,7 +89,7 @@ end @planar C1[i; j] := A[i; k l] * τ[k l; m n] * B[m n; j] @planar contractcheck = true C2[i; j] := A[i; k l] * τ[k l; m n] * B[m n; j] @test C1 ≈ C2 - @test_throws SpaceMismatch("incompatible spaces for l: $V ≠ $(V')") begin + @test_throws SpaceMismatch("incompatible spaces for m: $V ≠ $(V')") begin @planar contractcheck = true C3[i; j] := A[i; k l] * τ[k l; m n] * B[n j; m] end end