-
Notifications
You must be signed in to change notification settings - Fork 13
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Normalize #192
base: main
Are you sure you want to change the base?
Normalize #192
Changes from all commits
ec7ec3b
bd05519
e116388
cd2b139
6391bfa
fa91e7c
201882a
7228fb5
75d0c3b
c90139b
e87e1b3
8d780a8
e62ae0f
371492d
5138e51
275191a
194fba3
50369c1
0e5e5d8
4bc0183
9e14f14
0a7355e
b07b978
440c267
ed7befa
322dca4
54f41c0
dc0e132
2af3984
30786bc
f0d4fc8
ed5037e
e61e58c
6998077
511e09f
ed0c069
553a983
005b0e5
af68e63
0704609
e1344f0
66319b0
b098d44
b296277
1c87d22
f88b21c
6a8d4b9
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
using LinearAlgebra | ||
|
||
function LinearAlgebra.normalize(tn::AbstractITensorNetwork; alg="exact", kwargs...) | ||
return normalize(Algorithm(alg), tn; kwargs...) | ||
end | ||
|
||
function LinearAlgebra.normalize(alg::Algorithm"exact", tn::AbstractITensorNetwork) | ||
norm_tn = norm_sqr_network(tn) | ||
log_norm = logscalar(alg, norm_tn) | ||
tn = copy(tn) | ||
L = length(vertices(tn)) | ||
c = exp(log_norm / L) | ||
for v in vertices(tn) | ||
tn[v] = tn[v] / sqrt(c) | ||
end | ||
return tn | ||
end | ||
|
||
function LinearAlgebra.normalize( | ||
alg::Algorithm"bp", | ||
tn::AbstractITensorNetwork; | ||
(cache!)=nothing, | ||
update_cache=isnothing(cache!), | ||
cache_update_kwargs=default_cache_update_kwargs(cache!), | ||
) | ||
Comment on lines
+19
to
+25
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. So it seems like a basic design question here is if normalizing should refer to treating the It seems reasonable to define it such that The current implementation feels a bit too "in the weeds" dealing with quadratic forms, bras, kets, etc. and seems like something that could be abstracted and generalized. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Also defining a function like There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Relatedly, There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah I see what you mean, that's a nice idea to split it apart like that. Will change it to do that |
||
if isnothing(cache!) | ||
cache! = Ref(BeliefPropagationCache(QuadraticFormNetwork(tn))) | ||
end | ||
|
||
if update_cache | ||
cache![] = update(cache![]; cache_update_kwargs...) | ||
end | ||
|
||
tn = copy(tn) | ||
cache![] = normalize_messages(cache![]) | ||
norm_tn = tensornetwork(cache![]) | ||
|
||
vertices_states = Dictionary() | ||
for v in vertices(tn) | ||
v_ket, v_bra = ket_vertex(norm_tn, v), bra_vertex(norm_tn, v) | ||
pv = only(partitionvertices(cache![], [v_ket])) | ||
vn = region_scalar(cache![], pv) | ||
state = tn[v] / sqrt(vn) | ||
state_dag = copy(dag(state)) | ||
state_dag = replaceinds( | ||
state_dag, inds(state_dag), dual_index_map(norm_tn).(inds(state_dag)) | ||
) | ||
set!(vertices_states, v_ket, state) | ||
set!(vertices_states, v_bra, state_dag) | ||
tn[v] = state | ||
end | ||
|
||
cache![] = update_factors(cache![], vertices_states) | ||
|
||
return tn | ||
end |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
@eval module $(gensym()) | ||
using ITensorNetworks: | ||
BeliefPropagationCache, | ||
QuadraticFormNetwork, | ||
edge_scalars, | ||
norm_sqr_network, | ||
random_tensornetwork, | ||
vertex_scalars | ||
using ITensors: dag, inner, siteinds, scalar | ||
using Graphs: SimpleGraph, uniform_tree | ||
using LinearAlgebra: normalize | ||
using NamedGraphs: NamedGraph | ||
using NamedGraphs.NamedGraphGenerators: named_grid | ||
using StableRNGs: StableRNG | ||
using Test: @test, @testset | ||
@testset "Normalize" begin | ||
|
||
#First lets do a tree | ||
L = 6 | ||
χ = 2 | ||
rng = StableRNG(1234) | ||
|
||
g = NamedGraph(SimpleGraph(uniform_tree(L))) | ||
s = siteinds("S=1/2", g) | ||
x = random_tensornetwork(rng, s; link_space=χ) | ||
|
||
ψ = normalize(x; alg="exact") | ||
@test scalar(norm_sqr_network(ψ); alg="exact") ≈ 1.0 | ||
|
||
ψ = normalize(x; alg="bp") | ||
@test scalar(norm_sqr_network(ψ); alg="exact") ≈ 1.0 | ||
|
||
#Now a loopy graph | ||
Lx, Ly = 3, 2 | ||
χ = 2 | ||
rng = StableRNG(1234) | ||
|
||
g = named_grid((Lx, Ly)) | ||
s = siteinds("S=1/2", g) | ||
x = random_tensornetwork(rng, s; link_space=χ) | ||
|
||
ψ = normalize(x; alg="exact") | ||
@test scalar(norm_sqr_network(ψ); alg="exact") ≈ 1.0 | ||
|
||
ψIψ_bpc = Ref(BeliefPropagationCache(QuadraticFormNetwork(x))) | ||
ψ = normalize(x; alg="bp", (cache!)=ψIψ_bpc, update_cache=true) | ||
ψIψ_bpc = ψIψ_bpc[] | ||
@test all(x -> x ≈ 1.0, edge_scalars(ψIψ_bpc)) | ||
@test all(x -> x ≈ 1.0, vertex_scalars(ψIψ_bpc)) | ||
@test scalar(QuadraticFormNetwork(ψ); alg="bp") ≈ 1.0 | ||
end | ||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it would be good to first normalize the messages against themselves and then normalize them against each other. That could help to make sure certain messages don't end up with very large or small norms. I realize in general during message tensor updates you would generally normalize but that might not be the case and large norms could accumulate (say when computing expectation values of extensive operators like Hamiltonians).