Skip to content
This repository was archived by the owner on Nov 7, 2024. It is now read-only.

Commit aa3caec

Browse files
AidanGGChase Roberts
andcommitted
Return empty dict for empty sequence input to MPS left_envs and right_envs (#440)
* Return empty dict for empty input to MPS envs * Add tests for empty sequence input to MPS envs * Use explicit sequences for MPS envs tests Co-authored-by: Chase Roberts <chaseriley@google.com>
1 parent c56c1fa commit aa3caec

File tree

2 files changed

+47
-12
lines changed

2 files changed

+47
-12
lines changed

tensornetwork/matrixproductstates/finite_mps.py

Lines changed: 17 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -29,25 +29,25 @@
2929

3030
class FiniteMPS(BaseMPS):
3131
"""
32-
An MPS class for finite systems.
32+
An MPS class for finite systems.
3333
3434
MPS tensors are stored as a list of `Node` objects in the `FiniteMPS.nodes`
3535
attribute.
36-
`FiniteMPS` has a central site, also called orthogonality center.
37-
The position of this central site is stored in `FiniteMPS.center_position`,
38-
and it can be be shifted using the `FiniteMPS.position` method.
36+
`FiniteMPS` has a central site, also called orthogonality center.
37+
The position of this central site is stored in `FiniteMPS.center_position`,
38+
and it can be be shifted using the `FiniteMPS.position` method.
3939
`FiniteMPS.position` uses QR and RQ methods to shift `center_position`.
40-
40+
4141
`FiniteMPS` can be initialized either from a `list` of tensors, or
4242
by calling the classmethod `FiniteMPS.random`.
43-
43+
4444
By default, `FiniteMPS` is initialized in *canonical* form, i.e.
45-
the state is normalized, and all tensors to the left of
46-
`center_position` are left orthogonal, and all tensors
45+
the state is normalized, and all tensors to the left of
46+
`center_position` are left orthogonal, and all tensors
4747
to the right of `center_position` are right orthogonal. The tensor
4848
at `FiniteMPS.center_position` is neither left nor right orthogonal.
4949
50-
Note that canonicalization can be computationally relatively
50+
Note that canonicalization can be computationally relatively
5151
costly and scales :math:`\\propto ND^3`.
5252
"""
5353

@@ -62,7 +62,7 @@ def __init__(self,
6262
tensors: A list of `Tensor` or `BaseNode` objects.
6363
center_position: The initial position of the center site.
6464
canonicalize: If `True` the mps is canonicalized at initialization.
65-
backend: The name of the backend that should be used to perform
65+
backend: The name of the backend that should be used to perform
6666
contractions. Available backends are currently 'numpy', 'tensorflow',
6767
'pytorch', 'jax'
6868
"""
@@ -160,10 +160,13 @@ def left_envs(self, sites: Sequence[int]) -> Dict:
160160
Args:
161161
sites (list of int): A list of sites of the MPS.
162162
Returns:
163-
`dict` mapping `int` to `Tensor`: The left-reduced density matrices
163+
`dict` mapping `int` to `Tensor`: The left-reduced density matrices
164164
at each site in `sites`.
165165
166166
"""
167+
if not sites:
168+
return {}
169+
167170
n2 = max(sites)
168171
sites = np.array(sites) #enable logical indexing
169172

@@ -227,9 +230,11 @@ def right_envs(self, sites: Sequence[int]) -> Dict:
227230
Args:
228231
sites (list of int): A list of sites of the MPS.
229232
Returns:
230-
`dict` mapping `int` to `Tensor`: The right-reduced density matrices
233+
`dict` mapping `int` to `Tensor`: The right-reduced density matrices
231234
at each site in `sites`.
232235
"""
236+
if not sites:
237+
return {}
233238

234239
n1 = min(sites)
235240
sites = np.array(sites)

tensornetwork/matrixproductstates/finite_mps_test.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,3 +121,33 @@ def test_correlation_measurement_finite_mps(backend_dtype_values):
121121
actual[N // 2] = 0.25
122122
np.testing.assert_almost_equal(result_1, actual)
123123
np.testing.assert_allclose(result_2, np.ones(N) * 0.25)
124+
125+
126+
def test_left_envs_empty_seq(backend_dtype_values):
127+
backend = backend_dtype_values[0]
128+
dtype = backend_dtype_values[1]
129+
130+
D, d, N = 1, 2, 10
131+
tensors = [np.ones((1, d, D), dtype=dtype)] + [
132+
np.ones((D, d, D), dtype=dtype) for _ in range(N - 2)
133+
] + [np.ones((D, d, 1), dtype=dtype)]
134+
mps = FiniteMPS(tensors, center_position=0, backend=backend)
135+
136+
assert mps.left_envs(()) == {}
137+
assert mps.left_envs([]) == {}
138+
assert mps.left_envs(range(0)) == {}
139+
140+
141+
def test_right_envs_empty_seq(backend_dtype_values):
142+
backend = backend_dtype_values[0]
143+
dtype = backend_dtype_values[1]
144+
145+
D, d, N = 1, 2, 10
146+
tensors = [np.ones((1, d, D), dtype=dtype)] + [
147+
np.ones((D, d, D), dtype=dtype) for _ in range(N - 2)
148+
] + [np.ones((D, d, 1), dtype=dtype)]
149+
mps = FiniteMPS(tensors, center_position=0, backend=backend)
150+
151+
assert mps.right_envs(()) == {}
152+
assert mps.right_envs([]) == {}
153+
assert mps.right_envs(range(0)) == {}

0 commit comments

Comments
 (0)