Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Validate qubit types in css #1050

Merged
merged 15 commits into from
Sep 3, 2024
Merged
45 changes: 43 additions & 2 deletions cirq-superstaq/cirq_superstaq/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,48 @@

import cirq

SUPPORTED_QID_TYPES = (
cirq.LineQubit,
cirq.LineQid,
cirq.GridQubit,
cirq.GridQid,
cirq.NamedQubit,
cirq.NamedQid,
)


def validate_qubit_types(circuits: cirq.Circuit | Sequence[cirq.Circuit]) -> None:
"""Verifies that `circuits` consists of valid (`cirq-core`) qubit types only.

Args:
circuits: The input circuit(s) to validate.

Raises:
TypeError: If an unsupported qubit type is found in `circuits`.
"""
circuits_to_check = [circuits] if isinstance(circuits, cirq.Circuit) else circuits

all_qubits_present: set[cirq.Qid] = set()
for circuit in circuits_to_check:
all_qubits_present.update(circuit.all_qubits())

if not all(isinstance(q, SUPPORTED_QID_TYPES) for q in all_qubits_present):
invalid_qubit_types = ", ".join(
map(str, (set(type(q) for q in all_qubits_present) - set(SUPPORTED_QID_TYPES)))
)
raise TypeError(
f"Input circuit(s) contains unsupported qubit types: {invalid_qubit_types}. "
"Valid qubit types are: `cirq.LineQubit`, `cirq.LineQid`, `cirq.GridQubit`, "
"`cirq.GridQid`, `cirq.NamedQubit`, and `cirq.NamedQid`."
)


def validate_cirq_circuits(circuits: object, require_measurements: bool = False) -> None:
"""Validates that the input is either a single `cirq.Circuit` or a list of `cirq.Circuit`
instances.
"""Validates that the input is an acceptable `cirq-core` object for `cirq-superstaq`.

In particular, this function verifies that `circuits` is either a single `cirq.Circuit`
or a list of `cirq.Circuit` instances. Additionally, also validates that `circuits`
contains supported qubit types only.

Args:
circuits: The circuit(s) to run.
Expand All @@ -16,6 +54,7 @@ def validate_cirq_circuits(circuits: object, require_measurements: bool = False)

Raises:
ValueError: If the input is not a `cirq.Circuit` or a list of `cirq.Circuit` instances.
TypeError: If an unsupported qubit type is found in `circuits`.
"""

if not (
Expand All @@ -37,3 +76,5 @@ def validate_cirq_circuits(circuits: object, require_measurements: bool = False)
# TODO: only raise if the run method actually requires samples (and not for e.g. a
# statevector simulation)
raise ValueError("Circuit has no measurements to sample.")

validate_qubit_types(circuits)
23 changes: 23 additions & 0 deletions cirq-superstaq/cirq_superstaq/validation_test.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
# pylint: disable=missing-function-docstring,missing-class-docstring
from __future__ import annotations

import re
from unittest.mock import MagicMock, create_autospec

import cirq
import pytest

Expand All @@ -27,3 +30,23 @@ def test_validate_cirq_circuits() -> None:

with pytest.raises(ValueError, match="Circuit has no measurements to sample"):
css.validation.validate_cirq_circuits(circuit, require_measurements=True)
bharat-thotakura marked this conversation as resolved.
Show resolved Hide resolved


def test_validate_qubit_type() -> None:
invalid_qubit_type = MagicMock(spec=str) # E.g., in practice, `cirq_rigetti.AspenQubit`
mock_cirq_circuit = create_autospec(cirq.Circuit, spec_set=True)
mock_cirq_circuit.all_qubits.return_value = frozenset({invalid_qubit_type, cirq.LineQubit(0)})
q0 = cirq.LineQubit(0)
q1, q2 = cirq.NamedQubit.range(2, prefix="q")
q3 = cirq.GridQubit(0, 0)
q4 = cirq.q(4)
valid_qubits = frozenset({q0, q1, q2, q3, q4})
valid_circuit = cirq.Circuit(cirq.H(q) for q in valid_qubits)
valid_circuit += cirq.measure(*valid_qubits)

css.validation.validate_qubit_types(valid_circuit)
with pytest.raises(
TypeError,
match=re.escape("Input circuit(s) contains unsupported qubit types:"),
):
css.validation.validate_cirq_circuits([mock_cirq_circuit, valid_circuit])