Skip to content
This repository has been archived by the owner on Jul 7, 2024. It is now read-only.

Extend evolve! for Chain in canonical form #31

Merged
merged 35 commits into from
Mar 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
dfbd7f6
Add first implementation
jofrevalles Mar 15, 2024
a04f1de
Fix typo
jofrevalles Mar 15, 2024
abd5298
Fix pinv atol
jofrevalles Mar 15, 2024
1ec68bc
Add Quac integration tests for evolve
jofrevalles Mar 15, 2024
0d71bc6
Fix format
jofrevalles Mar 15, 2024
b4fa57e
Add missing import
jofrevalles Mar 15, 2024
6581346
Fix location of tests
jofrevalles Mar 15, 2024
7a13651
Remove unnecessary import
jofrevalles Mar 15, 2024
a5c7448
Remove unnecessary import
jofrevalles Mar 15, 2024
e1938ce
Remove tests
jofrevalles Mar 18, 2024
048d14e
Change function names
jofrevalles Mar 18, 2024
5299d5c
Fix code
jofrevalles Mar 18, 2024
39c9d51
Fix typo
jofrevalles Mar 18, 2024
e6ef963
Replace condition for isnothing function
jofrevalles Mar 18, 2024
b45a491
Replace condition for isnothing function
jofrevalles Mar 18, 2024
f62e387
Fix format
jofrevalles Mar 18, 2024
7366f97
Lower the atol pinv threshold
jofrevalles Mar 19, 2024
5c6d53d
Add delete_lambda as kwarg argument in contract
jofrevalles Mar 25, 2024
582db60
Refactor code from main functions
jofrevalles Mar 25, 2024
dad2a14
Fix typo
jofrevalles Mar 25, 2024
e475877
Merge branch 'master' into feature/evolve-canonized
mofeing Mar 25, 2024
1cc57b5
Fix default delete_lambda kwarg
jofrevalles Mar 25, 2024
e8fc71a
Format code
jofrevalles Mar 25, 2024
a8b74cb
Fix format
jofrevalles Mar 25, 2024
2819826
Fix format
jofrevalles Mar 25, 2024
54d1396
Add docstrings
jofrevalles Mar 25, 2024
395af8a
Change name to contract_2sitewf!
jofrevalles Mar 25, 2024
2130c15
Format code
jofrevalles Mar 25, 2024
c9194b7
Create unpack_2sitewf! function
jofrevalles Mar 25, 2024
fd880b6
Refactor `Site` to N-dimensional coordinates
mofeing Mar 24, 2024
5dde45d
Refactor code
jofrevalles Mar 25, 2024
910b357
Fix typo
jofrevalles Mar 25, 2024
e4cebc8
Fix typo
jofrevalles Mar 25, 2024
8db348e
Fix typo
jofrevalles Mar 25, 2024
bccf668
Fix typo
jofrevalles Mar 25, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/Ansatz.jl
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ struct MissingSchmidtCoefficientsException <: Base.Exception
bond::NTuple{2,Site}
end

MissingSchmidtCoefficientsException(bond::Vector{Site}) = MissingSchmidtCoefficientsException(tuple(bond...))
MissingSchmidtCoefficientsException(bond::Vector{<:Site}) = MissingSchmidtCoefficientsException(tuple(bond...))

function Base.showerror(io::IO, e::MissingSchmidtCoefficientsException)
print(io, "Can't access the spectrum on bond $(e.bond)")
Expand Down
110 changes: 95 additions & 15 deletions src/Ansatz/Chain.jl
Original file line number Diff line number Diff line change
Expand Up @@ -117,12 +117,12 @@
end

leftsite(tn::Chain, site::Site) = leftsite(boundary(tn), tn, site)
leftsite(::Open, tn::Chain, site::Site) = site.id ∈ range(2, nlanes(tn)) ? Site(site.id - 1) : nothing
leftsite(::Periodic, tn::Chain, site::Site) = Site(mod1(site.id - 1, nlanes(tn)))
leftsite(::Open, tn::Chain, site::Site) = id(site) ∈ range(2, nlanes(tn)) ? Site(id(site) - 1) : nothing
leftsite(::Periodic, tn::Chain, site::Site) = Site(mod1(id(site) - 1, nlanes(tn)))

rightsite(tn::Chain, site::Site) = rightsite(boundary(tn), tn, site)
rightsite(::Open, tn::Chain, site::Site) = site.id ∈ range(1, nlanes(tn) - 1) ? Site(site.id + 1) : nothing
rightsite(::Periodic, tn::Chain, site::Site) = Site(mod1(site.id + 1, nlanes(tn)))
rightsite(::Open, tn::Chain, site::Site) = id(site) ∈ range(1, nlanes(tn) - 1) ? Site(id(site) + 1) : nothing
rightsite(::Periodic, tn::Chain, site::Site) = Site(mod1(id(site) + 1, nlanes(tn)))

leftindex(tn::Chain, site::Site) = leftindex(boundary(tn), tn, site)
leftindex(::Open, tn::Chain, site::Site) = site == site"1" ? nothing : leftindex(Periodic(), tn, site)
Expand Down Expand Up @@ -230,7 +230,21 @@
Tenet.contract(tn::Chain, query::Symbol, args...; kwargs...) = contract!(copy(tn), Val(query), args...; kwargs...)
Tenet.contract!(tn::Chain, query::Symbol, args...; kwargs...) = contract!(tn, Val(query), args...; kwargs...)

function Tenet.contract!(tn::Chain, ::Val{:between}, site1::Site, site2::Site; direction::Symbol = :left)
"""
Tenet.contract!(tn::Chain, ::Val{:between}, site1::Site, site2::Site; direction::Symbol = :left, delete_Λ = true)

For a given [`Chain`](@ref) tensor network, contracts the singular values Λ between two sites `site1` and `site2`.
The `direction` keyword argument specifies the direction of the contraction, and the `delete_Λ` keyword argument
specifies whether to delete the singular values tensor after the contraction.
"""
function Tenet.contract!(
tn::Chain,
::Val{:between},
site1::Site,
site2::Site;
direction::Symbol = :left,
delete_Λ = true,
)
Λᵢ = select(tn, :between, site1, site2)
Λᵢ === nothing && return tn

Expand All @@ -244,7 +258,7 @@
throw(ArgumentError("Unknown direction=:$direction"))
end

delete!(TensorNetwork(tn), Λᵢ)
delete_Λ && delete!(TensorNetwork(tn), Λᵢ)

return tn
end
Expand Down Expand Up @@ -418,12 +432,12 @@
"""
function mixed_canonize!(::Open, tn::Chain, center::Site) # TODO: center could be a range of sites
# left-to-right QR sweep (left-canonical tensors)
for i in 1:center.id-1
for i in 1:id(center)-1
canonize_site!(tn, Site(i); direction = :right, method = :qr)
end

# right-to-left QR sweep (right-canonical tensors)
for i in nsites(tn):-1:center.id+1
for i in nsites(tn):-1:id(center)+1
canonize_site!(tn, Site(i); direction = :left, method = :qr)
end

Expand All @@ -447,7 +461,7 @@

Applies a local operator `gate` to the [`Chain`](@ref) tensor network.
"""
function evolve!(qtn::Chain, gate::Dense; threshold = nothing, maxdim = nothing)
function evolve!(qtn::Chain, gate::Dense; threshold = nothing, maxdim = nothing, iscanonical = false)

Check warning on line 464 in src/Ansatz/Chain.jl

View check run for this annotation

Codecov / codecov/patch

src/Ansatz/Chain.jl#L464

Added line #L464 was not covered by tests
jofrevalles marked this conversation as resolved.
Show resolved Hide resolved
# check gate is a valid operator
if !(socket(gate) isa Operator)
throw(ArgumentError("Gate must be an operator, but got $(socket(gate))"))
Expand All @@ -468,13 +482,13 @@
elseif nlanes(gate) == 2
# check gate sites are contiguous
# TODO refactor this out?
gate_inputs = sort!(map(x -> x.id, inputs(gate)))
gate_inputs = sort!(map(id, inputs(gate)))

Check warning on line 485 in src/Ansatz/Chain.jl

View check run for this annotation

Codecov / codecov/patch

src/Ansatz/Chain.jl#L485

Added line #L485 was not covered by tests
range = UnitRange(extrema(gate_inputs)...)

range != gate_inputs && throw(ArgumentError("Gate lanes must be contiguous"))

# TODO check correctly for periodic boundary conditions
evolve_2site!(qtn, gate; threshold, maxdim)
evolve_2site!(qtn, gate; threshold, maxdim, iscanonical = iscanonical)

Check warning on line 491 in src/Ansatz/Chain.jl

View check run for this annotation

Codecov / codecov/patch

src/Ansatz/Chain.jl#L491

Added line #L491 was not covered by tests
else
# TODO generalize for more than 2 lanes
throw(ArgumentError("Invalid number of lanes $(nlanes(gate)), maximum is 2"))
Expand Down Expand Up @@ -502,17 +516,18 @@
contract!(TensorNetwork(qtn), contracting_index)
end

function evolve_2site!(qtn::Chain, gate::Dense; threshold, maxdim)
# TODO: Maybe rename iscanonical kwarg ?
function evolve_2site!(qtn::Chain, gate::Dense; threshold, maxdim, iscanonical = false)

Check warning on line 520 in src/Ansatz/Chain.jl

View check run for this annotation

Codecov / codecov/patch

src/Ansatz/Chain.jl#L520

Added line #L520 was not covered by tests
jofrevalles marked this conversation as resolved.
Show resolved Hide resolved
# shallow copy to avoid problems if errors in mid execution
gate = copy(gate)

bond = sitel, siter = minmax(outputs(gate)...)
left_inds::Vector{Symbol} = !isnothing(leftindex(qtn, sitel)) ? [leftindex(qtn, sitel)] : Symbol[]
right_inds::Vector{Symbol} = !isnothing(rightindex(qtn, siter)) ? [rightindex(qtn, siter)] : Symbol[]

# contract virtual index
virtualind::Symbol = select(qtn, :bond, bond...)
contract!(TensorNetwork(qtn), virtualind)

iscanonical ? contract_2sitewf!(qtn, bond) : contract!(TensorNetwork(qtn), virtualind)

Check warning on line 530 in src/Ansatz/Chain.jl

View check run for this annotation

Codecov / codecov/patch

src/Ansatz/Chain.jl#L530

Added line #L530 was not covered by tests

# reindex contracting index
contracting_inds = [gensym(:tmp) for _ in inputs(gate)]
Expand All @@ -537,8 +552,12 @@
# decompose using SVD
push!(left_inds, select(qtn, :index, sitel))
push!(right_inds, select(qtn, :index, siter))
svd!(TensorNetwork(qtn); left_inds, right_inds, virtualind)

if iscanonical
unpack_2sitewf!(qtn, bond, left_inds, right_inds, virtualind)

Check warning on line 557 in src/Ansatz/Chain.jl

View check run for this annotation

Codecov / codecov/patch

src/Ansatz/Chain.jl#L556-L557

Added lines #L556 - L557 were not covered by tests
else
svd!(TensorNetwork(qtn); left_inds, right_inds, virtualind)

Check warning on line 559 in src/Ansatz/Chain.jl

View check run for this annotation

Codecov / codecov/patch

src/Ansatz/Chain.jl#L559

Added line #L559 was not covered by tests
end
# truncate virtual index
if any(!isnothing, [threshold, maxdim])
truncate!(qtn, bond; threshold, maxdim)
Expand All @@ -547,6 +566,67 @@
return qtn
end

"""
contract_2sitewf!(ψ::Chain, bond)

For a given [`Chain`](@ref) in the canonical form, creates the two-site wave function θ with Λᵢ₋₁Γᵢ₋₁ΛᵢΓᵢΛᵢ₊₁,
where i is the `bond`, and replaces the Γᵢ₋₁ΛᵢΓᵢ tensors with θ.
"""
function contract_2sitewf!(ψ::Chain, bond)

Check warning on line 575 in src/Ansatz/Chain.jl

View check run for this annotation

Codecov / codecov/patch

src/Ansatz/Chain.jl#L575

Added line #L575 was not covered by tests
# TODO Check if ψ is in canonical form

sitel, siter = bond # TODO Check if bond is valid
(0 < id(sitel) < nsites(ψ) || 0 < id(siter) < nsites(ψ)) ||

Check warning on line 579 in src/Ansatz/Chain.jl

View check run for this annotation

Codecov / codecov/patch

src/Ansatz/Chain.jl#L578-L579

Added lines #L578 - L579 were not covered by tests
throw(ArgumentError("The sites in the bond must be between 1 and $(nsites(ψ))"))

Λᵢ₋₁ = id(sitel) == 1 ? nothing : select(ψ, :between, Site(id(sitel) - 1), sitel)
Λᵢ₊₁ = id(sitel) == nsites(ψ) - 1 ? nothing : select(ψ, :between, siter, Site(id(siter) + 1))

Check warning on line 583 in src/Ansatz/Chain.jl

View check run for this annotation

Codecov / codecov/patch

src/Ansatz/Chain.jl#L582-L583

Added lines #L582 - L583 were not covered by tests

!isnothing(Λᵢ₋₁) && contract!(ψ, :between, Site(id(sitel) - 1), sitel; direction = :right, delete_Λ = false)
!isnothing(Λᵢ₊₁) && contract!(ψ, :between, siter, Site(id(siter) + 1); direction = :left, delete_Λ = false)

Check warning on line 586 in src/Ansatz/Chain.jl

View check run for this annotation

Codecov / codecov/patch

src/Ansatz/Chain.jl#L585-L586

Added lines #L585 - L586 were not covered by tests

contract!(TensorNetwork(ψ), select(ψ, :bond, bond...))

Check warning on line 588 in src/Ansatz/Chain.jl

View check run for this annotation

Codecov / codecov/patch

src/Ansatz/Chain.jl#L588

Added line #L588 was not covered by tests

return ψ

Check warning on line 590 in src/Ansatz/Chain.jl

View check run for this annotation

Codecov / codecov/patch

src/Ansatz/Chain.jl#L590

Added line #L590 was not covered by tests
end

"""
unpack_2sitewf!(ψ::Chain, bond)

For a given [`Chain`](@ref) that contains a two-site wave function θ in a bond, it decomposes θ into the canonical
form: Γᵢ₋₁ΛᵢΓᵢ, where i is the `bond`.
"""
function unpack_2sitewf!(ψ::Chain, bond, left_inds, right_inds, virtualind)

Check warning on line 599 in src/Ansatz/Chain.jl

View check run for this annotation

Codecov / codecov/patch

src/Ansatz/Chain.jl#L599

Added line #L599 was not covered by tests
# TODO Check if ψ is in canonical form

sitel, siter = bond # TODO Check if bond is valid
(0 < id(sitel) < nsites(ψ) || 0 < id(site_r) < nsites(ψ)) ||

Check warning on line 603 in src/Ansatz/Chain.jl

View check run for this annotation

Codecov / codecov/patch

src/Ansatz/Chain.jl#L602-L603

Added lines #L602 - L603 were not covered by tests
throw(ArgumentError("The sites in the bond must be between 1 and $(nsites(ψ))"))

Λᵢ₋₁ = id(sitel) == 1 ? nothing : select(ψ, :between, Site(id(sitel) - 1), sitel)
Λᵢ₊₁ = id(siter) == nsites(ψ) ? nothing : select(ψ, :between, siter, Site(id(siter) + 1))

Check warning on line 607 in src/Ansatz/Chain.jl

View check run for this annotation

Codecov / codecov/patch

src/Ansatz/Chain.jl#L606-L607

Added lines #L606 - L607 were not covered by tests

# do svd of the θ tensor
θ = select(ψ, :tensor, sitel)
U, s, Vt = svd(θ; left_inds, right_inds, virtualind)

Check warning on line 611 in src/Ansatz/Chain.jl

View check run for this annotation

Codecov / codecov/patch

src/Ansatz/Chain.jl#L610-L611

Added lines #L610 - L611 were not covered by tests

# contract with the inverse of Λᵢ and Λᵢ₊₂
Γᵢ₋₁ =

Check warning on line 614 in src/Ansatz/Chain.jl

View check run for this annotation

Codecov / codecov/patch

src/Ansatz/Chain.jl#L614

Added line #L614 was not covered by tests
isnothing(Λᵢ₋₁) ? U :
contract(U, Tensor(diag(pinv(Diagonal(parent(Λᵢ₋₁)), atol = 1e-32)), inds(Λᵢ₋₁)), dims = ())
Γᵢ =

Check warning on line 617 in src/Ansatz/Chain.jl

View check run for this annotation

Codecov / codecov/patch

src/Ansatz/Chain.jl#L617

Added line #L617 was not covered by tests
isnothing(Λᵢ₊₁) ? Vt :
contract(Tensor(diag(pinv(Diagonal(parent(Λᵢ₊₁)), atol = 1e-32)), inds(Λᵢ₊₁)), Vt, dims = ())

delete!(TensorNetwork(ψ), θ)

Check warning on line 621 in src/Ansatz/Chain.jl

View check run for this annotation

Codecov / codecov/patch

src/Ansatz/Chain.jl#L621

Added line #L621 was not covered by tests

push!(TensorNetwork(ψ), Γᵢ₋₁)
push!(TensorNetwork(ψ), s)
push!(TensorNetwork(ψ), Γᵢ)

Check warning on line 625 in src/Ansatz/Chain.jl

View check run for this annotation

Codecov / codecov/patch

src/Ansatz/Chain.jl#L623-L625

Added lines #L623 - L625 were not covered by tests

return ψ

Check warning on line 627 in src/Ansatz/Chain.jl

View check run for this annotation

Codecov / codecov/patch

src/Ansatz/Chain.jl#L627

Added line #L627 was not covered by tests
end

function expect(ψ::Chain, observables)
# contract observable with TN
ϕ = copy(ψ)
Expand Down
33 changes: 21 additions & 12 deletions src/Quantum.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,30 +11,39 @@

- Should we store here some information about quantum numbers?
"""
struct Site
id::Int
struct Site{N}
id::NTuple{N,Int}
dual::Bool

Site(id; dual = false) = new(id, dual)
Site(id::NTuple{N,Int}; dual = false) where {N} = new{N}(id, dual)
end

Site(id::Int; kwargs...) = Site((id,); kwargs...)
Site(id::Vararg{Int,N}; kwargs...) where {N} = Site(id; kwargs...)

id(site::Site{1}) = only(site.id)
id(site::Site) = site.id

Base.CartesianIndex(site::Site) = CartesianIndex(id(site))

isdual(site::Site) = site.dual
Base.show(io::IO, site::Site) = print(io, "$(site.id)$(site.dual ? "'" : "")")
Base.adjoint(site::Site) = Site(site.id; dual = !site.dual)
Base.isless(a::Site, b::Site) = a.id < b.id
Base.show(io::IO, site::Site) = print(io, "$(id(site))$(site.dual ? "'" : "")")

Check warning on line 30 in src/Quantum.jl

View check run for this annotation

Codecov / codecov/patch

src/Quantum.jl#L30

Added line #L30 was not covered by tests
Base.adjoint(site::Site) = Site(id(site); dual = !site.dual)
Base.isless(a::Site, b::Site) = id(a) < id(b)

macro site_str(str)
m = match(r"^(\d+)('?)$", str)
m = match(r"^(\d+,)*\d+('?)$", str)
if isnothing(m)
error("Invalid site string: $str")
end

id = parse(Int, m.captures[1])
dual = m.captures[2] == "'"
id = tuple(map(eachmatch(r"(\d+)", str)) do match
parse(Int, only(match.captures))
end...)

quote
Site($id; dual = $dual)
end
dual = endswith(str, "'")

return :(Site($id; dual = $dual))
end

"""
Expand Down
62 changes: 50 additions & 12 deletions test/Site_test.jl
Original file line number Diff line number Diff line change
@@ -1,25 +1,63 @@
@testset "Site" begin
using Qrochet: id

s = Site(1)
@test s.id == 1
@test s.dual == false
@test id(s) == 1
@test CartesianIndex(s) == CartesianIndex(1)
@test isdual(s) == false

s = Site(1; dual = true)
@test s.id == 1
@test s.dual == true
@test id(s) == 1
@test CartesianIndex(s) == CartesianIndex(1)
@test isdual(s) == true

s = Site(1, 2)
@test id(s) == (1, 2)
@test CartesianIndex(s) == CartesianIndex((1, 2))
@test isdual(s) == false

s = Site(1, 2; dual = true)
@test id(s) == (1, 2)
@test CartesianIndex(s) == CartesianIndex((1, 2))
@test isdual(s) == true

s = site"1"
@test s.id == 1
@test s.dual == false
@test id(s) == 1
@test CartesianIndex(s) == CartesianIndex(1)
@test isdual(s) == false

s = site"1'"
@test s.id == 1
@test s.dual == true
@test id(s) == 1
@test CartesianIndex(s) == CartesianIndex(1)
@test isdual(s) == true

s = site"1,2"
@test id(s) == (1, 2)
@test CartesianIndex(s) == CartesianIndex((1, 2))
@test isdual(s) == false

s = site"1,2'"
@test id(s) == (1, 2)
@test CartesianIndex(s) == CartesianIndex((1, 2))
@test isdual(s) == true

s = site"1" |> adjoint
@test s.id == 1
@test s.dual == true
@test id(s) == 1
@test CartesianIndex(s) == CartesianIndex(1)
@test isdual(s) == true

s = site"1'" |> adjoint
@test s.id == 1
@test s.dual == false
@test id(s) == 1
@test CartesianIndex(s) == CartesianIndex(1)
@test isdual(s) == false

s = site"1,2" |> adjoint
@test id(s) == (1, 2)
@test CartesianIndex(s) == CartesianIndex((1, 2))
@test isdual(s) == true

s = site"1,2'" |> adjoint
@test id(s) == (1, 2)
@test CartesianIndex(s) == CartesianIndex((1, 2))
@test isdual(s) == false
end
Loading