diff --git a/qiskit_braket_provider/providers/braket_backend.py b/qiskit_braket_provider/providers/braket_backend.py index 027cdfeb..420e8213 100644 --- a/qiskit_braket_provider/providers/braket_backend.py +++ b/qiskit_braket_provider/providers/braket_backend.py @@ -3,6 +3,7 @@ import datetime import logging +import enum from abc import ABC from typing import Iterable, Union, List @@ -36,6 +37,15 @@ class BraketBackend(BackendV2, ABC): def __repr__(self): return f"BraketBackend[{self.name}]" + def _validate_meas_level(self, meas_level: Union[enum.Enum, int]): + if isinstance(meas_level, enum.Enum): + meas_level = meas_level.value + if meas_level != 2: + raise QiskitBraketException( + f"Device {self.name} only supports classified measurement " + f"results, received meas_level={meas_level}." + ) + class BraketLocalBackend(BraketBackend): """BraketLocalBackend.""" @@ -109,6 +119,9 @@ def run( shots = options["shots"] if "shots" in options else 1024 if shots == 0: circuits = list(map(lambda x: x.state_vector(), circuits)) + if "meas_level" in options: + self._validate_meas_level(options["meas_level"]) + del options["meas_level"] tasks = [] try: for circuit in circuits: @@ -278,6 +291,10 @@ def run(self, run_input, **options): else: raise QiskitBraketException(f"Unsupported input type: {type(run_input)}") + if "meas_level" in options: + self._validate_meas_level(options["meas_level"]) + del options["meas_level"] + braket_circuits = list(convert_qiskit_to_braket_circuits(circuits)) if options.pop("verbatim", False): diff --git a/tests/providers/test_adapter.py b/tests/providers/test_adapter.py index 1d3a80b2..991013ff 100644 --- a/tests/providers/test_adapter.py +++ b/tests/providers/test_adapter.py @@ -10,11 +10,11 @@ from qiskit import ( QuantumCircuit, - BasicAer, QuantumRegister, ClassicalRegister, transpile, ) +from qiskit.providers.basicaer import BasicAer from qiskit.circuit import Parameter from qiskit.circuit.library import PauliEvolutionGate from qiskit.quantum_info import SparsePauliOp diff --git a/tests/providers/test_braket_backend.py b/tests/providers/test_braket_backend.py index d0bee49d..4e95e1f9 100644 --- a/tests/providers/test_braket_backend.py +++ b/tests/providers/test_braket_backend.py @@ -6,7 +6,8 @@ from botocore import errorfactory from braket.aws.queue_information import QueueDepthInfo, QueueType -from qiskit import QuantumCircuit, transpile, BasicAer +from qiskit import QuantumCircuit, transpile +from qiskit.providers.basicaer import BasicAer from qiskit.algorithms.minimum_eigensolvers import VQE, VQEResult @@ -20,7 +21,7 @@ from qiskit.transpiler import Target from qiskit.primitives import BackendEstimator -from qiskit_braket_provider import AWSBraketProvider, version +from qiskit_braket_provider import AWSBraketProvider, version, exception from qiskit_braket_provider.providers import AWSBraketBackend, BraketLocalBackend from qiskit_braket_provider.providers.adapter import aws_device_to_target from tests.providers.mocks import ( @@ -141,6 +142,23 @@ def test_local_backend_circuit_shots0(self): self.assertEqual(statevector[2], 0.0 + 0.0j) self.assertEqual(statevector[3], 1.0 + 0.0j) + def test_meas_level_2(self): + """Check that there's no error for asking for classified measurement results.""" + backend = BraketLocalBackend(name="default") + circuit = QuantumCircuit(1, 1) + circuit.h(0) + circuit.measure(0, 0) + backend.run(circuit, shots=10, meas_level=2) + + def test_meas_level_1(self): + """Check that there's an exception for asking for raw measurement results.""" + backend = BraketLocalBackend(name="default") + circuit = QuantumCircuit(1, 1) + circuit.h(0) + circuit.measure(0, 0) + with self.assertRaises(exception.QiskitBraketException): + backend.run(circuit, shots=10, meas_level=1) + def test_vqe(self): """Tests VQE.""" local_simulator = BraketLocalBackend(name="default")