Skip to content

Commit

Permalink
Move gateset construction to adapter (#146)
Browse files Browse the repository at this point in the history
* Put gateset construction in a more logical place and made constants and methods private as needed.
* Also skip gateset construction when running circuits verbatim.
  • Loading branch information
speller26 authored Feb 8, 2024
1 parent b129cfd commit d5eb538
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 53 deletions.
54 changes: 39 additions & 15 deletions qiskit_braket_provider/providers/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
GateModelSimulatorParadigmProperties,
)
from braket.devices import LocalSimulator
from braket.ir.openqasm.modifiers import Control

from numpy import pi

Expand All @@ -38,7 +39,7 @@
from qiskit_ionq import ionq_gates
from qiskit_braket_provider.exception import QiskitBraketException

BRAKET_TO_QISKIT_NAMES = {
_BRAKET_TO_QISKIT_NAMES = {
"u": "u",
"phaseshift": "p",
"cnot": "cx",
Expand Down Expand Up @@ -79,7 +80,7 @@

_EPS = 1e-10 # global variable used to chop very small numbers to zero

GATE_NAME_TO_BRAKET_GATE: dict[str, Callable] = {
_GATE_NAME_TO_BRAKET_GATE: dict[str, Callable] = {
"u1": lambda lam: [braket_gates.PhaseShift(lam)],
"u2": lambda phi, lam: [
braket_gates.PhaseShift(lam),
Expand Down Expand Up @@ -146,12 +147,12 @@
}

_TRANSLATABLE_QISKIT_GATE_NAMES = (
set(GATE_NAME_TO_BRAKET_GATE.keys())
set(_GATE_NAME_TO_BRAKET_GATE.keys())
.union(set(_QISKIT_CONTROLLED_GATE_NAMES_TO_BRAKET_GATES))
.union({"measure", "barrier", "reset"})
)

GATE_NAME_TO_QISKIT_GATE: dict[str, Optional[QiskitInstruction]] = {
_GATE_NAME_TO_QISKIT_GATE: dict[str, Optional[QiskitInstruction]] = {
"u": qiskit_gates.UGate(Parameter("theta"), Parameter("phi"), Parameter("lam")),
"u1": qiskit_gates.U1Gate(Parameter("theta")),
"u2": qiskit_gates.U2Gate(Parameter("theta"), Parameter("lam")),
Expand Down Expand Up @@ -193,7 +194,30 @@
}


def get_controlled_gateset(max_qubits: Optional[int] = None) -> set[str]:
def gateset_from_properties(properties: OpenQASMDeviceActionProperties) -> set[str]:
"""Returns the gateset supported by a Braket device with the given properties
Args:
properties (OpenQASMDeviceActionProperties): The action properties of the Braket device.
Returns:
set[str]: The names of the gates supported by the device
"""
gateset = {
_BRAKET_TO_QISKIT_NAMES[op.lower()]
for op in properties.supportedOperations
if op.lower() in _BRAKET_TO_QISKIT_NAMES
}
max_control = 0
for modifier in properties.supportedModifiers:
if isinstance(modifier, Control):
max_control = modifier.max_qubits
break
gateset.update(_get_controlled_gateset(max_control))
return gateset


def _get_controlled_gateset(max_qubits: Optional[int] = None) -> set[str]:
"""Returns the Qiskit gates expressible as controlled versions of existing Braket gates
This set can be filtered by the maximum number of control qubits.
Expand Down Expand Up @@ -227,7 +251,7 @@ def local_simulator_to_target(simulator: LocalSimulator) -> Target:
target = Target()

instructions = [
inst for inst in GATE_NAME_TO_QISKIT_GATE.values() if inst is not None
inst for inst in _GATE_NAME_TO_QISKIT_GATE.values() if inst is not None
]
properties = simulator.properties
paradigm: GateModelSimulatorParadigmProperties = properties.paradigm
Expand Down Expand Up @@ -283,7 +307,7 @@ def aws_device_to_target(device: AwsDevice) -> Target:
instructions: list[QiskitInstruction] = []

for operation in action_properties.supportedOperations:
instruction = GATE_NAME_TO_QISKIT_GATE.get(operation.lower(), None)
instruction = _GATE_NAME_TO_QISKIT_GATE.get(operation.lower(), None)
if instruction is not None:
# TODO: remove when target will be supporting > 2 qubit gates # pylint:disable=fixme
if instruction.num_qubits <= 2:
Expand Down Expand Up @@ -381,7 +405,7 @@ def convert_continuous_qubit_indices(
instructions = []

for operation in simulator_action_properties.supportedOperations:
instruction = GATE_NAME_TO_QISKIT_GATE.get(operation.lower(), None)
instruction = _GATE_NAME_TO_QISKIT_GATE.get(operation.lower(), None)
if instruction is not None:
# TODO: remove when target will be supporting > 2 qubit gates # pylint:disable=fixme
if instruction.num_qubits <= 2:
Expand Down Expand Up @@ -424,13 +448,13 @@ def convert_continuous_qubit_indices(

def to_braket(
circuit: QuantumCircuit,
gateset: Optional[Iterable[str]] = None,
basis_gates: Optional[Iterable[str]] = None,
verbatim: bool = False,
) -> Circuit:
"""Return a Braket quantum circuit from a Qiskit quantum circuit.
Args:
circuit (QuantumCircuit): Qiskit quantum circuit
gateset (Optional[Iterable[str]]): The gateset to transpile to.
basis_gates (Optional[Iterable[str]]): The gateset to transpile to.
If `None`, the transpiler will use all gates defined in the Braket SDK.
Default: `None`.
verbatim (bool): Whether to translate the circuit without any modification, in other
Expand All @@ -439,15 +463,15 @@ def to_braket(
Returns:
Circuit: Braket circuit
"""
gateset = gateset or _TRANSLATABLE_QISKIT_GATE_NAMES
basis_gates = basis_gates or _TRANSLATABLE_QISKIT_GATE_NAMES
if not isinstance(circuit, QuantumCircuit):
raise TypeError(f"Expected a QuantumCircuit, got {type(circuit)} instead.")

braket_circuit = Circuit()
if not verbatim and not {gate.name for gate, _, _ in circuit.data}.issubset(
gateset
basis_gates
):
circuit = transpile(circuit, basis_gates=gateset, optimization_level=0)
circuit = transpile(circuit, basis_gates=basis_gates, optimization_level=0)

# handle qiskit to braket conversion
for circuit_instruction in circuit.data:
Expand Down Expand Up @@ -493,7 +517,7 @@ def to_braket(
)
braket_circuit += instruction
else:
for gate in GATE_NAME_TO_BRAKET_GATE[gate_name](*params):
for gate in _GATE_NAME_TO_BRAKET_GATE[gate_name](*params):
instruction = Instruction(
operator=gate,
target=[circuit.find_bit(qubit).index for qubit in qubits],
Expand Down Expand Up @@ -605,7 +629,7 @@ def to_qiskit(circuit: Circuit) -> QuantumCircuit:
def _create_qiskit_gate(
gate_name: str, gate_params: list[Union[float, Parameter]]
) -> Instruction:
gate_instance = GATE_NAME_TO_QISKIT_GATE.get(gate_name, None)
gate_instance = _GATE_NAME_TO_QISKIT_GATE.get(gate_name, None)
new_gate_params = []
for param_expression, value in zip(gate_instance.params, gate_params):
param = list(param_expression.parameters)[0]
Expand Down
32 changes: 8 additions & 24 deletions qiskit_braket_provider/providers/braket_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,15 @@
from braket.circuits import Circuit
from braket.device_schema import DeviceActionType
from braket.devices import Device, LocalSimulator
from braket.ir.openqasm.modifiers import Control
from braket.tasks.local_quantum_task import LocalQuantumTask
from qiskit import QuantumCircuit
from qiskit.providers import BackendV2, QubitProperties, Options, Provider

from .adapter import (
aws_device_to_target,
BRAKET_TO_QISKIT_NAMES,
gateset_from_properties,
local_simulator_to_target,
to_braket,
get_controlled_gateset,
)
from .braket_job import AmazonBraketTask
from .. import version
Expand All @@ -31,7 +29,7 @@
logger = logging.getLogger(__name__)


TASK_ID_DIVIDER = ";"
_TASK_ID_DIVIDER = ";"


class BraketBackend(BackendV2, ABC):
Expand All @@ -55,20 +53,7 @@ def _validate_meas_level(self, meas_level: Union[enum.Enum, int]):

def _get_gateset(self) -> Optional[set[str]]:
action = self._device.properties.action.get(DeviceActionType.OPENQASM)
if not action:
return None
gateset = {
BRAKET_TO_QISKIT_NAMES[op.lower()]
for op in action.supportedOperations
if op.lower() in BRAKET_TO_QISKIT_NAMES
}
max_control = 0
for modifier in action.supportedModifiers:
if isinstance(modifier, Control):
max_control = modifier.max_qubits
break
gateset.update(get_controlled_gateset(max_control))
return gateset
return gateset_from_properties(action) if action else None


class BraketLocalBackend(BraketBackend):
Expand Down Expand Up @@ -143,8 +128,8 @@ def run(
convert_input = (
[run_input] if isinstance(run_input, QuantumCircuit) else list(run_input)
)
gateset = self._get_gateset()
verbatim = options.pop("verbatim", False)
gateset = self._get_gateset() if not verbatim else None
circuits: list[Circuit] = [
to_braket(circ, gateset, verbatim) for circ in convert_input
]
Expand Down Expand Up @@ -172,7 +157,7 @@ def run(
logger.error("State of %s: %s.", task.id, task.state())
raise ex

task_id = TASK_ID_DIVIDER.join(task.id for task in tasks)
task_id = _TASK_ID_DIVIDER.join(task.id for task in tasks)

return AmazonBraketTask(
task_id=task_id,
Expand Down Expand Up @@ -235,7 +220,7 @@ def retrieve_job(self, task_id: str) -> AmazonBraketTask:
Returns:
The job with the given ID.
"""
task_ids = task_id.split(TASK_ID_DIVIDER)
task_ids = task_id.split(_TASK_ID_DIVIDER)

return AmazonBraketTask(
task_id=task_id,
Expand Down Expand Up @@ -328,20 +313,19 @@ def run(self, run_input, **options):
else:
raise QiskitBraketException(f"Unsupported input type: {type(run_input)}")

gateset = self._get_gateset()

if "meas_level" in options:
self._validate_meas_level(options["meas_level"])
del options["meas_level"]

verbatim = options.pop("verbatim", False)
gateset = self._get_gateset() if not verbatim else None
braket_circuits = [to_braket(circ, gateset, verbatim) for circ in circuits]

batch_task: AwsQuantumTaskBatch = self._device.run_batch(
braket_circuits, **options
)
tasks: list[AwsQuantumTask] = batch_task.tasks
task_id = TASK_ID_DIVIDER.join(task.id for task in tasks)
task_id = _TASK_ID_DIVIDER.join(task.id for task in tasks)

return AmazonBraketTask(
task_id=task_id, tasks=tasks, backend=self, shots=options.get("shots")
Expand Down
28 changes: 14 additions & 14 deletions tests/providers/test_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,9 @@
to_braket,
convert_qiskit_to_braket_circuit,
convert_qiskit_to_braket_circuits,
GATE_NAME_TO_BRAKET_GATE,
GATE_NAME_TO_QISKIT_GATE,
get_controlled_gateset,
_GATE_NAME_TO_BRAKET_GATE,
_GATE_NAME_TO_QISKIT_GATE,
_get_controlled_gateset,
)
from qiskit_braket_provider.providers.braket_backend import BraketLocalBackend

Expand Down Expand Up @@ -345,12 +345,12 @@ def test_mappers(self):

self.assertEqual(
set(qiskit_to_braket_gate_names.keys()),
set(GATE_NAME_TO_BRAKET_GATE.keys()),
set(_GATE_NAME_TO_BRAKET_GATE.keys()),
)

self.assertEqual(
set(qiskit_to_braket_gate_names.values()),
set(GATE_NAME_TO_QISKIT_GATE.keys()),
set(_GATE_NAME_TO_QISKIT_GATE.keys()),
)

def test_type_error_on_bad_input(self):
Expand Down Expand Up @@ -495,12 +495,12 @@ def test_get_controlled_gateset(self):
max1 = {"ch", "cs", "csdg", "csx", "crx", "cry", "crz", "ccz"}
max3 = max1.union({"c3sx"})
unlimited = max3.union({"mcx"})
assert get_controlled_gateset(0) == set()
assert get_controlled_gateset(1) == max1
assert get_controlled_gateset(2) == max1
assert get_controlled_gateset(3) == max3
assert get_controlled_gateset(4) == max3
assert get_controlled_gateset() == unlimited
assert _get_controlled_gateset(0) == set()
assert _get_controlled_gateset(1) == max1
assert _get_controlled_gateset(2) == max1
assert _get_controlled_gateset(3) == max3
assert _get_controlled_gateset(4) == max3
assert _get_controlled_gateset() == unlimited


class TestFromBraket(TestCase):
Expand All @@ -523,7 +523,7 @@ def test_all_standard_gates(self):

for gate_name in gate_set:
if (
gate_name.lower() not in GATE_NAME_TO_QISKIT_GATE
gate_name.lower() not in _GATE_NAME_TO_QISKIT_GATE
or gate_name.lower() in ["gpi", "gpi2", "ms"]
):
continue
Expand All @@ -543,7 +543,7 @@ def test_all_standard_gates(self):

expected_qiskit_circuit = QuantumCircuit(op.qubit_count)
expected_qiskit_circuit.append(
GATE_NAME_TO_QISKIT_GATE.get(gate_name.lower()), target
_GATE_NAME_TO_QISKIT_GATE.get(gate_name.lower()), target
)
expected_qiskit_circuit.measure_all()
expected_qiskit_circuit = expected_qiskit_circuit.assign_parameters(
Expand All @@ -564,7 +564,7 @@ def test_all_ionq_gates(self):
for gate_name in gate_set:
gate = getattr(Gate, gate_name)
value = 0.1
qiskit_gate_cls = GATE_NAME_TO_QISKIT_GATE.get(gate_name.lower()).__class__
qiskit_gate_cls = _GATE_NAME_TO_QISKIT_GATE.get(gate_name.lower()).__class__
qiskit_value = 0.1 / (2 * np.pi)
if issubclass(gate, AngledGate):
op = gate(value)
Expand Down

0 comments on commit d5eb538

Please sign in to comment.