From 7375cb59d20e4bac0865021be7e0d88b69ed6960 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jofre=20Vall=C3=A8s?= Date: Wed, 19 Jun 2024 12:58:35 +0200 Subject: [PATCH] Remove unnecessary helper function --- src/Ansatz/Chain.jl | 68 +++++++++++++++++---------------------------- 1 file changed, 26 insertions(+), 42 deletions(-) diff --git a/src/Ansatz/Chain.jl b/src/Ansatz/Chain.jl index 052c961..5d10dd6 100644 --- a/src/Ansatz/Chain.jl +++ b/src/Ansatz/Chain.jl @@ -41,8 +41,8 @@ function Chain(::State, boundary::Periodic, arrays::Vector{<:AbstractArray}; ord n = length(arrays) symbols = [nextindex() for _ in 1:2n] - function get_index(directions, i) - map(directions) do dir + _tensors = map(enumerate(arrays)) do (i, array) + inds = map(order) do dir if dir == :o symbols[i] elseif dir == :r @@ -53,10 +53,6 @@ function Chain(::State, boundary::Periodic, arrays::Vector{<:AbstractArray}; ord throw(ArgumentError("Invalid direction: $dir")) end end - end - - _tensors = map(enumerate(arrays)) do (i, array) - inds = get_index(order, i) Tensor(array, inds) end @@ -69,20 +65,22 @@ function Chain(::State, boundary::Open, arrays::Vector{<:AbstractArray}; order = @assert ndims(arrays[1]) == 2 "First array must have 2 dimensions" @assert all(==(3) ∘ ndims, arrays[2:end-1]) "All arrays must have 3 dimensions" @assert ndims(arrays[end]) == 2 "Last array must have 2 dimensions" - issetequal(order, defaultorder(State())) || - throw(ArgumentError("order must be a permutation of $(String.(defaultorder(State())))")) + issetequal(order, defaultorder(Chain, State())) || + throw(ArgumentError("order must be a permutation of $(String.(defaultorder(Chain, State())))")) n = length(arrays) symbols = [nextindex() for _ in 1:2n] - function get_index(directions, i, is_first, is_last) - if is_first - directions = filter(x -> x != :l, directions) - elseif is_last - directions = filter(x -> x != :r, directions) + _tensors = map(enumerate(arrays)) do (i, array) + if i == 1 + _order = filter(x -> x != :l, order) + elseif i == n + _order = filter(x -> x != :r, order) + else + _order = order end - map(directions) do dir + inds = map(_order) do dir if dir == :o symbols[i] elseif dir == :r @@ -93,12 +91,6 @@ function Chain(::State, boundary::Open, arrays::Vector{<:AbstractArray}; order = throw(ArgumentError("Invalid direction: $dir")) end end - end - - _tensors = map(enumerate(arrays)) do (i, array) - is_first = (i == 1) - is_last = (i == n) - inds = get_index(order, i, is_first, is_last) Tensor(array, inds) end @@ -109,14 +101,14 @@ end function Chain(::Operator, boundary::Periodic, arrays::Vector{<:AbstractArray}; order = defaultorder(Chain, Operator())) @assert all(==(4) ∘ ndims, arrays) "All arrays must have 4 dimensions" - issetequal(order, defaultorder(Operator())) || - throw(ArgumentError("order must be a permutation of $(String.(defaultorder(Operator())))")) + issetequal(order, defaultorder(Chain, Operator())) || + throw(ArgumentError("order must be a permutation of $(String.(defaultorder(Chain, Operator())))")) n = length(arrays) symbols = [nextindex() for _ in 1:3n] - function get_index(directions, i) - map(directions) do dir + _tensors = map(enumerate(arrays)) do (i, array) + inds = map(order) do dir if dir == :o symbols[i] elseif dir == :i @@ -129,10 +121,6 @@ function Chain(::Operator, boundary::Periodic, arrays::Vector{<:AbstractArray}; throw(ArgumentError("Invalid direction: $dir")) end end - end - - _tensors = map(enumerate(arrays)) do (i, array) - inds = get_index(order, i) Tensor(array, inds) end @@ -146,20 +134,22 @@ function Chain(::Operator, boundary::Open, arrays::Vector{<:AbstractArray}; orde @assert ndims(arrays[1]) == 3 "First array must have 3 dimensions" @assert all(==(4) ∘ ndims, arrays[2:end-1]) "All arrays must have 4 dimensions" @assert ndims(arrays[end]) == 3 "Last array must have 3 dimensions" - issetequal(order, defaultorder(Operator())) || - throw(ArgumentError("order must be a permutation of $(String.(defaultorder(Operator())))")) + issetequal(order, defaultorder(Chain, Operator())) || + throw(ArgumentError("order must be a permutation of $(String.(defaultorder(Chain, Operator())))")) n = length(arrays) symbols = [nextindex() for _ in 1:3n-1] - function get_index(directions, i, is_first, is_last) - if is_first - directions = filter(x -> x != :l, directions) - elseif is_last - directions = filter(x -> x != :r, directions) + _tensors = map(enumerate(arrays)) do (i, array) + if i == 1 + _order = filter(x -> x != :l, order) + elseif i == n + _order = filter(x -> x != :r, order) + else + _order = order end - map(directions) do dir + inds = map(_order) do dir if dir == :o symbols[i] elseif dir == :i @@ -172,12 +162,6 @@ function Chain(::Operator, boundary::Open, arrays::Vector{<:AbstractArray}; orde throw(ArgumentError("Invalid direction: $dir")) end end - end - - _tensors = map(enumerate(arrays)) do (i, array) - is_first = (i == 1) - is_last = (i == n) - inds = get_index(order, i, is_first, is_last) Tensor(array, inds) end