Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions varipeps/expectation/three_sites.py
Original file line number Diff line number Diff line change
Expand Up @@ -399,7 +399,7 @@ def calc_three_sites_triangle_without_bottom_right_multiple_gates(
The order of the PEPS sequence have to be
[top-left, top-right, bottom-left, bottom-left].

The gate is applied in the order [top-left, top-right, bottom-right].
The gate is applied in the order [top-left, top-right, bottom-left].

Args:
peps_tensors (:term:`sequence` of :obj:`jax.numpy.ndarray`):
Expand Down Expand Up @@ -430,7 +430,7 @@ def calc_three_sites_triangle_without_bottom_right_multiple_gates(
)

density_matrix_bottom_left = apply_contraction(
"density_matrix_four_sites_left",
"density_matrix_four_sites_bottom_left",
[peps_tensors[2]],
[peps_tensor_objs[2]],
[],
Expand Down
3 changes: 3 additions & 0 deletions varipeps/optimization/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -463,6 +463,9 @@ def optimize_peps_network(

As convergence criterion the norm of the gradient is used.

If the first CTMRG calculation does not converge. OptimizeResult is returned with Success=False.
This should be handled by the script calling this function.

Args:
input_tensors (:obj:`~varipeps.peps.PEPS_Unit_Cell` or :term:`sequence` of :obj:`jax.numpy.ndarray`):
The PEPS unitcell to work on or the tensors which should be mapped by
Expand Down
29 changes: 20 additions & 9 deletions varipeps/utils/svd.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,17 +336,28 @@ def gauge_fixed_svd(
:obj:`tuple`\\ (:obj:`jnp.ndarray`, :obj:`jnp.ndarray`, :obj:`jnp.ndarray`):
Tuple with sign-fixed U, S and Vh of the SVD.
"""
if only_u_or_vh is None:
if any(d.platform == "gpu" for d in jax.devices()):
U, S, Vh = svd_wrapper(matrix, use_qr=use_qr)
gauge_unitary = U
elif only_u_or_vh == "U":
U, S = svd_only_u(matrix, use_qr=use_qr)
gauge_unitary = U
elif only_u_or_vh == "Vh":
S, Vh = svd_only_vt(matrix, use_qr=use_qr)
gauge_unitary = Vh.T.conj()
if only_u_or_vh is None:
gauge_unitary = U
elif only_u_or_vh == "U":
gauge_unitary = U
elif only_u_or_vh == "Vh":
gauge_unitary = Vh.T.conj()
else:
raise ValueError("Invalid value for parameter 'only_u_or_vh'.")
else:
raise ValueError("Invalid value for parameter 'only_u_or_vh'.")
if only_u_or_vh is None:
U, S, Vh = svd_wrapper(matrix, use_qr=use_qr)
gauge_unitary = U
elif only_u_or_vh == "U":
U, S = svd_only_u(matrix, use_qr=use_qr)
gauge_unitary = U
elif only_u_or_vh == "Vh":
S, Vh = svd_only_vt(matrix, use_qr=use_qr)
gauge_unitary = Vh.T.conj()
else:
raise ValueError("Invalid value for parameter 'only_u_or_vh'.")

# Fix the gauge of the SVD
abs_gauge_unitary = jnp.abs(gauge_unitary)
Expand Down