diff --git a/src/Ansatz.jl b/src/Ansatz.jl index 3378fea..de3b911 100644 --- a/src/Ansatz.jl +++ b/src/Ansatz.jl @@ -74,7 +74,7 @@ struct MissingSchmidtCoefficientsException <: Base.Exception bond::NTuple{2,Site} end -MissingSchmidtCoefficientsException(bond::Vector{Site}) = MissingSchmidtCoefficientsException(tuple(bond...)) +MissingSchmidtCoefficientsException(bond::Vector{<:Site}) = MissingSchmidtCoefficientsException(tuple(bond...)) function Base.showerror(io::IO, e::MissingSchmidtCoefficientsException) print(io, "Can't access the spectrum on bond $(e.bond)") diff --git a/src/Ansatz/Chain.jl b/src/Ansatz/Chain.jl index d567b09..ec62f94 100644 --- a/src/Ansatz/Chain.jl +++ b/src/Ansatz/Chain.jl @@ -117,12 +117,12 @@ function Base.convert(::Type{Chain}, qtn::Product) end leftsite(tn::Chain, site::Site) = leftsite(boundary(tn), tn, site) -leftsite(::Open, tn::Chain, site::Site) = site.id ∈ range(2, nlanes(tn)) ? Site(site.id - 1) : nothing -leftsite(::Periodic, tn::Chain, site::Site) = Site(mod1(site.id - 1, nlanes(tn))) +leftsite(::Open, tn::Chain, site::Site) = id(site) ∈ range(2, nlanes(tn)) ? Site(id(site) - 1) : nothing +leftsite(::Periodic, tn::Chain, site::Site) = Site(mod1(id(site) - 1, nlanes(tn))) rightsite(tn::Chain, site::Site) = rightsite(boundary(tn), tn, site) -rightsite(::Open, tn::Chain, site::Site) = site.id ∈ range(1, nlanes(tn) - 1) ? Site(site.id + 1) : nothing -rightsite(::Periodic, tn::Chain, site::Site) = Site(mod1(site.id + 1, nlanes(tn))) +rightsite(::Open, tn::Chain, site::Site) = id(site) ∈ range(1, nlanes(tn) - 1) ? Site(id(site) + 1) : nothing +rightsite(::Periodic, tn::Chain, site::Site) = Site(mod1(id(site) + 1, nlanes(tn))) leftindex(tn::Chain, site::Site) = leftindex(boundary(tn), tn, site) leftindex(::Open, tn::Chain, site::Site) = site == site"1" ? nothing : leftindex(Periodic(), tn, site) @@ -230,7 +230,21 @@ end Tenet.contract(tn::Chain, query::Symbol, args...; kwargs...) = contract!(copy(tn), Val(query), args...; kwargs...) Tenet.contract!(tn::Chain, query::Symbol, args...; kwargs...) = contract!(tn, Val(query), args...; kwargs...) -function Tenet.contract!(tn::Chain, ::Val{:between}, site1::Site, site2::Site; direction::Symbol = :left) +""" + Tenet.contract!(tn::Chain, ::Val{:between}, site1::Site, site2::Site; direction::Symbol = :left, delete_Λ = true) + +For a given [`Chain`](@ref) tensor network, contracts the singular values Λ between two sites `site1` and `site2`. +The `direction` keyword argument specifies the direction of the contraction, and the `delete_Λ` keyword argument +specifies whether to delete the singular values tensor after the contraction. +""" +function Tenet.contract!( + tn::Chain, + ::Val{:between}, + site1::Site, + site2::Site; + direction::Symbol = :left, + delete_Λ = true, +) Λᵢ = select(tn, :between, site1, site2) Λᵢ === nothing && return tn @@ -244,7 +258,7 @@ function Tenet.contract!(tn::Chain, ::Val{:between}, site1::Site, site2::Site; d throw(ArgumentError("Unknown direction=:$direction")) end - delete!(TensorNetwork(tn), Λᵢ) + delete_Λ && delete!(TensorNetwork(tn), Λᵢ) return tn end @@ -418,12 +432,12 @@ and in the center there is a matrix with singular values. """ function mixed_canonize!(::Open, tn::Chain, center::Site) # TODO: center could be a range of sites # left-to-right QR sweep (left-canonical tensors) - for i in 1:center.id-1 + for i in 1:id(center)-1 canonize_site!(tn, Site(i); direction = :right, method = :qr) end # right-to-left QR sweep (right-canonical tensors) - for i in nsites(tn):-1:center.id+1 + for i in nsites(tn):-1:id(center)+1 canonize_site!(tn, Site(i); direction = :left, method = :qr) end @@ -447,7 +461,7 @@ end Applies a local operator `gate` to the [`Chain`](@ref) tensor network. """ -function evolve!(qtn::Chain, gate::Dense; threshold = nothing, maxdim = nothing) +function evolve!(qtn::Chain, gate::Dense; threshold = nothing, maxdim = nothing, iscanonical = false) # check gate is a valid operator if !(socket(gate) isa Operator) throw(ArgumentError("Gate must be an operator, but got $(socket(gate))")) @@ -468,13 +482,13 @@ function evolve!(qtn::Chain, gate::Dense; threshold = nothing, maxdim = nothing) elseif nlanes(gate) == 2 # check gate sites are contiguous # TODO refactor this out? - gate_inputs = sort!(map(x -> x.id, inputs(gate))) + gate_inputs = sort!(map(id, inputs(gate))) range = UnitRange(extrema(gate_inputs)...) range != gate_inputs && throw(ArgumentError("Gate lanes must be contiguous")) # TODO check correctly for periodic boundary conditions - evolve_2site!(qtn, gate; threshold, maxdim) + evolve_2site!(qtn, gate; threshold, maxdim, iscanonical = iscanonical) else # TODO generalize for more than 2 lanes throw(ArgumentError("Invalid number of lanes $(nlanes(gate)), maximum is 2")) @@ -502,7 +516,8 @@ function evolve_1site!(qtn::Chain, gate::Dense) contract!(TensorNetwork(qtn), contracting_index) end -function evolve_2site!(qtn::Chain, gate::Dense; threshold, maxdim) +# TODO: Maybe rename iscanonical kwarg ? +function evolve_2site!(qtn::Chain, gate::Dense; threshold, maxdim, iscanonical = false) # shallow copy to avoid problems if errors in mid execution gate = copy(gate) @@ -510,9 +525,9 @@ function evolve_2site!(qtn::Chain, gate::Dense; threshold, maxdim) left_inds::Vector{Symbol} = !isnothing(leftindex(qtn, sitel)) ? [leftindex(qtn, sitel)] : Symbol[] right_inds::Vector{Symbol} = !isnothing(rightindex(qtn, siter)) ? [rightindex(qtn, siter)] : Symbol[] - # contract virtual index virtualind::Symbol = select(qtn, :bond, bond...) - contract!(TensorNetwork(qtn), virtualind) + + iscanonical ? contract_2sitewf!(qtn, bond) : contract!(TensorNetwork(qtn), virtualind) # reindex contracting index contracting_inds = [gensym(:tmp) for _ in inputs(gate)] @@ -537,8 +552,12 @@ function evolve_2site!(qtn::Chain, gate::Dense; threshold, maxdim) # decompose using SVD push!(left_inds, select(qtn, :index, sitel)) push!(right_inds, select(qtn, :index, siter)) - svd!(TensorNetwork(qtn); left_inds, right_inds, virtualind) + if iscanonical + unpack_2sitewf!(qtn, bond, left_inds, right_inds, virtualind) + else + svd!(TensorNetwork(qtn); left_inds, right_inds, virtualind) + end # truncate virtual index if any(!isnothing, [threshold, maxdim]) truncate!(qtn, bond; threshold, maxdim) @@ -547,6 +566,67 @@ function evolve_2site!(qtn::Chain, gate::Dense; threshold, maxdim) return qtn end +""" + contract_2sitewf!(ψ::Chain, bond) + +For a given [`Chain`](@ref) in the canonical form, creates the two-site wave function θ with Λᵢ₋₁Γᵢ₋₁ΛᵢΓᵢΛᵢ₊₁, +where i is the `bond`, and replaces the Γᵢ₋₁ΛᵢΓᵢ tensors with θ. +""" +function contract_2sitewf!(ψ::Chain, bond) + # TODO Check if ψ is in canonical form + + sitel, siter = bond # TODO Check if bond is valid + (0 < id(sitel) < nsites(ψ) || 0 < id(siter) < nsites(ψ)) || + throw(ArgumentError("The sites in the bond must be between 1 and $(nsites(ψ))")) + + Λᵢ₋₁ = id(sitel) == 1 ? nothing : select(ψ, :between, Site(id(sitel) - 1), sitel) + Λᵢ₊₁ = id(sitel) == nsites(ψ) - 1 ? nothing : select(ψ, :between, siter, Site(id(siter) + 1)) + + !isnothing(Λᵢ₋₁) && contract!(ψ, :between, Site(id(sitel) - 1), sitel; direction = :right, delete_Λ = false) + !isnothing(Λᵢ₊₁) && contract!(ψ, :between, siter, Site(id(siter) + 1); direction = :left, delete_Λ = false) + + contract!(TensorNetwork(ψ), select(ψ, :bond, bond...)) + + return ψ +end + +""" + unpack_2sitewf!(ψ::Chain, bond) + +For a given [`Chain`](@ref) that contains a two-site wave function θ in a bond, it decomposes θ into the canonical +form: Γᵢ₋₁ΛᵢΓᵢ, where i is the `bond`. +""" +function unpack_2sitewf!(ψ::Chain, bond, left_inds, right_inds, virtualind) + # TODO Check if ψ is in canonical form + + sitel, siter = bond # TODO Check if bond is valid + (0 < id(sitel) < nsites(ψ) || 0 < id(site_r) < nsites(ψ)) || + throw(ArgumentError("The sites in the bond must be between 1 and $(nsites(ψ))")) + + Λᵢ₋₁ = id(sitel) == 1 ? nothing : select(ψ, :between, Site(id(sitel) - 1), sitel) + Λᵢ₊₁ = id(siter) == nsites(ψ) ? nothing : select(ψ, :between, siter, Site(id(siter) + 1)) + + # do svd of the θ tensor + θ = select(ψ, :tensor, sitel) + U, s, Vt = svd(θ; left_inds, right_inds, virtualind) + + # contract with the inverse of Λᵢ and Λᵢ₊₂ + Γᵢ₋₁ = + isnothing(Λᵢ₋₁) ? U : + contract(U, Tensor(diag(pinv(Diagonal(parent(Λᵢ₋₁)), atol = 1e-32)), inds(Λᵢ₋₁)), dims = ()) + Γᵢ = + isnothing(Λᵢ₊₁) ? Vt : + contract(Tensor(diag(pinv(Diagonal(parent(Λᵢ₊₁)), atol = 1e-32)), inds(Λᵢ₊₁)), Vt, dims = ()) + + delete!(TensorNetwork(ψ), θ) + + push!(TensorNetwork(ψ), Γᵢ₋₁) + push!(TensorNetwork(ψ), s) + push!(TensorNetwork(ψ), Γᵢ) + + return ψ +end + function expect(ψ::Chain, observables) # contract observable with TN ϕ = copy(ψ) diff --git a/src/Quantum.jl b/src/Quantum.jl index 449dc2c..4001bcd 100644 --- a/src/Quantum.jl +++ b/src/Quantum.jl @@ -11,30 +11,39 @@ Represents a physical index. - Should we store here some information about quantum numbers? """ -struct Site - id::Int +struct Site{N} + id::NTuple{N,Int} dual::Bool - Site(id; dual = false) = new(id, dual) + Site(id::NTuple{N,Int}; dual = false) where {N} = new{N}(id, dual) end +Site(id::Int; kwargs...) = Site((id,); kwargs...) +Site(id::Vararg{Int,N}; kwargs...) where {N} = Site(id; kwargs...) + +id(site::Site{1}) = only(site.id) +id(site::Site) = site.id + +Base.CartesianIndex(site::Site) = CartesianIndex(id(site)) + isdual(site::Site) = site.dual -Base.show(io::IO, site::Site) = print(io, "$(site.id)$(site.dual ? "'" : "")") -Base.adjoint(site::Site) = Site(site.id; dual = !site.dual) -Base.isless(a::Site, b::Site) = a.id < b.id +Base.show(io::IO, site::Site) = print(io, "$(id(site))$(site.dual ? "'" : "")") +Base.adjoint(site::Site) = Site(id(site); dual = !site.dual) +Base.isless(a::Site, b::Site) = id(a) < id(b) macro site_str(str) - m = match(r"^(\d+)('?)$", str) + m = match(r"^(\d+,)*\d+('?)$", str) if isnothing(m) error("Invalid site string: $str") end - id = parse(Int, m.captures[1]) - dual = m.captures[2] == "'" + id = tuple(map(eachmatch(r"(\d+)", str)) do match + parse(Int, only(match.captures)) + end...) - quote - Site($id; dual = $dual) - end + dual = endswith(str, "'") + + return :(Site($id; dual = $dual)) end """ diff --git a/test/Site_test.jl b/test/Site_test.jl index 1578dfa..a097858 100644 --- a/test/Site_test.jl +++ b/test/Site_test.jl @@ -1,25 +1,63 @@ @testset "Site" begin + using Qrochet: id + s = Site(1) - @test s.id == 1 - @test s.dual == false + @test id(s) == 1 + @test CartesianIndex(s) == CartesianIndex(1) + @test isdual(s) == false s = Site(1; dual = true) - @test s.id == 1 - @test s.dual == true + @test id(s) == 1 + @test CartesianIndex(s) == CartesianIndex(1) + @test isdual(s) == true + + s = Site(1, 2) + @test id(s) == (1, 2) + @test CartesianIndex(s) == CartesianIndex((1, 2)) + @test isdual(s) == false + + s = Site(1, 2; dual = true) + @test id(s) == (1, 2) + @test CartesianIndex(s) == CartesianIndex((1, 2)) + @test isdual(s) == true s = site"1" - @test s.id == 1 - @test s.dual == false + @test id(s) == 1 + @test CartesianIndex(s) == CartesianIndex(1) + @test isdual(s) == false s = site"1'" - @test s.id == 1 - @test s.dual == true + @test id(s) == 1 + @test CartesianIndex(s) == CartesianIndex(1) + @test isdual(s) == true + + s = site"1,2" + @test id(s) == (1, 2) + @test CartesianIndex(s) == CartesianIndex((1, 2)) + @test isdual(s) == false + + s = site"1,2'" + @test id(s) == (1, 2) + @test CartesianIndex(s) == CartesianIndex((1, 2)) + @test isdual(s) == true s = site"1" |> adjoint - @test s.id == 1 - @test s.dual == true + @test id(s) == 1 + @test CartesianIndex(s) == CartesianIndex(1) + @test isdual(s) == true s = site"1'" |> adjoint - @test s.id == 1 - @test s.dual == false + @test id(s) == 1 + @test CartesianIndex(s) == CartesianIndex(1) + @test isdual(s) == false + + s = site"1,2" |> adjoint + @test id(s) == (1, 2) + @test CartesianIndex(s) == CartesianIndex((1, 2)) + @test isdual(s) == true + + s = site"1,2'" |> adjoint + @test id(s) == (1, 2) + @test CartesianIndex(s) == CartesianIndex((1, 2)) + @test isdual(s) == false end