Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix nondeterminacy in Circuit.insert (simplified) #7043

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
155 changes: 91 additions & 64 deletions cirq-core/cirq/circuits/circuit.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@


_TGate = TypeVar('_TGate', bound='cirq.Gate')
_MOMENT_OR_OP = Union['cirq.Moment', 'cirq.Operation']

CIRCUIT_TYPE = TypeVar('CIRCUIT_TYPE', bound='AbstractCircuit')
document(
Expand Down Expand Up @@ -2095,49 +2096,6 @@ def earliest_available_moment(
last_available = k
return last_available

def _pick_or_create_inserted_op_moment_index(
self, splitter_index: int, op: 'cirq.Operation', strategy: 'cirq.InsertStrategy'
) -> int:
"""Determines and prepares where an insertion will occur.

Args:
splitter_index: The index to insert at.
op: The operation that will be inserted.
strategy: The insertion strategy.

Returns:
The index of the (possibly new) moment where the insertion should
occur.

Raises:
ValueError: Unrecognized append strategy.
"""

if strategy is InsertStrategy.NEW or strategy is InsertStrategy.NEW_THEN_INLINE:
self._moments.insert(splitter_index, Moment())
self._mutated()
return splitter_index

if strategy is InsertStrategy.INLINE:
if 0 <= splitter_index - 1 < len(self._moments) and self._can_add_op_at(
splitter_index - 1, op
):
return splitter_index - 1

return self._pick_or_create_inserted_op_moment_index(
splitter_index, op, InsertStrategy.NEW
)

if strategy is InsertStrategy.EARLIEST:
if self._can_add_op_at(splitter_index, op):
return self.earliest_available_moment(op, end_moment_index=splitter_index)

return self._pick_or_create_inserted_op_moment_index(
splitter_index, op, InsertStrategy.INLINE
)

raise ValueError(f'Unrecognized append strategy: {strategy}')

def _can_add_op_at(self, moment_index: int, operation: 'cirq.Operation') -> bool:
if not 0 <= moment_index < len(self._moments):
return True
Expand All @@ -2147,7 +2105,7 @@ def _can_add_op_at(self, moment_index: int, operation: 'cirq.Operation') -> bool
def insert(
self,
index: int,
moment_or_operation_tree: Union['cirq.Operation', 'cirq.OP_TREE'],
moment_or_operation_tree: 'cirq.OP_TREE',
strategy: 'cirq.InsertStrategy' = InsertStrategy.EARLIEST,
) -> int:
"""Inserts operations into the circuit.
Expand All @@ -2170,24 +2128,57 @@ def insert(
"""
# limit index to 0..len(self._moments), also deal with indices smaller 0
k = max(min(index if index >= 0 else len(self._moments) + index, len(self._moments)), 0)
if strategy != InsertStrategy.EARLIEST or index != len(self._moments):
if strategy != InsertStrategy.EARLIEST or k != len(self._moments):
self._placement_cache = None
for moment_or_op in list(ops.flatten_to_ops_or_moments(moment_or_operation_tree)):
if self._placement_cache:
p = self._placement_cache.append(moment_or_op)
elif isinstance(moment_or_op, Moment):
p = k
else:
p = self._pick_or_create_inserted_op_moment_index(k, moment_or_op, strategy)
if isinstance(moment_or_op, Moment):
self._moments.insert(p, moment_or_op)
elif p == len(self._moments):
self._moments.append(Moment(moment_or_op))
else:
self._moments[p] = self._moments[p].with_operation(moment_or_op)
k = max(k, p + 1)
if strategy is InsertStrategy.NEW_THEN_INLINE:
strategy = InsertStrategy.INLINE
mops = list(ops.flatten_to_ops_or_moments(moment_or_operation_tree))
if self._placement_cache:
batches = [mops] # Any grouping would work here; this just happens to be the fastest.
elif strategy is InsertStrategy.NEW:
batches = [[mop] for mop in mops] # Each op goes into its own moment.
else:
batches = list(_group_into_moment_compatible(mops))
for batch in batches:
# Insert a moment if inline/earliest and _any_ op in the batch requires it.
if (
not self._placement_cache
and not isinstance(batch[0], Moment)
and strategy in [InsertStrategy.INLINE, InsertStrategy.EARLIEST]
and not all(
(strategy is InsertStrategy.EARLIEST and self._can_add_op_at(k, op))
or (k > 0 and self._can_add_op_at(k - 1, op))
for op in cast(List['cirq.Operation'], batch)
)
):
self._moments.insert(k, Moment())
if strategy is InsertStrategy.INLINE:
k += 1
max_p = 0
for moment_or_op in batch:
# Determine Placement
if self._placement_cache:
p = self._placement_cache.append(moment_or_op)
elif isinstance(moment_or_op, Moment):
p = k
elif strategy in [InsertStrategy.NEW, InsertStrategy.NEW_THEN_INLINE]:
self._moments.insert(k, Moment())
p = k
elif strategy is InsertStrategy.INLINE:
p = k - 1
else: # InsertStrategy.EARLIEST:
p = self.earliest_available_moment(moment_or_op, end_moment_index=k)
# Place
if isinstance(moment_or_op, Moment):
self._moments.insert(p, moment_or_op)
elif p == len(self._moments):
self._moments.append(Moment(moment_or_op))
else:
self._moments[p] = self._moments[p].with_operation(moment_or_op)
# Iterate
max_p = max(p, max_p)
if strategy is InsertStrategy.NEW_THEN_INLINE:
strategy = InsertStrategy.INLINE
k += 1
k = max(k, max_p + 1)
self._mutated(preserve_placement_cache=True)
return k

Expand Down Expand Up @@ -2450,7 +2441,7 @@ def batch_insert(self, insertions: Iterable[Tuple[int, 'cirq.OP_TREE']]) -> None

def append(
self,
moment_or_operation_tree: Union['cirq.Moment', 'cirq.OP_TREE'],
moment_or_operation_tree: 'cirq.OP_TREE',
strategy: 'cirq.InsertStrategy' = InsertStrategy.EARLIEST,
) -> None:
"""Appends operations onto the end of the circuit.
Expand Down Expand Up @@ -2841,8 +2832,44 @@ def _group_until_different(items: Iterable[_TIn], key: Callable[[_TIn], _TKey],
return ((k, [val(i) for i in v]) for (k, v) in itertools.groupby(items, key))


def _group_into_moment_compatible(inputs: Sequence[_MOMENT_OR_OP]) -> Iterator[List[_MOMENT_OR_OP]]:
"""Groups sequential ops into those that can coexist in a single moment.

This function will go through the input sequence in order, emitting lists of sequential
operations that can go into a single moment. It does not try to rearrange the elements or try
to move them to open slots in earlier moments; it simply processes them in order and outputs
them. i.e. the output, if flattened, will equal the input.

Actual Moments in the input will always be emitted by themselves as a single-element list.

Examples:
[X(a), X(b), X(a)] -> [[X(a), X(b)], [X(a)]]
[X(a), X(a), X(b)] -> [[X(a)], [X(a), X(b)]]
[X(a), Moment(X(b)), X(c)] -> [[X(a)], [Moment(X(b))], [X(c)]]"""
i = 0
batch: List[_MOMENT_OR_OP] = []
while i < len(inputs):
if isinstance(inputs[i], Moment):
yield [inputs[i]]
i += 1
continue
batch_qubits: Set['cirq.Qid'] = set()
while i < len(inputs):
mop = inputs[i]
qs = mop.qubits
if isinstance(mop, Moment) or not batch_qubits.isdisjoint(qs):
yield batch
batch = []
break
batch.append(mop)
batch_qubits.update(qs)
i += 1
if batch:
yield batch


def get_earliest_accommodating_moment_index(
moment_or_operation: Union['cirq.Moment', 'cirq.Operation'],
moment_or_operation: _MOMENT_OR_OP,
qubit_indices: Dict['cirq.Qid', int],
mkey_indices: Dict['cirq.MeasurementKey', int],
ckey_indices: Dict['cirq.MeasurementKey', int],
Expand Down Expand Up @@ -2938,7 +2965,7 @@ def __init__(self) -> None:
# For keeping track of length of the circuit thus far.
self._length = 0

def append(self, moment_or_operation: Union['cirq.Moment', 'cirq.Operation']) -> int:
def append(self, moment_or_operation: _MOMENT_OR_OP) -> int:
"""Find placement for moment/operation and update cache.

Determines the placement index of the provided operation, assuming
Expand Down
46 changes: 46 additions & 0 deletions cirq-core/cirq/circuits/circuit_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3555,6 +3555,52 @@ def test_insert_operations_random_circuits(circuit):
assert circuit == other_circuit


def test_insert_zero_index():
# Should always go to moment[0], independent of qubit order or earliest/inline strategy.
q0, q1 = cirq.LineQubit.range(2)
c0 = cirq.Circuit(cirq.X(q0))
c0.insert(0, cirq.Y.on_each(q0, q1), strategy=cirq.InsertStrategy.EARLIEST)
c1 = cirq.Circuit(cirq.X(q0))
c1.insert(0, cirq.Y.on_each(q1, q0), strategy=cirq.InsertStrategy.EARLIEST)
c2 = cirq.Circuit(cirq.X(q0))
c2.insert(0, cirq.Y.on_each(q0, q1), strategy=cirq.InsertStrategy.INLINE)
c3 = cirq.Circuit(cirq.X(q0))
c3.insert(0, cirq.Y.on_each(q1, q0), strategy=cirq.InsertStrategy.INLINE)
expected = cirq.Circuit(cirq.Moment(cirq.Y(q0), cirq.Y(q1)), cirq.Moment(cirq.X(q0)))
assert c0 == expected
assert c1 == expected
assert c2 == expected
assert c3 == expected


def test_insert_earliest_on_previous_moment():
q = cirq.LineQubit(0)
c = cirq.Circuit(cirq.Moment(cirq.X(q)), cirq.Moment(), cirq.Moment(), cirq.Moment(cirq.Z(q)))
c.insert(3, cirq.Y(q), strategy=cirq.InsertStrategy.EARLIEST)
# Should fall back to moment[1] since EARLIEST
assert c == cirq.Circuit(
cirq.Moment(cirq.X(q)), cirq.Moment(cirq.Y(q)), cirq.Moment(), cirq.Moment(cirq.Z(q))
)


def test_insert_inline_end_of_circuit():
# If end index is specified, INLINE should place all ops there independent of qubit order.
q0, q1 = cirq.LineQubit.range(2)
c0 = cirq.Circuit(cirq.X(q0))
c0.insert(1, cirq.Y.on_each(q0, q1), strategy=cirq.InsertStrategy.INLINE)
c1 = cirq.Circuit(cirq.X(q0))
c1.insert(1, cirq.Y.on_each(q1, q0), strategy=cirq.InsertStrategy.INLINE)
c2 = cirq.Circuit(cirq.X(q0))
c2.insert(5, cirq.Y.on_each(q0, q1), strategy=cirq.InsertStrategy.INLINE)
c3 = cirq.Circuit(cirq.X(q0))
c3.insert(5, cirq.Y.on_each(q1, q0), strategy=cirq.InsertStrategy.INLINE)
expected = cirq.Circuit(cirq.Moment(cirq.X(q0)), cirq.Moment(cirq.Y(q0), cirq.Y(q1)))
assert c0 == expected
assert c1 == expected
assert c2 == expected
assert c3 == expected


def test_insert_operations_errors():
a, b, c = (cirq.NamedQubit(s) for s in 'abc')
with pytest.raises(ValueError):
Expand Down