From a34381c1e2c0c20037a3653ae92337f51c871616 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jofre=20Vall=C3=A8s?= Date: Tue, 18 Jun 2024 13:26:09 +0200 Subject: [PATCH 01/10] Add order kwarg in Chain constructor --- src/Ansatz/Chain.jl | 121 +++++++++++++++++++++++++++++++++++--------- 1 file changed, 98 insertions(+), 23 deletions(-) diff --git a/src/Ansatz/Chain.jl b/src/Ansatz/Chain.jl index f33e2ce..e631531 100644 --- a/src/Ansatz/Chain.jl +++ b/src/Ansatz/Chain.jl @@ -30,14 +30,33 @@ function Chain(tn::TensorNetwork, sites, args...; kwargs...) Chain(Quantum(tn, sites), args...; kwargs...) end -function Chain(::State, boundary::Periodic, arrays::Vector{<:AbstractArray}) +defaultorder(::State) = (:o, :l, :r) +defaultorder(::Operator) = (:o, :i, :l, :r) + +function Chain(::State, boundary::Periodic, arrays::Vector{<:AbstractArray}; order = defaultorder(State())) @assert all(==(3) ∘ ndims, arrays) "All arrays must have 3 dimensions" + issetequal(order, defaultorder(State())) || throw(ArgumentError("order must be a permutation of $(String.(defaultorder(State())))")) n = length(arrays) symbols = [nextindex() for _ in 1:2n] + function get_index(directions, i) + map(directions) do dir + if dir == :o + symbols[i] + elseif dir == :l + symbols[n + mod1(i, n)] + elseif dir == :r + symbols[n + mod1(i - 1, n)] + else + throw(ArgumentError("Invalid direction: $dir")) + end + end + end + _tensors = map(enumerate(arrays)) do (i, array) - Tensor(array, [symbols[i], symbols[n+mod1(i - 1, n)], symbols[n+mod1(i, n)]]) + inds = get_index(order, i) + Tensor(array, inds) end sitemap = Dict(Site(i) => symbols[i] for i in 1:n) @@ -45,65 +64,121 @@ function Chain(::State, boundary::Periodic, arrays::Vector{<:AbstractArray}) Chain(Quantum(TensorNetwork(_tensors), sitemap), boundary) end -function Chain(::State, boundary::Open, arrays::Vector{<:AbstractArray}) +function Chain(::State, boundary::Open, arrays::Vector{<:AbstractArray}; order = defaultorder(State())) @assert ndims(arrays[1]) == 2 "First array must have 2 dimensions" @assert all(==(3) ∘ ndims, arrays[2:end-1]) "All arrays must have 3 dimensions" @assert ndims(arrays[end]) == 2 "Last array must have 2 dimensions" + issetequal(order, defaultorder(State())) || throw(ArgumentError("order must be a permutation of $(String.(defaultorder(State())))")) n = length(arrays) symbols = [nextindex() for _ in 1:2n-1] - _tensors = map(enumerate(arrays)) do (i, array) - if i == 1 - Tensor(array, [symbols[1], symbols[1+n]]) - elseif i == n - Tensor(array, [symbols[n], symbols[n+mod1(n - 1, n)]]) - else - Tensor(array, [symbols[i], symbols[n+mod1(i - 1, n)], symbols[n+mod1(i, n)]]) + function get_index(directions, i, is_first, is_last) + if is_first + directions = filter(x -> x != :l, directions) + elseif is_last + directions = filter(x -> x != :r, directions) + end + + map(directions) do dir + if dir == :o + symbols[i] + elseif dir == :l + symbols[n + mod1(i, n)] + elseif dir == :r + symbols[n + mod1(i - 1, n)] + else + throw(ArgumentError("Invalid direction: $dir")) + end end end + _tensors = map(enumerate(arrays)) do (i, array) + is_first = (i == 1) + is_last = (i == n) + inds = get_index(order, i, is_first, is_last) + Tensor(array, inds) + end + sitemap = Dict(Site(i) => symbols[i] for i in 1:n) Chain(Quantum(TensorNetwork(_tensors), sitemap), boundary) end -function Chain(::Operator, boundary::Periodic, arrays::Vector{<:AbstractArray}) - @assert all(==(4) ∘ ndims, arrays) "All arrays must have 3 dimensions" +function Chain(::Operator, boundary::Periodic, arrays::Vector{<:AbstractArray}; order = defaultorder(Operator())) + @assert all(==(4) ∘ ndims, arrays) "All arrays must have 4 dimensions" + issetequal(order, defaultorder(Operator())) || throw(ArgumentError("order must be a permutation of $(String.(defaultorder(Operator())))")) n = length(arrays) symbols = [nextindex() for _ in 1:3n] + function get_index(directions, i) + map(directions) do dir + if dir == :o + symbols[i] + elseif dir == :i + symbols[i + n] + elseif dir == :l + symbols[2n + mod1(i - 1, n)] + elseif dir == :r + symbols[2n + mod1(i, n)] + else + throw(ArgumentError("Invalid direction: $dir")) + end + end + end + _tensors = map(enumerate(arrays)) do (i, array) - Tensor(array, [symbols[i], symbols[i+n], symbols[2n+mod1(i - 1, n)], symbols[2n+mod1(i, n)]]) + inds = get_index(order, i) + Tensor(array, inds) end sitemap = Dict(Site(i) => symbols[i] for i in 1:n) - merge!(sitemap, Dict(Site(i; dual = true) => symbols[i+n] for i in 1:n)) + merge!(sitemap, Dict(Site(i; dual = true) => symbols[i + n] for i in 1:n)) Chain(Quantum(TensorNetwork(_tensors), sitemap), boundary) end -function Chain(::Operator, boundary::Open, arrays::Vector{<:AbstractArray}) +function Chain(::Operator, boundary::Open, arrays::Vector{<:AbstractArray}; order = defaultorder(Operator())) @assert ndims(arrays[1]) == 3 "First array must have 3 dimensions" @assert all(==(4) ∘ ndims, arrays[2:end-1]) "All arrays must have 4 dimensions" @assert ndims(arrays[end]) == 3 "Last array must have 3 dimensions" + issetequal(order, defaultorder(Operator())) || throw(ArgumentError("order must be a permutation of $(String.(defaultorder(Operator())))")) n = length(arrays) symbols = [nextindex() for _ in 1:3n-1] - _tensors = map(enumerate(arrays)) do (i, array) - if i == 1 - Tensor(array, [symbols[1], symbols[n+1], symbols[1+2n]]) - elseif i == n - Tensor(array, [symbols[n], symbols[2n], symbols[2n+mod1(n - 1, n)]]) - else - Tensor(array, [symbols[i], symbols[i+n], symbols[2n+mod1(i - 1, n)], symbols[2n+mod1(i, n)]]) + function get_index(directions, i, is_first, is_last) + if is_first + directions = filter(x -> x != :l, directions) + elseif is_last + directions = filter(x -> x != :r, directions) end + + map(directions) do dir + if dir == :o + symbols[i] + elseif dir == :i + symbols[i + n] + elseif dir == :l + symbols[2n + mod1(i - 1, n)] + elseif dir == :r + symbols[2n + mod1(i, n)] + else + throw(ArgumentError("Invalid direction: $dir")) + end + end + end + + _tensors = map(enumerate(arrays)) do (i, array) + is_first = (i == 1) + is_last = (i == n) + inds = get_index(order, i, is_first, is_last) + Tensor(array, inds) end sitemap = Dict(Site(i) => symbols[i] for i in 1:n) - merge!(sitemap, Dict(Site(i; dual = true) => symbols[i+n] for i in 1:n)) + merge!(sitemap, Dict(Site(i; dual = true) => symbols[i + n] for i in 1:n)) Chain(Quantum(TensorNetwork(_tensors), sitemap), boundary) end From fc426682820216cf43042bfd9a3f4d1f2a7ed25a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jofre=20Vall=C3=A8s?= Date: Tue, 18 Jun 2024 13:26:23 +0200 Subject: [PATCH 02/10] Add tests --- test/Ansatz/Chain_test.jl | 188 +++++++++++++++++++++++++++++++------- 1 file changed, 157 insertions(+), 31 deletions(-) diff --git a/test/Ansatz/Chain_test.jl b/test/Ansatz/Chain_test.jl index 0eb7bcb..4183086 100644 --- a/test/Ansatz/Chain_test.jl +++ b/test/Ansatz/Chain_test.jl @@ -1,35 +1,161 @@ @testset "Chain ansatz" begin - qtn = Chain(State(), Periodic(), [rand(2, 4, 4) for _ in 1:3]) - @test socket(qtn) == State() - @test ninputs(qtn) == 0 - @test noutputs(qtn) == 3 - @test issetequal(sites(qtn), [site"1", site"2", site"3"]) - @test boundary(qtn) == Periodic() - @test leftindex(qtn, site"1") == rightindex(qtn, site"3") != nothing - - qtn = Chain(State(), Open(), [rand(2, 2), rand(2, 2, 2), rand(2, 2)]) - @test socket(qtn) == State() - @test ninputs(qtn) == 0 - @test noutputs(qtn) == 3 - @test issetequal(sites(qtn), [site"1", site"2", site"3"]) - @test boundary(qtn) == Open() - @test leftindex(qtn, site"1") == rightindex(qtn, site"3") == nothing - - qtn = Chain(Operator(), Periodic(), [rand(2, 2, 4, 4) for _ in 1:3]) - @test socket(qtn) == Operator() - @test ninputs(qtn) == 3 - @test noutputs(qtn) == 3 - @test issetequal(sites(qtn), [site"1", site"2", site"3", site"1'", site"2'", site"3'"]) - @test boundary(qtn) == Periodic() - @test leftindex(qtn, site"1") == rightindex(qtn, site"3") != nothing - - qtn = Chain(Operator(), Open(), [rand(2, 2, 4), rand(2, 2, 4, 4), rand(2, 2, 4)]) - @test socket(qtn) == Operator() - @test ninputs(qtn) == 3 - @test noutputs(qtn) == 3 - @test issetequal(sites(qtn), [site"1", site"2", site"3", site"1'", site"2'", site"3'"]) - @test boundary(qtn) == Open() - @test leftindex(qtn, site"1") == rightindex(qtn, site"3") == nothing + @testset "Periodic boundary" begin + @testset "State" begin + qtn = Chain(State(), Periodic(), [rand(2, 4, 4) for _ in 1:3]) + @test socket(qtn) == State() + @test ninputs(qtn) == 0 + @test noutputs(qtn) == 3 + @test issetequal(sites(qtn), [site"1", site"2", site"3"]) + @test boundary(qtn) == Periodic() + @test leftindex(qtn, site"1") == rightindex(qtn, site"3") != nothing + + arrays = [rand(2, 1, 4), rand(2, 4, 3), rand(2, 3, 1)] + qtn = Chain(State(), Periodic(), arrays) # Default order (:o, :l, :r) + + @test size(tensors(qtn; at = Site(1))) == (2, 1, 4) + @test size(tensors(qtn; at = Site(2))) == (2, 4, 3) + @test size(tensors(qtn; at = Site(3))) == (2, 3, 1) + + @test leftindex(qtn, Site(1)) == rightindex(qtn, Site(3)) + @test leftindex(qtn, Site(2)) == rightindex(qtn, Site(1)) + @test leftindex(qtn, Site(3)) == rightindex(qtn, Site(2)) + + arrays = [permutedims(array, (3, 1, 2)) for array in arrays] # now we have (:r, :o, :l) + qtn = Chain(State(), Periodic(), arrays, order=[:r, :o, :l]) + + @test size(tensors(qtn; at = Site(1))) == (4, 2, 1) + @test size(tensors(qtn; at = Site(2))) == (3, 2, 4) + @test size(tensors(qtn; at = Site(3))) == (1, 2, 3) + + @test leftindex(qtn, Site(1)) == rightindex(qtn, Site(3)) + @test leftindex(qtn, Site(2)) == rightindex(qtn, Site(1)) + @test leftindex(qtn, Site(3)) == rightindex(qtn, Site(2)) + + for i in 1:nsites(qtn) + @test size(TensorNetwork(qtn), inds(qtn; at = Site(i))) == 2 + end + end + @testset "Operator" begin + qtn = Chain(Operator(), Periodic(), [rand(2, 2, 4, 4) for _ in 1:3]) + @test socket(qtn) == Operator() + @test ninputs(qtn) == 3 + @test noutputs(qtn) == 3 + @test issetequal(sites(qtn), [site"1", site"2", site"3", site"1'", site"2'", site"3'"]) + @test boundary(qtn) == Periodic() + @test leftindex(qtn, site"1") == rightindex(qtn, site"3") != nothing + + arrays = [rand(2, 4, 1, 3), rand(2, 4, 3, 6), rand(2, 4, 6, 1)] # Default order (:o, :i, :l, :r) + qtn = Chain(Operator(), Periodic(), arrays) + + @test size(tensors(qtn; at = Site(1))) == (2, 4, 1, 3) + @test size(tensors(qtn; at = Site(2))) == (2, 4, 3, 6) + @test size(tensors(qtn; at = Site(3))) == (2, 4, 6, 1) + + @test leftindex(qtn, Site(1)) == rightindex(qtn, Site(3)) + @test leftindex(qtn, Site(2)) == rightindex(qtn, Site(1)) + @test leftindex(qtn, Site(3)) == rightindex(qtn, Site(2)) + + for i in 1:length(arrays) + @test size(TensorNetwork(qtn), inds(qtn; at = Site(i))) == 2 + @test size(TensorNetwork(qtn), inds(qtn; at = Site(i; dual=true))) == 4 + end + + arrays = [permutedims(array, (4, 1, 3, 2)) for array in arrays] # now we have (:r, :o, :l, :i) + qtn = Chain(Operator(), Periodic(), arrays, order=[:r, :o, :l, :i]) + + @test size(tensors(qtn; at = Site(1))) == (3, 2, 1, 4) + @test size(tensors(qtn; at = Site(2))) == (6, 2, 3, 4) + @test size(tensors(qtn; at = Site(3))) == (1, 2, 6, 4) + + @test leftindex(qtn, Site(1)) == rightindex(qtn, Site(3)) + @test leftindex(qtn, Site(2)) == rightindex(qtn, Site(1)) + @test leftindex(qtn, Site(3)) == rightindex(qtn, Site(2)) + + for i in 1:length(arrays) + @test size(TensorNetwork(qtn), inds(qtn; at = Site(i))) == 2 + @test size(TensorNetwork(qtn), inds(qtn; at = Site(i; dual=true))) == 4 + end + end + end + + @testset "Open boundary" begin + @testset "State" begin + qtn = Chain(State(), Open(), [rand(2, 2), rand(2, 2, 2), rand(2, 2)]) + @test socket(qtn) == State() + @test ninputs(qtn) == 0 + @test noutputs(qtn) == 3 + @test issetequal(sites(qtn), [site"1", site"2", site"3"]) + @test boundary(qtn) == Open() + @test leftindex(qtn, site"1") == rightindex(qtn, site"3") == nothing + + arrays = [rand(2, 1), rand(2, 1, 3), rand(2, 3)] + qtn = Chain(State(), Open(), arrays) # Default order (:o, :l, :r) + + @test size(tensors(qtn; at = Site(1))) == (2, 1) + @test size(tensors(qtn; at = Site(2))) == (2, 1, 3) + @test size(tensors(qtn; at = Site(3))) == (2, 3) + + @test leftindex(qtn, Site(1)) == rightindex(qtn, Site(3)) === nothing + @test leftindex(qtn, Site(2)) == rightindex(qtn, Site(1)) + @test leftindex(qtn, Site(3)) == rightindex(qtn, Site(2)) + + arrays = [permutedims(arrays[1], (2, 1)), permutedims(arrays[2], (3, 1, 2)), permutedims(arrays[3], (2, 1))] # now we have (:l, :o, :r) + qtn = Chain(State(), Open(), arrays, order=[:r, :o, :l]) + + @test size(tensors(qtn; at = Site(1))) == (1, 2) + @test size(tensors(qtn; at = Site(2))) == (3, 2, 1) + @test size(tensors(qtn; at = Site(3))) == (3, 2) + + @test leftindex(qtn, Site(1)) == rightindex(qtn, Site(3)) === nothing + @test leftindex(qtn, Site(2)) == rightindex(qtn, Site(1)) + @test leftindex(qtn, Site(3)) == rightindex(qtn, Site(2)) + + for i in 1:nsites(qtn) + @test size(TensorNetwork(qtn), inds(qtn; at = Site(i))) == 2 + end + end + @testset "Operator" begin + qtn = Chain(Operator(), Open(), [rand(2, 2, 4), rand(2, 2, 4, 4), rand(2, 2, 4)]) + @test socket(qtn) == Operator() + @test ninputs(qtn) == 3 + @test noutputs(qtn) == 3 + @test issetequal(sites(qtn), [site"1", site"2", site"3", site"1'", site"2'", site"3'"]) + @test boundary(qtn) == Open() + @test leftindex(qtn, site"1") == rightindex(qtn, site"3") == nothing + + arrays = [rand(2, 4, 1), rand(2, 4, 1, 3), rand(2, 4, 3)] # Default order (:o :i, :l, :r) + qtn = Chain(Operator(), Open(), arrays) + + @test size(tensors(qtn; at = Site(1))) == (2, 4, 1) + @test size(tensors(qtn; at = Site(2))) == (2, 4, 1, 3) + @test size(tensors(qtn; at = Site(3))) == (2, 4, 3) + + @test leftindex(qtn, Site(1)) == rightindex(qtn, Site(3)) === nothing + @test leftindex(qtn, Site(2)) == rightindex(qtn, Site(1)) + @test leftindex(qtn, Site(3)) == rightindex(qtn, Site(2)) + + for i in 1:length(arrays) + @test size(TensorNetwork(qtn), inds(qtn; at = Site(i))) == 2 + @test size(TensorNetwork(qtn), inds(qtn; at = Site(i; dual=true))) == 4 + end + + arrays = [permutedims(array, (3, 1, 2)), permutedims(array, (4, 1, 3, 2)), permutedims(array, (3, 1, 2))] # now we have (:r, :o, :l, :i) + qtn = Chain(Operator(), Open(), arrays, order=[:r, :o, :l, :i]) + + @test size(tensors(qtn; at = Site(1))) == (1, 2, 4) + @test size(tensors(qtn; at = Site(2))) == (3, 2, 1, 4) + @test size(tensors(qtn; at = Site(3))) == (3, 2, 4) + + @test leftindex(qtn, Site(1)) == rightindex(qtn, Site(3)) === nothing + @test leftindex(qtn, Site(2)) == rightindex(qtn, Site(1)) + @test leftindex(qtn, Site(3)) == rightindex(qtn, Site(2)) + + for i in 1:length(arrays) + @test size(TensorNetwork(qtn), inds(qtn; at = Site(i))) == 2 + @test size(TensorNetwork(qtn), inds(qtn; at = Site(i; dual=true))) == 4 + end + end + end @testset "Site" begin using Qrochet: leftsite, rightsite From 73449c37a072a16215a496a785a1f5b862e61024 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jofre=20Vall=C3=A8s?= Date: Tue, 18 Jun 2024 15:13:46 +0200 Subject: [PATCH 03/10] Fix minor typos --- src/Ansatz/Chain.jl | 2 +- test/Ansatz/Chain_test.jl | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/Ansatz/Chain.jl b/src/Ansatz/Chain.jl index e631531..46e4f54 100644 --- a/src/Ansatz/Chain.jl +++ b/src/Ansatz/Chain.jl @@ -71,7 +71,7 @@ function Chain(::State, boundary::Open, arrays::Vector{<:AbstractArray}; order = issetequal(order, defaultorder(State())) || throw(ArgumentError("order must be a permutation of $(String.(defaultorder(State())))")) n = length(arrays) - symbols = [nextindex() for _ in 1:2n-1] + symbols = [nextindex() for _ in 1:2n] function get_index(directions, i, is_first, is_last) if is_first diff --git a/test/Ansatz/Chain_test.jl b/test/Ansatz/Chain_test.jl index 4183086..3fbf664 100644 --- a/test/Ansatz/Chain_test.jl +++ b/test/Ansatz/Chain_test.jl @@ -99,12 +99,12 @@ @test leftindex(qtn, Site(2)) == rightindex(qtn, Site(1)) @test leftindex(qtn, Site(3)) == rightindex(qtn, Site(2)) - arrays = [permutedims(arrays[1], (2, 1)), permutedims(arrays[2], (3, 1, 2)), permutedims(arrays[3], (2, 1))] # now we have (:l, :o, :r) + arrays = [permutedims(arrays[1], (2, 1)), permutedims(arrays[2], (3, 1, 2)), permutedims(arrays[3], (1, 2))] # now we have (:r, :o, :l) qtn = Chain(State(), Open(), arrays, order=[:r, :o, :l]) @test size(tensors(qtn; at = Site(1))) == (1, 2) @test size(tensors(qtn; at = Site(2))) == (3, 2, 1) - @test size(tensors(qtn; at = Site(3))) == (3, 2) + @test size(tensors(qtn; at = Site(3))) == (2, 3) @test leftindex(qtn, Site(1)) == rightindex(qtn, Site(3)) === nothing @test leftindex(qtn, Site(2)) == rightindex(qtn, Site(1)) @@ -139,12 +139,12 @@ @test size(TensorNetwork(qtn), inds(qtn; at = Site(i; dual=true))) == 4 end - arrays = [permutedims(array, (3, 1, 2)), permutedims(array, (4, 1, 3, 2)), permutedims(array, (3, 1, 2))] # now we have (:r, :o, :l, :i) + arrays = [permutedims(arrays[1], (3, 1, 2)), permutedims(arrays[2], (4, 1, 3, 2)), permutedims(arrays[3], (1, 3, 2))] # now we have (:r, :o, :l, :i) qtn = Chain(Operator(), Open(), arrays, order=[:r, :o, :l, :i]) @test size(tensors(qtn; at = Site(1))) == (1, 2, 4) @test size(tensors(qtn; at = Site(2))) == (3, 2, 1, 4) - @test size(tensors(qtn; at = Site(3))) == (3, 2, 4) + @test size(tensors(qtn; at = Site(3))) == (2, 3, 4) @test leftindex(qtn, Site(1)) == rightindex(qtn, Site(3)) === nothing @test leftindex(qtn, Site(2)) == rightindex(qtn, Site(1)) From f305b002620da57c8be7ea07bd5f48cd78f9cdd0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jofre=20Vall=C3=A8s?= Date: Wed, 19 Jun 2024 09:00:10 +0200 Subject: [PATCH 04/10] Add minor fixes in Chain constructor --- src/Ansatz/Chain.jl | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/Ansatz/Chain.jl b/src/Ansatz/Chain.jl index 46e4f54..ea2b224 100644 --- a/src/Ansatz/Chain.jl +++ b/src/Ansatz/Chain.jl @@ -44,9 +44,9 @@ function Chain(::State, boundary::Periodic, arrays::Vector{<:AbstractArray}; ord map(directions) do dir if dir == :o symbols[i] - elseif dir == :l - symbols[n + mod1(i, n)] elseif dir == :r + symbols[n + mod1(i, n)] + elseif dir == :l symbols[n + mod1(i - 1, n)] else throw(ArgumentError("Invalid direction: $dir")) @@ -83,9 +83,9 @@ function Chain(::State, boundary::Open, arrays::Vector{<:AbstractArray}; order = map(directions) do dir if dir == :o symbols[i] - elseif dir == :l - symbols[n + mod1(i, n)] elseif dir == :r + symbols[n + mod1(i, n)] + elseif dir == :l symbols[n + mod1(i - 1, n)] else throw(ArgumentError("Invalid direction: $dir")) From 1467f809ed42a35dee2a53d7d92a324cf6808a3a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jofre=20Vall=C3=A8s?= Date: Wed, 19 Jun 2024 09:00:27 +0200 Subject: [PATCH 05/10] Enhance tests --- test/Ansatz/Chain_test.jl | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/test/Ansatz/Chain_test.jl b/test/Ansatz/Chain_test.jl index 3fbf664..fe9fa3f 100644 --- a/test/Ansatz/Chain_test.jl +++ b/test/Ansatz/Chain_test.jl @@ -67,9 +67,9 @@ @test size(tensors(qtn; at = Site(2))) == (6, 2, 3, 4) @test size(tensors(qtn; at = Site(3))) == (1, 2, 6, 4) - @test leftindex(qtn, Site(1)) == rightindex(qtn, Site(3)) - @test leftindex(qtn, Site(2)) == rightindex(qtn, Site(1)) - @test leftindex(qtn, Site(3)) == rightindex(qtn, Site(2)) + @test leftindex(qtn, Site(1)) == rightindex(qtn, Site(3)) !== nothing + @test leftindex(qtn, Site(2)) == rightindex(qtn, Site(1)) !== nothing + @test leftindex(qtn, Site(3)) == rightindex(qtn, Site(2)) !== nothing for i in 1:length(arrays) @test size(TensorNetwork(qtn), inds(qtn; at = Site(i))) == 2 @@ -107,8 +107,8 @@ @test size(tensors(qtn; at = Site(3))) == (2, 3) @test leftindex(qtn, Site(1)) == rightindex(qtn, Site(3)) === nothing - @test leftindex(qtn, Site(2)) == rightindex(qtn, Site(1)) - @test leftindex(qtn, Site(3)) == rightindex(qtn, Site(2)) + @test leftindex(qtn, Site(2)) == rightindex(qtn, Site(1)) !== nothing + @test leftindex(qtn, Site(3)) == rightindex(qtn, Site(2)) !== nothing for i in 1:nsites(qtn) @test size(TensorNetwork(qtn), inds(qtn; at = Site(i))) == 2 @@ -131,8 +131,8 @@ @test size(tensors(qtn; at = Site(3))) == (2, 4, 3) @test leftindex(qtn, Site(1)) == rightindex(qtn, Site(3)) === nothing - @test leftindex(qtn, Site(2)) == rightindex(qtn, Site(1)) - @test leftindex(qtn, Site(3)) == rightindex(qtn, Site(2)) + @test leftindex(qtn, Site(2)) == rightindex(qtn, Site(1)) !== nothing + @test leftindex(qtn, Site(3)) == rightindex(qtn, Site(2)) !== nothing for i in 1:length(arrays) @test size(TensorNetwork(qtn), inds(qtn; at = Site(i))) == 2 @@ -147,8 +147,8 @@ @test size(tensors(qtn; at = Site(3))) == (2, 3, 4) @test leftindex(qtn, Site(1)) == rightindex(qtn, Site(3)) === nothing - @test leftindex(qtn, Site(2)) == rightindex(qtn, Site(1)) - @test leftindex(qtn, Site(3)) == rightindex(qtn, Site(2)) + @test leftindex(qtn, Site(2)) == rightindex(qtn, Site(1)) !== nothing + @test leftindex(qtn, Site(3)) == rightindex(qtn, Site(2)) !== nothing for i in 1:length(arrays) @test size(TensorNetwork(qtn), inds(qtn; at = Site(i))) == 2 From 5f5070707ee087233eb8f718be83d55b5a991e57 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jofre=20Vall=C3=A8s?= Date: Wed, 19 Jun 2024 09:14:02 +0200 Subject: [PATCH 06/10] Format code --- src/Ansatz/Chain.jl | 36 ++++++++++++++++++++---------------- test/Ansatz/Chain_test.jl | 22 +++++++++++++--------- 2 files changed, 33 insertions(+), 25 deletions(-) diff --git a/src/Ansatz/Chain.jl b/src/Ansatz/Chain.jl index ea2b224..56d37f3 100644 --- a/src/Ansatz/Chain.jl +++ b/src/Ansatz/Chain.jl @@ -35,7 +35,8 @@ defaultorder(::Operator) = (:o, :i, :l, :r) function Chain(::State, boundary::Periodic, arrays::Vector{<:AbstractArray}; order = defaultorder(State())) @assert all(==(3) ∘ ndims, arrays) "All arrays must have 3 dimensions" - issetequal(order, defaultorder(State())) || throw(ArgumentError("order must be a permutation of $(String.(defaultorder(State())))")) + issetequal(order, defaultorder(State())) || + throw(ArgumentError("order must be a permutation of $(String.(defaultorder(State())))")) n = length(arrays) symbols = [nextindex() for _ in 1:2n] @@ -45,9 +46,9 @@ function Chain(::State, boundary::Periodic, arrays::Vector{<:AbstractArray}; ord if dir == :o symbols[i] elseif dir == :r - symbols[n + mod1(i, n)] + symbols[n+mod1(i, n)] elseif dir == :l - symbols[n + mod1(i - 1, n)] + symbols[n+mod1(i - 1, n)] else throw(ArgumentError("Invalid direction: $dir")) end @@ -68,7 +69,8 @@ function Chain(::State, boundary::Open, arrays::Vector{<:AbstractArray}; order = @assert ndims(arrays[1]) == 2 "First array must have 2 dimensions" @assert all(==(3) ∘ ndims, arrays[2:end-1]) "All arrays must have 3 dimensions" @assert ndims(arrays[end]) == 2 "Last array must have 2 dimensions" - issetequal(order, defaultorder(State())) || throw(ArgumentError("order must be a permutation of $(String.(defaultorder(State())))")) + issetequal(order, defaultorder(State())) || + throw(ArgumentError("order must be a permutation of $(String.(defaultorder(State())))")) n = length(arrays) symbols = [nextindex() for _ in 1:2n] @@ -84,9 +86,9 @@ function Chain(::State, boundary::Open, arrays::Vector{<:AbstractArray}; order = if dir == :o symbols[i] elseif dir == :r - symbols[n + mod1(i, n)] + symbols[n+mod1(i, n)] elseif dir == :l - symbols[n + mod1(i - 1, n)] + symbols[n+mod1(i - 1, n)] else throw(ArgumentError("Invalid direction: $dir")) end @@ -107,7 +109,8 @@ end function Chain(::Operator, boundary::Periodic, arrays::Vector{<:AbstractArray}; order = defaultorder(Operator())) @assert all(==(4) ∘ ndims, arrays) "All arrays must have 4 dimensions" - issetequal(order, defaultorder(Operator())) || throw(ArgumentError("order must be a permutation of $(String.(defaultorder(Operator())))")) + issetequal(order, defaultorder(Operator())) || + throw(ArgumentError("order must be a permutation of $(String.(defaultorder(Operator())))")) n = length(arrays) symbols = [nextindex() for _ in 1:3n] @@ -117,11 +120,11 @@ function Chain(::Operator, boundary::Periodic, arrays::Vector{<:AbstractArray}; if dir == :o symbols[i] elseif dir == :i - symbols[i + n] + symbols[i+n] elseif dir == :l - symbols[2n + mod1(i - 1, n)] + symbols[2n+mod1(i - 1, n)] elseif dir == :r - symbols[2n + mod1(i, n)] + symbols[2n+mod1(i, n)] else throw(ArgumentError("Invalid direction: $dir")) end @@ -134,7 +137,7 @@ function Chain(::Operator, boundary::Periodic, arrays::Vector{<:AbstractArray}; end sitemap = Dict(Site(i) => symbols[i] for i in 1:n) - merge!(sitemap, Dict(Site(i; dual = true) => symbols[i + n] for i in 1:n)) + merge!(sitemap, Dict(Site(i; dual = true) => symbols[i+n] for i in 1:n)) Chain(Quantum(TensorNetwork(_tensors), sitemap), boundary) end @@ -143,7 +146,8 @@ function Chain(::Operator, boundary::Open, arrays::Vector{<:AbstractArray}; orde @assert ndims(arrays[1]) == 3 "First array must have 3 dimensions" @assert all(==(4) ∘ ndims, arrays[2:end-1]) "All arrays must have 4 dimensions" @assert ndims(arrays[end]) == 3 "Last array must have 3 dimensions" - issetequal(order, defaultorder(Operator())) || throw(ArgumentError("order must be a permutation of $(String.(defaultorder(Operator())))")) + issetequal(order, defaultorder(Operator())) || + throw(ArgumentError("order must be a permutation of $(String.(defaultorder(Operator())))")) n = length(arrays) symbols = [nextindex() for _ in 1:3n-1] @@ -159,11 +163,11 @@ function Chain(::Operator, boundary::Open, arrays::Vector{<:AbstractArray}; orde if dir == :o symbols[i] elseif dir == :i - symbols[i + n] + symbols[i+n] elseif dir == :l - symbols[2n + mod1(i - 1, n)] + symbols[2n+mod1(i - 1, n)] elseif dir == :r - symbols[2n + mod1(i, n)] + symbols[2n+mod1(i, n)] else throw(ArgumentError("Invalid direction: $dir")) end @@ -178,7 +182,7 @@ function Chain(::Operator, boundary::Open, arrays::Vector{<:AbstractArray}; orde end sitemap = Dict(Site(i) => symbols[i] for i in 1:n) - merge!(sitemap, Dict(Site(i; dual = true) => symbols[i + n] for i in 1:n)) + merge!(sitemap, Dict(Site(i; dual = true) => symbols[i+n] for i in 1:n)) Chain(Quantum(TensorNetwork(_tensors), sitemap), boundary) end diff --git a/test/Ansatz/Chain_test.jl b/test/Ansatz/Chain_test.jl index fe9fa3f..36d34ee 100644 --- a/test/Ansatz/Chain_test.jl +++ b/test/Ansatz/Chain_test.jl @@ -21,7 +21,7 @@ @test leftindex(qtn, Site(3)) == rightindex(qtn, Site(2)) arrays = [permutedims(array, (3, 1, 2)) for array in arrays] # now we have (:r, :o, :l) - qtn = Chain(State(), Periodic(), arrays, order=[:r, :o, :l]) + qtn = Chain(State(), Periodic(), arrays, order = [:r, :o, :l]) @test size(tensors(qtn; at = Site(1))) == (4, 2, 1) @test size(tensors(qtn; at = Site(2))) == (3, 2, 4) @@ -57,11 +57,11 @@ for i in 1:length(arrays) @test size(TensorNetwork(qtn), inds(qtn; at = Site(i))) == 2 - @test size(TensorNetwork(qtn), inds(qtn; at = Site(i; dual=true))) == 4 + @test size(TensorNetwork(qtn), inds(qtn; at = Site(i; dual = true))) == 4 end arrays = [permutedims(array, (4, 1, 3, 2)) for array in arrays] # now we have (:r, :o, :l, :i) - qtn = Chain(Operator(), Periodic(), arrays, order=[:r, :o, :l, :i]) + qtn = Chain(Operator(), Periodic(), arrays, order = [:r, :o, :l, :i]) @test size(tensors(qtn; at = Site(1))) == (3, 2, 1, 4) @test size(tensors(qtn; at = Site(2))) == (6, 2, 3, 4) @@ -73,7 +73,7 @@ for i in 1:length(arrays) @test size(TensorNetwork(qtn), inds(qtn; at = Site(i))) == 2 - @test size(TensorNetwork(qtn), inds(qtn; at = Site(i; dual=true))) == 4 + @test size(TensorNetwork(qtn), inds(qtn; at = Site(i; dual = true))) == 4 end end end @@ -100,7 +100,7 @@ @test leftindex(qtn, Site(3)) == rightindex(qtn, Site(2)) arrays = [permutedims(arrays[1], (2, 1)), permutedims(arrays[2], (3, 1, 2)), permutedims(arrays[3], (1, 2))] # now we have (:r, :o, :l) - qtn = Chain(State(), Open(), arrays, order=[:r, :o, :l]) + qtn = Chain(State(), Open(), arrays, order = [:r, :o, :l]) @test size(tensors(qtn; at = Site(1))) == (1, 2) @test size(tensors(qtn; at = Site(2))) == (3, 2, 1) @@ -136,11 +136,15 @@ for i in 1:length(arrays) @test size(TensorNetwork(qtn), inds(qtn; at = Site(i))) == 2 - @test size(TensorNetwork(qtn), inds(qtn; at = Site(i; dual=true))) == 4 + @test size(TensorNetwork(qtn), inds(qtn; at = Site(i; dual = true))) == 4 end - arrays = [permutedims(arrays[1], (3, 1, 2)), permutedims(arrays[2], (4, 1, 3, 2)), permutedims(arrays[3], (1, 3, 2))] # now we have (:r, :o, :l, :i) - qtn = Chain(Operator(), Open(), arrays, order=[:r, :o, :l, :i]) + arrays = [ + permutedims(arrays[1], (3, 1, 2)), + permutedims(arrays[2], (4, 1, 3, 2)), + permutedims(arrays[3], (1, 3, 2)), + ] # now we have (:r, :o, :l, :i) + qtn = Chain(Operator(), Open(), arrays, order = [:r, :o, :l, :i]) @test size(tensors(qtn; at = Site(1))) == (1, 2, 4) @test size(tensors(qtn; at = Site(2))) == (3, 2, 1, 4) @@ -152,7 +156,7 @@ for i in 1:length(arrays) @test size(TensorNetwork(qtn), inds(qtn; at = Site(i))) == 2 - @test size(TensorNetwork(qtn), inds(qtn; at = Site(i; dual=true))) == 4 + @test size(TensorNetwork(qtn), inds(qtn; at = Site(i; dual = true))) == 4 end end end From ff981fa09d122355091dcf421178ba85262c3a25 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jofre=20Vall=C3=A8s=20Muns?= <61060572+jofrevalles@users.noreply.github.com> Date: Wed, 19 Jun 2024 12:49:06 +0200 Subject: [PATCH 07/10] Apply @mofeing suggestions from code review MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Sergio Sánchez Ramírez <15837247+mofeing@users.noreply.github.com> --- src/Ansatz/Chain.jl | 14 +++++++------- test/Ansatz/Chain_test.jl | 1 + 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/src/Ansatz/Chain.jl b/src/Ansatz/Chain.jl index 56d37f3..052c961 100644 --- a/src/Ansatz/Chain.jl +++ b/src/Ansatz/Chain.jl @@ -30,12 +30,12 @@ function Chain(tn::TensorNetwork, sites, args...; kwargs...) Chain(Quantum(tn, sites), args...; kwargs...) end -defaultorder(::State) = (:o, :l, :r) -defaultorder(::Operator) = (:o, :i, :l, :r) +defaultorder(::Type{Chain}, ::State) = (:o, :l, :r) +defaultorder(::Type{Chain}, ::Operator) = (:o, :i, :l, :r) -function Chain(::State, boundary::Periodic, arrays::Vector{<:AbstractArray}; order = defaultorder(State())) +function Chain(::State, boundary::Periodic, arrays::Vector{<:AbstractArray}; order = defaultorder(Chain, State())) @assert all(==(3) ∘ ndims, arrays) "All arrays must have 3 dimensions" - issetequal(order, defaultorder(State())) || + issetequal(order, defaultorder(Chain, State())) || throw(ArgumentError("order must be a permutation of $(String.(defaultorder(State())))")) n = length(arrays) @@ -65,7 +65,7 @@ function Chain(::State, boundary::Periodic, arrays::Vector{<:AbstractArray}; ord Chain(Quantum(TensorNetwork(_tensors), sitemap), boundary) end -function Chain(::State, boundary::Open, arrays::Vector{<:AbstractArray}; order = defaultorder(State())) +function Chain(::State, boundary::Open, arrays::Vector{<:AbstractArray}; order = defaultorder(Chain, State())) @assert ndims(arrays[1]) == 2 "First array must have 2 dimensions" @assert all(==(3) ∘ ndims, arrays[2:end-1]) "All arrays must have 3 dimensions" @assert ndims(arrays[end]) == 2 "Last array must have 2 dimensions" @@ -107,7 +107,7 @@ function Chain(::State, boundary::Open, arrays::Vector{<:AbstractArray}; order = Chain(Quantum(TensorNetwork(_tensors), sitemap), boundary) end -function Chain(::Operator, boundary::Periodic, arrays::Vector{<:AbstractArray}; order = defaultorder(Operator())) +function Chain(::Operator, boundary::Periodic, arrays::Vector{<:AbstractArray}; order = defaultorder(Chain, Operator())) @assert all(==(4) ∘ ndims, arrays) "All arrays must have 4 dimensions" issetequal(order, defaultorder(Operator())) || throw(ArgumentError("order must be a permutation of $(String.(defaultorder(Operator())))")) @@ -142,7 +142,7 @@ function Chain(::Operator, boundary::Periodic, arrays::Vector{<:AbstractArray}; Chain(Quantum(TensorNetwork(_tensors), sitemap), boundary) end -function Chain(::Operator, boundary::Open, arrays::Vector{<:AbstractArray}; order = defaultorder(Operator())) +function Chain(::Operator, boundary::Open, arrays::Vector{<:AbstractArray}; order = defaultorder(Chain, Operator())) @assert ndims(arrays[1]) == 3 "First array must have 3 dimensions" @assert all(==(4) ∘ ndims, arrays[2:end-1]) "All arrays must have 4 dimensions" @assert ndims(arrays[end]) == 3 "Last array must have 3 dimensions" diff --git a/test/Ansatz/Chain_test.jl b/test/Ansatz/Chain_test.jl index 36d34ee..7496d85 100644 --- a/test/Ansatz/Chain_test.jl +++ b/test/Ansatz/Chain_test.jl @@ -35,6 +35,7 @@ @test size(TensorNetwork(qtn), inds(qtn; at = Site(i))) == 2 end end + @testset "Operator" begin qtn = Chain(Operator(), Periodic(), [rand(2, 2, 4, 4) for _ in 1:3]) @test socket(qtn) == Operator() From 7375cb59d20e4bac0865021be7e0d88b69ed6960 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jofre=20Vall=C3=A8s?= Date: Wed, 19 Jun 2024 12:58:35 +0200 Subject: [PATCH 08/10] Remove unnecessary helper function --- src/Ansatz/Chain.jl | 68 +++++++++++++++++---------------------------- 1 file changed, 26 insertions(+), 42 deletions(-) diff --git a/src/Ansatz/Chain.jl b/src/Ansatz/Chain.jl index 052c961..5d10dd6 100644 --- a/src/Ansatz/Chain.jl +++ b/src/Ansatz/Chain.jl @@ -41,8 +41,8 @@ function Chain(::State, boundary::Periodic, arrays::Vector{<:AbstractArray}; ord n = length(arrays) symbols = [nextindex() for _ in 1:2n] - function get_index(directions, i) - map(directions) do dir + _tensors = map(enumerate(arrays)) do (i, array) + inds = map(order) do dir if dir == :o symbols[i] elseif dir == :r @@ -53,10 +53,6 @@ function Chain(::State, boundary::Periodic, arrays::Vector{<:AbstractArray}; ord throw(ArgumentError("Invalid direction: $dir")) end end - end - - _tensors = map(enumerate(arrays)) do (i, array) - inds = get_index(order, i) Tensor(array, inds) end @@ -69,20 +65,22 @@ function Chain(::State, boundary::Open, arrays::Vector{<:AbstractArray}; order = @assert ndims(arrays[1]) == 2 "First array must have 2 dimensions" @assert all(==(3) ∘ ndims, arrays[2:end-1]) "All arrays must have 3 dimensions" @assert ndims(arrays[end]) == 2 "Last array must have 2 dimensions" - issetequal(order, defaultorder(State())) || - throw(ArgumentError("order must be a permutation of $(String.(defaultorder(State())))")) + issetequal(order, defaultorder(Chain, State())) || + throw(ArgumentError("order must be a permutation of $(String.(defaultorder(Chain, State())))")) n = length(arrays) symbols = [nextindex() for _ in 1:2n] - function get_index(directions, i, is_first, is_last) - if is_first - directions = filter(x -> x != :l, directions) - elseif is_last - directions = filter(x -> x != :r, directions) + _tensors = map(enumerate(arrays)) do (i, array) + if i == 1 + _order = filter(x -> x != :l, order) + elseif i == n + _order = filter(x -> x != :r, order) + else + _order = order end - map(directions) do dir + inds = map(_order) do dir if dir == :o symbols[i] elseif dir == :r @@ -93,12 +91,6 @@ function Chain(::State, boundary::Open, arrays::Vector{<:AbstractArray}; order = throw(ArgumentError("Invalid direction: $dir")) end end - end - - _tensors = map(enumerate(arrays)) do (i, array) - is_first = (i == 1) - is_last = (i == n) - inds = get_index(order, i, is_first, is_last) Tensor(array, inds) end @@ -109,14 +101,14 @@ end function Chain(::Operator, boundary::Periodic, arrays::Vector{<:AbstractArray}; order = defaultorder(Chain, Operator())) @assert all(==(4) ∘ ndims, arrays) "All arrays must have 4 dimensions" - issetequal(order, defaultorder(Operator())) || - throw(ArgumentError("order must be a permutation of $(String.(defaultorder(Operator())))")) + issetequal(order, defaultorder(Chain, Operator())) || + throw(ArgumentError("order must be a permutation of $(String.(defaultorder(Chain, Operator())))")) n = length(arrays) symbols = [nextindex() for _ in 1:3n] - function get_index(directions, i) - map(directions) do dir + _tensors = map(enumerate(arrays)) do (i, array) + inds = map(order) do dir if dir == :o symbols[i] elseif dir == :i @@ -129,10 +121,6 @@ function Chain(::Operator, boundary::Periodic, arrays::Vector{<:AbstractArray}; throw(ArgumentError("Invalid direction: $dir")) end end - end - - _tensors = map(enumerate(arrays)) do (i, array) - inds = get_index(order, i) Tensor(array, inds) end @@ -146,20 +134,22 @@ function Chain(::Operator, boundary::Open, arrays::Vector{<:AbstractArray}; orde @assert ndims(arrays[1]) == 3 "First array must have 3 dimensions" @assert all(==(4) ∘ ndims, arrays[2:end-1]) "All arrays must have 4 dimensions" @assert ndims(arrays[end]) == 3 "Last array must have 3 dimensions" - issetequal(order, defaultorder(Operator())) || - throw(ArgumentError("order must be a permutation of $(String.(defaultorder(Operator())))")) + issetequal(order, defaultorder(Chain, Operator())) || + throw(ArgumentError("order must be a permutation of $(String.(defaultorder(Chain, Operator())))")) n = length(arrays) symbols = [nextindex() for _ in 1:3n-1] - function get_index(directions, i, is_first, is_last) - if is_first - directions = filter(x -> x != :l, directions) - elseif is_last - directions = filter(x -> x != :r, directions) + _tensors = map(enumerate(arrays)) do (i, array) + if i == 1 + _order = filter(x -> x != :l, order) + elseif i == n + _order = filter(x -> x != :r, order) + else + _order = order end - map(directions) do dir + inds = map(_order) do dir if dir == :o symbols[i] elseif dir == :i @@ -172,12 +162,6 @@ function Chain(::Operator, boundary::Open, arrays::Vector{<:AbstractArray}; orde throw(ArgumentError("Invalid direction: $dir")) end end - end - - _tensors = map(enumerate(arrays)) do (i, array) - is_first = (i == 1) - is_last = (i == n) - inds = get_index(order, i, is_first, is_last) Tensor(array, inds) end From e52302baaab72bbe40757d0f199859d8d1464049 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jofre=20Vall=C3=A8s=20Muns?= <61060572+jofrevalles@users.noreply.github.com> Date: Wed, 19 Jun 2024 13:04:08 +0200 Subject: [PATCH 09/10] Apply suggestions from code review MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Sergio Sánchez Ramírez <15837247+mofeing@users.noreply.github.com> --- src/Ansatz/Chain.jl | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/src/Ansatz/Chain.jl b/src/Ansatz/Chain.jl index 5d10dd6..95ffb4e 100644 --- a/src/Ansatz/Chain.jl +++ b/src/Ansatz/Chain.jl @@ -77,7 +77,13 @@ function Chain(::State, boundary::Open, arrays::Vector{<:AbstractArray}; order = elseif i == n _order = filter(x -> x != :r, order) else - _order = order + local order = if i == 1 + filter(x -> x != :l, order) + elseif i == n + filter(x -> x != :r, order) + else + order + end end inds = map(_order) do dir From 025a693a05b19a449357c96413b57e2a1708c4f2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jofre=20Vall=C3=A8s?= Date: Wed, 19 Jun 2024 13:05:52 +0200 Subject: [PATCH 10/10] Minor aesthetic updates in code --- src/Ansatz/Chain.jl | 22 ++++++++-------------- 1 file changed, 8 insertions(+), 14 deletions(-) diff --git a/src/Ansatz/Chain.jl b/src/Ansatz/Chain.jl index 95ffb4e..4bc699b 100644 --- a/src/Ansatz/Chain.jl +++ b/src/Ansatz/Chain.jl @@ -72,18 +72,12 @@ function Chain(::State, boundary::Open, arrays::Vector{<:AbstractArray}; order = symbols = [nextindex() for _ in 1:2n] _tensors = map(enumerate(arrays)) do (i, array) - if i == 1 - _order = filter(x -> x != :l, order) + _order = if i == 1 + filter(x -> x != :l, order) elseif i == n - _order = filter(x -> x != :r, order) + filter(x -> x != :r, order) else - local order = if i == 1 - filter(x -> x != :l, order) - elseif i == n - filter(x -> x != :r, order) - else - order - end + order end inds = map(_order) do dir @@ -147,12 +141,12 @@ function Chain(::Operator, boundary::Open, arrays::Vector{<:AbstractArray}; orde symbols = [nextindex() for _ in 1:3n-1] _tensors = map(enumerate(arrays)) do (i, array) - if i == 1 - _order = filter(x -> x != :l, order) + _order = if i == 1 + filter(x -> x != :l, order) elseif i == n - _order = filter(x -> x != :r, order) + filter(x -> x != :r, order) else - _order = order + order end inds = map(_order) do dir