Skip to content

Commit

Permalink
Improve controlled gate support (#164)
Browse files Browse the repository at this point in the history
Right now, a controlled gate is included in the supported gates as long as its control qubit count is supported. This change checks whether the base gate of the controlled gate (e.g. rx for crx) is supported as well.
  • Loading branch information
speller26 authored Mar 6, 2024
1 parent b1dda9f commit e62177d
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 32 deletions.
78 changes: 52 additions & 26 deletions qiskit_braket_provider/providers/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,10 +69,19 @@
}

_CONTROLLED_GATES_BY_QUBIT_COUNT = {
1: {"ch", "cs", "csdg", "csx", "crx", "cry", "crz", "ccz"},
3: {"c3sx"},
1: {
"ch": "h",
"cs": "s",
"csdg": "sdg",
"csx": "sx",
"crx": "rx",
"cry": "ry",
"crz": "rz",
"ccz": "cz",
},
3: {"c3sx": "sx"},
}
_ARBITRARY_CONTROLLED_GATES = {"mcx"}
_ARBITRARY_CONTROLLED_GATES = {"mcx": "cx"}

_ADDITIONAL_U_GATES = {"u1", "u2", "u3"}

Expand Down Expand Up @@ -122,16 +131,9 @@
}

_QISKIT_CONTROLLED_GATE_NAMES_TO_BRAKET_GATES: dict[str, Callable] = {
"ch": braket_gates.H,
"cs": braket_gates.S,
"csdg": braket_gates.Si,
"csx": braket_gates.V,
"ccz": braket_gates.CZ,
"c3sx": braket_gates.V,
"mcx": braket_gates.CNot,
"crx": braket_gates.Rx,
"cry": braket_gates.Ry,
"crz": braket_gates.Rz,
controlled_gate: _GATE_NAME_TO_BRAKET_GATE[base_gate]
for gate_map in _CONTROLLED_GATES_BY_QUBIT_COUNT.values()
for controlled_gate, base_gate in gate_map.items()
}

_TRANSLATABLE_QISKIT_GATE_NAMES = (
Expand Down Expand Up @@ -197,23 +199,26 @@ def gateset_from_properties(properties: OpenQASMDeviceActionProperties) -> set[s
for op in properties.supportedOperations
if op.lower() in _BRAKET_TO_QISKIT_NAMES
}
if "u" in gateset:
gateset.update(_ADDITIONAL_U_GATES)
max_control = 0
for modifier in properties.supportedModifiers:
if isinstance(modifier, Control):
max_control = modifier.max_qubits
break
gateset.update(_get_controlled_gateset(max_control))
if "u" in gateset:
gateset.update(_ADDITIONAL_U_GATES)
gateset.update(_get_controlled_gateset(gateset, max_control))
return gateset


def _get_controlled_gateset(max_qubits: Optional[int] = None) -> set[str]:
def _get_controlled_gateset(
base_gateset: set[str], max_qubits: Optional[int] = None
) -> set[str]:
"""Returns the Qiskit gates expressible as controlled versions of existing Braket gates
This set can be filtered by the maximum number of control qubits.
Args:
base_gateset (set[str]): The base (without control modifiers) gates supported
max_qubits (Optional[int]): The maximum number of control qubits that can be used to express
the Qiskit gate as a controlled Braket gate. If `None`, then there is no limit to the
number of control qubits. Default: `None`.
Expand All @@ -222,11 +227,30 @@ def _get_controlled_gateset(max_qubits: Optional[int] = None) -> set[str]:
set[str]: The names of the controlled gates.
"""
if max_qubits is None:
gateset = set().union(*[g for _, g in _CONTROLLED_GATES_BY_QUBIT_COUNT.items()])
gateset = set().union(
[
controlled_gate
for gate_map in _CONTROLLED_GATES_BY_QUBIT_COUNT.values()
for controlled_gate, base_gate in gate_map.items()
if base_gate in base_gateset
]
)
gateset.update(
[
controlled_gate
for controlled_gate, base_gate in _ARBITRARY_CONTROLLED_GATES.items()
if base_gate in base_gateset
]
)
gateset.update(_ARBITRARY_CONTROLLED_GATES)
return gateset
return set().union(
*[g for q, g in _CONTROLLED_GATES_BY_QUBIT_COUNT.items() if q <= max_qubits]
[
controlled_gate
for control_count, gate_map in _CONTROLLED_GATES_BY_QUBIT_COUNT.items()
for controlled_gate, base_gate in gate_map.items()
if control_count <= max_qubits and base_gate in base_gateset
]
)


Expand Down Expand Up @@ -453,13 +477,15 @@ def to_braket(
qubit_indices = [circuit.find_bit(qubit).index for qubit in qubits]
params = _create_free_parameters(operation)
if gate_name in _QISKIT_CONTROLLED_GATE_NAMES_TO_BRAKET_GATES:
gate = _QISKIT_CONTROLLED_GATE_NAMES_TO_BRAKET_GATES[gate_name](*params)
gate_qubit_count = gate.qubit_count
braket_circuit += Instruction(
operator=gate,
target=qubit_indices[-gate_qubit_count:],
control=qubit_indices[:-gate_qubit_count],
)
for gate in _QISKIT_CONTROLLED_GATE_NAMES_TO_BRAKET_GATES[gate_name](
*params
):
gate_qubit_count = gate.qubit_count
braket_circuit += Instruction(
operator=gate,
target=qubit_indices[-gate_qubit_count:],
control=qubit_indices[:-gate_qubit_count],
)
else:
for gate in _GATE_NAME_TO_BRAKET_GATE[gate_name](*params):
braket_circuit += Instruction(
Expand Down
21 changes: 15 additions & 6 deletions tests/providers/test_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -440,15 +440,24 @@ def test_invalid_ctrl_state(self, mock_transpile):

def test_get_controlled_gateset(self):
"""Tests that the correct controlled gateset is returned for all maximum qubit counts."""
full_gateset = {"h", "s", "sdg", "sx", "rx", "ry", "rz", "cz"}
restricted_gateset = {"rx", "cx", "sx"}
max1 = {"ch", "cs", "csdg", "csx", "crx", "cry", "crz", "ccz"}
max3 = max1.union({"c3sx"})
unlimited = max3.union({"mcx"})
assert _get_controlled_gateset(0) == set()
assert _get_controlled_gateset(1) == max1
assert _get_controlled_gateset(2) == max1
assert _get_controlled_gateset(3) == max3
assert _get_controlled_gateset(4) == max3
assert _get_controlled_gateset() == unlimited
assert _get_controlled_gateset(full_gateset, 0) == set()
assert _get_controlled_gateset(full_gateset, 1) == max1
assert _get_controlled_gateset(full_gateset, 2) == max1
assert _get_controlled_gateset(full_gateset, 3) == max3
assert _get_controlled_gateset(full_gateset, 4) == max3
assert _get_controlled_gateset(full_gateset) == unlimited
assert _get_controlled_gateset(restricted_gateset, 3) == {"crx", "csx", "c3sx"}
assert _get_controlled_gateset(restricted_gateset) == {
"crx",
"csx",
"c3sx",
"mcx",
}


class TestFromBraket(TestCase):
Expand Down

0 comments on commit e62177d

Please sign in to comment.