diff --git a/runtime/estimator.json b/runtime/estimator.json index c729b1d6a..313f3a687 100644 --- a/runtime/estimator.json +++ b/runtime/estimator.json @@ -9,6 +9,7 @@ {"name": "parameters", "description": "The parameters to be bound.", "type": "Union[list[float], list[list[float]]]", "required": false}, {"name": "class_name", "description": "The name of the evaluator class.", "type": "str", "required": false}, {"name": "transpile_options", "description": "Options for transpile.", "type": "dict", "required": false}, - {"name": "run_options", "description": "Options for backend.run.", "type": "dict", "required": false} + {"name": "run_options", "description": "Options for backend run.", "type": "dict", "required": false}, + {"name": "measurement_error_mitigation", "description": "Whether to apply measurement error mitigation. Default is False.", "type": "bool", "required": false} ] } diff --git a/runtime/estimator.py b/runtime/estimator.py index 57d23c66b..3f7086b0c 100644 --- a/runtime/estimator.py +++ b/runtime/estimator.py @@ -1,11 +1,12 @@ -import numpy as np from dataclasses import asdict +import numpy as np +from mthree.utils import final_measurement_mapping +from qiskit.evaluators.backends import ReadoutErrorMitigation from qiskit.quantum_info import SparsePauliOp def main(backend, user_messenger, **kwargs): - state = kwargs.pop("state") observable = kwargs.pop("observable") if isinstance(observable, list): @@ -14,11 +15,28 @@ def main(backend, user_messenger, **kwargs): class_name = kwargs.pop("class_name", "PauliExpectationValue") transpile_options = kwargs.pop("transpile_options", {}) run_options = kwargs.pop("run_options", {}) + measurement_error_mitigation = kwargs.pop("measurement_error_mitigation", False) try: if class_name == "PauliExpectationValue": from qiskit.evaluators import PauliExpectationValue as Evaluator + + if measurement_error_mitigation: + expval = Evaluator(state, observable, backend) + expval.set_transpile_options(**transpile_options) + mapping = final_measurement_mapping(expval.transpiled_circuits[0]) + backend = ReadoutErrorMitigation( + backend=backend, + mitigation="mthree", + refresh=1800, # refresh the calibration data every 1800 seconds + shots=8192, # use M3's default shot number + qubits=list(mapping), + ) + else: + # use backend as is + pass elif class_name == "ExactExpectationValue": + # Note: ExactExpectationValue works only with Aer backend from qiskit.evaluators import ExactExpectationValue as Evaluator except ModuleNotFoundError: raise RuntimeError("You are not authorized to use this program.") @@ -28,9 +46,6 @@ def main(backend, user_messenger, **kwargs): expval.set_run_options(**run_options) result = expval.evaluate(parameters) - # for debug - # print(result) - ret = { key: val.tolist() if isinstance(val, np.ndarray) else val for key, val in asdict(result).items()