Skip to content
This repository has been archived by the owner on Jul 7, 2024. It is now read-only.

Add mixed_canonize function #10

Merged
merged 9 commits into from
Feb 20, 2024
42 changes: 37 additions & 5 deletions src/Ansatz/Chain.jl
Original file line number Diff line number Diff line change
Expand Up @@ -144,24 +144,24 @@ function rightindex(::Union{Open, Periodic}, tn::Chain, site::Site)
end
end

canonize(tn::Chain, args...; kwargs...) = canonize!(deepcopy(tn), args...; kwargs...)
canonize!(tn::Chain, args...; kwargs...) = canonize!(boundary(tn), tn, args...; kwargs...)
canonize_site(tn::Chain, args...; kwargs...) = canonize_site!(deepcopy(tn), args...; kwargs...)
canonize_site!(tn::Chain, args...; kwargs...) = canonize_site!(boundary(tn), tn, args...; kwargs...)

# NOTE: in mode == :svd the spectral weights are stored in a vector connected to the now virtual hyperindex!
function canonize!(::Open, tn::Chain, site::Site; direction::Symbol, mode = :qr)
function canonize_site!(::Open, tn::Chain, site::Site; direction::Symbol, mode = :qr)
left_inds = Symbol[]
right_inds = Symbol[]

virtualind = if direction === :left
site == Site(nsites(tn)) && throw(ArgumentError("Cannot right-canonize right-most tensor"))
site == Site(nsites(tn)) && throw(ArgumentError("Cannot left-canonize right-most tensor"))
push!(right_inds, rightindex(tn, site))

site == Site(1) || push!(left_inds, leftindex(tn, site))
push!(left_inds, Quantum(tn)[site])

only(right_inds)
elseif direction === :right
site == Site(1) && throw(ArgumentError("Cannot left-canonize left-most tensor"))
site == Site(1) && throw(ArgumentError("Cannot right-canonize left-most tensor"))
push!(right_inds, leftindex(tn, site))

site == Site(nsites(tn)) || push!(left_inds, rightindex(tn, site))
Expand All @@ -186,3 +186,35 @@ function canonize!(::Open, tn::Chain, site::Site; direction::Symbol, mode = :qr)

return tn
end

mixed_canonize(tn::Chain, args...; kwargs...) = mixed_canonize!(deepcopy(tn), args...; kwargs...)
mixed_canonize!(tn::Chain, args...; kwargs...) = mixed_canonize!(boundary(tn), tn, args...; kwargs...)

"""
mixed_canonize!(boundary::Boundary, tn::Chain, center::Site)

Transform a `Chain` tensor network into the mixed-canonical form, that is,
for i < center the tensors are left-canonical and for i > center the tensors are right-canonical,
and in the center there is a matrix with singular values.
"""
function mixed_canonize!(::Open, tn::Chain, center::Site) # TODO: center could be a range of sites
N = length(sites(tn))

# Left-to-right QR sweep -> get left-canonical tensors
for i in 1:N-1
canonize_site!(tn, Site(i); direction = :left, mode = :qr)
end

# Right-to-left QR sweep -> get left-canonical tensors for i > center
for i in N:-1:1
if i > center.id
canonize_site!(tn, Site(i); direction = :right, mode = :qr)
elseif i == center.id
canonize_site!(tn, Site(i); direction = :left, mode = :svd)
else
canonize_site!(tn, Site(i); direction = :left, mode = :qr)
end
end

return tn
end
3 changes: 2 additions & 1 deletion src/Qrochet.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@ export Product
include("Ansatz/Chain.jl")
export Chain
export MPS, pMPS, MPO, pMPO
export leftindex, rightindex, canonize, canonize!
export leftindex, rightindex, canonize_site, canonize_site!
export mixed_canonize, mixed_canonize!

# reexports from Tenet
using Tenet
Expand Down
91 changes: 54 additions & 37 deletions test/Ansatz/Chain_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -51,51 +51,68 @@
@test rightsite(qtn, Site(1)) == Site(2)
end

@testset "canonize" begin
@testset "Canonization" begin
using Tenet

function is_left_canonical(qtn, s::Site)
label_r = rightindex(qtn, s)
A = select(qtn, :tensor, s)
try
contracted = contract(A, replace(conj(A), label_r => :new_ind_name))
return isapprox(contracted, Matrix{Float64}(I, size(A, label_r), size(A, label_r)), atol=1e-12)
catch
return false
end
end
label_r = rightindex(qtn, s)
A = select(qtn, :tensor, s)
try
contracted = contract(A, replace(conj(A), label_r => :new_ind_name))
return isapprox(contracted, Matrix{Float64}(I, size(A, label_r), size(A, label_r)), atol=1e-12)
catch
return false
end
end

function is_right_canonical(qtn, s::Site)
label_l = leftindex(qtn, s)
A = select(qtn, :tensor, s)
try
contracted = contract(A, replace(conj(A), label_l => :new_ind_name))
return isapprox(contracted, Matrix{Float64}(I, size(A, label_l), size(A, label_l)), atol=1e-12)
catch
return false
end
end

qtn = Chain(State(), Open(), [rand(4, 4), rand(4, 4, 4), rand(4, 4)])

@test_throws ArgumentError canonize!(qtn, Site(1); direction=:right)
@test_throws ArgumentError canonize!(qtn, Site(3); direction=:left)

for mode in [:qr, :svd]
for i in 1:length(sites(qtn))
if i != 1
canonized = canonize(qtn, Site(i); direction=:right, mode=mode)
@test is_right_canonical(canonized, Site(i))
@test isapprox(contract(transform(TensorNetwork(canonized), Tenet.HyperindConverter())), contract(TensorNetwork(qtn)))
elseif i != length(sites(qtn))
canonized = canonize(qtn, Site(i); direction=:left, mode=mode)
@test is_left_canonical(canonized, Site(i))
@test isapprox(contract(transform(TensorNetwork(canonized), Tenet.HyperindConverter())), contract(TensorNetwork(qtn)))
label_l = leftindex(qtn, s)
A = select(qtn, :tensor, s)
try
contracted = contract(A, replace(conj(A), label_l => :new_ind_name))
return isapprox(contracted, Matrix{Float64}(I, size(A, label_l), size(A, label_l)), atol=1e-12)
catch
return false
end
end

@testset "canonize_site" begin
qtn = Chain(State(), Open(), [rand(4, 4), rand(4, 4, 4), rand(4, 4)])

@test_throws ArgumentError canonize_site!(qtn, Site(1); direction=:right)
@test_throws ArgumentError canonize_site!(qtn, Site(3); direction=:left)

for mode in [:qr, :svd]
for i in 1:length(sites(qtn))
if i != 1
canonized = canonize_site(qtn, Site(i); direction=:right, mode=mode)
@test is_right_canonical(canonized, Site(i))
@test isapprox(contract(transform(TensorNetwork(canonized), Tenet.HyperindConverter())), contract(TensorNetwork(qtn)))
elseif i != length(sites(qtn))
canonized = canonize_site(qtn, Site(i); direction=:left, mode=mode)
@test is_left_canonical(canonized, Site(i))
@test isapprox(contract(transform(TensorNetwork(canonized), Tenet.HyperindConverter())), contract(TensorNetwork(qtn)))
end
end
end

# Ensure that svd creates a new tensor
@test length(tensors(canonize_site(qtn, Site(2); direction=:right, mode=:svd))) == 4
end

# Ensure that svd creates a new tensor
@test length(tensors(canonize(qtn, Site(2); direction=:right, mode=:svd))) == 4
@testset "mixed_canonize" begin
qtn = Chain(State(), Open(), [rand(4, 4), rand(4, 4, 4), rand(4, 4, 4), rand(4, 4, 4), rand(4, 4)])
canonized = mixed_canonize(qtn, Site(3))

@test is_left_canonical(canonized, Site(1))
@test is_left_canonical(canonized, Site(2))
@test is_left_canonical(canonized, Site(3))
@test is_right_canonical(canonized, Site(4))
@test is_right_canonical(canonized, Site(5))

@test length(tensors(canonized)) == 6 # 5 tensors + 1 singular value matrix

@test isapprox(contract(transform(TensorNetwork(canonized), Tenet.HyperindConverter())), contract(TensorNetwork(qtn)))
end
end
end
Loading