Skip to content

Commit

Permalink
inital progress
Browse files Browse the repository at this point in the history
  • Loading branch information
apchytr committed Mar 10, 2025
1 parent fe76606 commit 98c0fcc
Show file tree
Hide file tree
Showing 4 changed files with 14 additions and 14 deletions.
2 changes: 1 addition & 1 deletion mrmustard/physics/ansatz/array_ansatz.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,9 +155,9 @@ def from_function(cls, fn: Callable, batch_dims: int = 0, **kwargs: Any) -> Arra
def contract(
self,
other: ArrayAnsatz,
batch_str: str = "",
idx1: int | tuple[int, ...] = tuple(),
idx2: int | tuple[int, ...] = tuple(),
batch_str: str = "",
) -> ArrayAnsatz:
r"""
Contracts two ansatze across the specified variables and batch dimensions.
Expand Down
2 changes: 1 addition & 1 deletion mrmustard/physics/ansatz/polyexp_ansatz.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,9 +335,9 @@ def from_function(cls, fn: Callable, **kwargs: Any) -> PolyExpAnsatz:
def contract(
self,
other: PolyExpAnsatz,
batch_str: str = "",
idx1: int | tuple[int, ...] = tuple(),
idx2: int | tuple[int, ...] = tuple(),
batch_str: str = "",
) -> PolyExpAnsatz:
r"""
Contracts two ansatze across the specified CV variables and batch dimensions.
Expand Down
12 changes: 6 additions & 6 deletions mrmustard/physics/gaussian_integrals.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,12 +219,12 @@ def join_Abc(Abc1: tuple, Abc2: tuple, batch_string: str) -> tuple:
# 0. unpack and prepare inputs
A1, b1, c1 = Abc1
A2, b2, c2 = Abc2
# A1 = math.atleast_3d(A1, dtype=math.complex128)
# A2 = math.atleast_3d(A2, dtype=math.complex128)
# b1 = math.atleast_2d(b1, dtype=math.complex128)
# b2 = math.atleast_2d(b2, dtype=math.complex128)
c1 = math.astensor(c1, dtype=math.complex128)
c2 = math.astensor(c2, dtype=math.complex128)
A1 = math.atleast_3d(A1, dtype=math.complex128)
A2 = math.atleast_3d(A2, dtype=math.complex128)
b1 = math.atleast_2d(b1, dtype=math.complex128)
b2 = math.atleast_2d(b2, dtype=math.complex128)
c1 = math.atleast_1d(c1, dtype=math.complex128)
c2 = math.atleast_1d(c2, dtype=math.complex128)

# 1. Parse the batch string
if "->" not in batch_string:
Expand Down
12 changes: 6 additions & 6 deletions tests/test_physics/test_gaussian_integrals.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ def test_join_Abc_nonbatched():
b2 = np.array([12, 13])
c2 = np.array(10)

A, b, c = join_Abc((A1, b1, c1), (A2, b2, c2))
A, b, c = join_Abc((A1, b1, c1), (A2, b2, c2), batch_string="i,j->ij")

assert np.allclose(A, np.array([[1, 2, 0, 0], [3, 4, 0, 0], [0, 0, 8, 9], [0, 0, 10, 11]]))
assert np.allclose(b, np.array([5, 6, 12, 13]))
Expand All @@ -153,7 +153,7 @@ def test_join_Abc_batched_zip():
b2 = np.array([[12, 13], [14, 15]])
c2 = np.array([10, 100])

A, b, c = join_Abc((A1, b1, c1), (A2, b2, c2), mode="zip")
A, b, c = join_Abc((A1, b1, c1), (A2, b2, c2), batch_string="i,i->i")

assert np.allclose(
A,
Expand All @@ -178,7 +178,7 @@ def test_join_Abc_batched_kron():
b2 = np.array([[12, 13], [14, 15]])
c2 = np.array([10, 100])

A, b, c = join_Abc((A1, b1, c1), (A2, b2, c2), mode="kron")
A, b, c = join_Abc((A1, b1, c1), (A2, b2, c2), batch_string="i,j->ij")

assert np.allclose(
A,
Expand Down Expand Up @@ -246,7 +246,7 @@ def test_complex_gaussian_integral_2_batched():
c2 = math.astensor([c2a, c2b, c2c])
c3 = math.astensor([c3a, c3b, c3c])

res = complex_gaussian_integral_2((A1, b1, c1), (A2, b2, c2), [0], [1], mode="zip")
res = complex_gaussian_integral_2((A1, b1, c1), (A2, b2, c2), [0], [1], batch_string="i,i->i")
assert np.allclose(res[0], A3)
assert np.allclose(res[1], b3)
assert np.allclose(res[2], c3)
Expand All @@ -266,7 +266,7 @@ def test_complex_gaussian_integral_1_not_batched():
A2, b2, c2 = triples.displacement_gate_Abc(x=[0.1, 0.2], y=0.3)
A3, b3, c3 = triples.displaced_squeezed_vacuum_state_Abc(x=[0.1, 0.2], y=0.3)

A, b, c = join_Abc((A1, b1, c1), (A2, b2, c2), mode="zip")
A, b, c = join_Abc((A1, b1, c1), (A2, b2, c2), batch_string="i,i->i")

res = complex_gaussian_integral_1((A, b, c), [0, 1], [4, 5])
assert np.allclose(res[0], A3)
Expand All @@ -293,7 +293,7 @@ def test_complex_gaussian_integral_1_batched():
c2 = math.astensor([c2a, c2b, c2c])
c3 = math.astensor([c3a, c3b, c3c])

A, b, c = join_Abc((A1, b1, c1), (A2, b2, c2), mode="zip")
A, b, c = join_Abc((A1, b1, c1), (A2, b2, c2), batch_string="i,i->i")
res1 = complex_gaussian_integral_1((A, b, c), [0], [2])
assert np.allclose(res1[0], A3)
assert np.allclose(res1[1], b3)
Expand Down

0 comments on commit 98c0fcc

Please sign in to comment.