Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Default run_options for Primitives #8513

Merged
merged 8 commits into from
Sep 5, 2022
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 40 additions & 7 deletions qiskit/primitives/base_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@
from qiskit.exceptions import QiskitError
from qiskit.opflow import PauliSumOp
from qiskit.providers import JobV1 as Job
from qiskit.providers import Options
from qiskit.quantum_info.operators import SparsePauliOp
from qiskit.quantum_info.operators.base_operator import BaseOperator
from qiskit.utils.deprecation import deprecate_arguments, deprecate_function
Expand All @@ -126,13 +127,14 @@ class BaseEstimator(ABC):
Base class for Estimator that estimates expectation values of quantum circuits and observables.
"""

__hash__ = None # type: ignore
__hash__ = None

def __init__(
self,
circuits: Iterable[QuantumCircuit] | QuantumCircuit | None = None,
observables: Iterable[SparsePauliOp] | SparsePauliOp | None = None,
parameters: Iterable[Iterable[Parameter]] | None = None,
run_options: dict | None = None,
):
"""
Creating an instance of an Estimator, or using one in a ``with`` context opens a session that
Expand All @@ -145,6 +147,7 @@ def __init__(
will be bound. Defaults to ``[circ.parameters for circ in circuits]``
The indexing is such that ``parameters[i, j]`` is the j-th formal parameter of
``circuits[i]``.
run_options: runtime options.
Cryoris marked this conversation as resolved.
Show resolved Hide resolved

Raises:
QiskitError: For mismatch of circuits and parameters list.
Expand Down Expand Up @@ -185,6 +188,9 @@ def __init__(
f"Different numbers of parameters of {i}-th circuit: "
f"expected {circ.num_parameters}, actual {len(params)}."
)
self._run_options = Options()
if run_options is not None:
self._run_options.update_options(**run_options)

def __new__(
cls,
Expand Down Expand Up @@ -258,6 +264,23 @@ def parameters(self) -> tuple[ParameterView, ...]:
"""
return tuple(self._parameters)

@property
def run_options(self) -> Options:
"""Return options values for the estimator.

Returns:
run_options
"""
return self._run_options

def set_run_options(self, **fields) -> BaseEstimator:
"""Set options values for the estimator.

Args:
**fields: The fields to update the options
"""
self._run_options.update_options(**fields)

@deprecate_function(
"The BaseSampler.__call__ method is deprecated as of Qiskit Terra 0.21.0 "
"and will be removed no sooner than 3 months after the releasedate. "
Expand Down Expand Up @@ -296,7 +319,7 @@ def __call__(
circuits: the list of circuit indices or circuit objects.
observables: the list of observable indices or observable objects.
parameter_values: concrete parameters to be bound.
run_options: runtime options used for circuit execution.
run_options: Default runtime options used for circuit execution.

Returns:
EstimatorResult: The result of the estimator.
Expand All @@ -312,7 +335,7 @@ def __call__(

# Allow objects
circuits = [
self._circuit_ids.get(id(circuit)) # type: ignore
self._circuit_ids.get(id(circuit))
if not isinstance(circuit, (int, np.integer))
else circuit
for circuit in circuits
Expand All @@ -323,7 +346,7 @@ def __call__(
"initialize the session."
)
observables = [
self._observable_ids.get(id(observable)) # type: ignore
self._observable_ids.get(id(observable))
if not isinstance(observable, (int, np.integer))
else observable
for observable in observables
Expand Down Expand Up @@ -386,12 +409,14 @@ def __call__(
f"The number of circuits is {len(self.observables)}, "
f"but the index {max(observables)} is given."
)
run_opts = copy(self.run_options)
run_opts.update_options(**run_options)

return self._call(
circuits=circuits,
observables=observables,
parameter_values=parameter_values,
**run_options,
**run_opts.__dict__,
)

def run(
Expand Down Expand Up @@ -495,8 +520,16 @@ def run(
f"not match the number of qubits of the {i}-th observable "
f"({observable.num_qubits})."
)

return self._run(circuits, observables, parameter_values, parameter_views, **run_options)
run_opts = copy(self.run_options)
run_opts.update_options(**run_options)

return self._run(
circuits,
observables,
parameter_values,
parameter_views,
**run_opts.__dict__,
)

@abstractmethod
def _call(
Expand Down
42 changes: 37 additions & 5 deletions qiskit/primitives/base_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@
from qiskit.circuit.parametertable import ParameterView
from qiskit.exceptions import QiskitError
from qiskit.providers import JobV1 as Job
from qiskit.providers import Options
from qiskit.utils.deprecation import deprecate_arguments, deprecate_function

from .sampler_result import SamplerResult
Expand All @@ -112,18 +113,20 @@ class BaseSampler(ABC):
Base class of Sampler that calculates quasi-probabilities of bitstrings from quantum circuits.
"""

__hash__ = None # type: ignore
__hash__ = None

def __init__(
self,
circuits: Iterable[QuantumCircuit] | QuantumCircuit | None = None,
parameters: Iterable[Iterable[Parameter]] | None = None,
run_options: dict | None = None,
):
"""
Args:
circuits: Quantum circuits to be executed.
parameters: Parameters of each of the quantum circuits.
Defaults to ``[circ.parameters for circ in circuits]``.
run_options: Default runtime options.

Raises:
QiskitError: For mismatch of circuits and parameters list.
Expand Down Expand Up @@ -153,6 +156,9 @@ def __init__(
f"Different number of parameters ({len(self._parameters)}) "
f"and circuits ({len(self._circuits)})"
)
self._run_options = Options()
if run_options is not None:
self._run_options.update_options(**run_options)

def __new__(
cls,
Expand Down Expand Up @@ -209,6 +215,23 @@ def parameters(self) -> tuple[ParameterView, ...]:
"""
return tuple(self._parameters)

@property
def run_options(self) -> Options:
"""Return options values for the estimator.

Returns:
run_options
"""
return self._run_options

def set_run_options(self, **fields) -> BaseSampler:
"""Set options values for the estimator.

Args:
**fields: The fields to update the options
"""
self._run_options.update_options(**fields)

@deprecate_function(
"The BaseSampler.__call__ method is deprecated as of Qiskit Terra 0.21.0 "
"and will be removed no sooner than 3 months after the releasedate. "
Expand Down Expand Up @@ -243,7 +266,7 @@ def __call__(

# Allow objects
circuits = [
self._circuit_ids.get(id(circuit)) # type: ignore
self._circuit_ids.get(id(circuit))
if not isinstance(circuit, (int, np.integer))
else circuit
for circuit in circuits
Expand Down Expand Up @@ -285,11 +308,13 @@ def __call__(
f"The number of circuits is {len(self.circuits)}, "
f"but the index {max(circuits)} is given."
)
run_opts = copy(self.run_options)
run_opts.update_options(**run_options)

return self._call(
circuits=circuits,
parameter_values=parameter_values,
**run_options,
**run_opts.__dict__,
)

def run(
Expand Down Expand Up @@ -358,8 +383,15 @@ def run(
f"The number of values ({len(parameter_value)}) does not match "
f"the number of parameters ({circuit.num_parameters}) for the {i}-th circuit."
)

return self._run(circuits, parameter_values, parameter_views, **run_options)
run_opts = copy(self.run_options)
run_opts.update_options(**run_options)

return self._run(
circuits,
parameter_values,
parameter_views,
**run_opts.__dict__,
)

@abstractmethod
def _call(
Expand Down
13 changes: 13 additions & 0 deletions qiskit/primitives/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,19 @@ def __init__(
circuits: QuantumCircuit | Iterable[QuantumCircuit] | None = None,
observables: BaseOperator | PauliSumOp | Iterable[BaseOperator | PauliSumOp] | None = None,
parameters: Iterable[Iterable[Parameter]] | None = None,
run_options: dict | None = None,
):
"""
Args:
circuits: circuits that represent quantum states.
observables: observables to be estimated.
parameters: Parameters of each of the quantum circuits.
Defaults to ``[circ.parameters for circ in circuits]``.
run_options: Default runtime options.

Raises:
QiskitError: if some classical bits are not used for measurements.
"""
if isinstance(circuits, QuantumCircuit):
circuits = (circuits,)
if circuits is not None:
Expand All @@ -69,6 +81,7 @@ def __init__(
circuits=circuits,
observables=observables, # type: ignore
parameters=parameters,
run_options=run_options,
)
self._is_closed = False

Expand Down
8 changes: 5 additions & 3 deletions qiskit/primitives/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from __future__ import annotations

from collections.abc import Iterable, Sequence
from typing import Any, cast
from typing import Any

import numpy as np

Expand Down Expand Up @@ -53,12 +53,14 @@ def __init__(
self,
circuits: QuantumCircuit | Iterable[QuantumCircuit] | None = None,
parameters: Iterable[Iterable[Parameter]] | None = None,
run_options: dict | None = None,
):
"""
Args:
circuits: circuits to be executed
parameters: Parameters of each of the quantum circuits.
Defaults to ``[circ.parameters for circ in circuits]``.
run_options: Default runtime options.

Raises:
QiskitError: if some classical bits are not used for measurements.
Expand All @@ -74,7 +76,7 @@ def __init__(
preprocessed_circuits.append(circuit)
else:
preprocessed_circuits = None
super().__init__(preprocessed_circuits, parameters)
super().__init__(preprocessed_circuits, parameters, run_options)
self._is_closed = False

def _call(
Expand Down Expand Up @@ -165,5 +167,5 @@ def _preprocess_circuit(circuit: QuantumCircuit):
)
c_q_mapping = sorted((c, q) for q, c in q_c_mapping.items())
qargs = [q for _, q in c_q_mapping]
circuit = cast(QuantumCircuit, circuit.remove_final_measurements(inplace=False))
circuit = circuit.remove_final_measurements(inplace=False)
return circuit, qargs
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
---
features:
- |
Added ``run_options`` arguments in constructor of primitives and ``run_options`` methods to
primitives. It is now possible to set default ``run_options`` in addition to passing
``run_options`` at runtime.
18 changes: 18 additions & 0 deletions test/python/primitives/test_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -568,6 +568,24 @@ def test_run_with_shots_option(self):
self.assertIsInstance(result, EstimatorResult)
np.testing.assert_allclose(result.values, [-1.307397243478641])

def test_run_options(self):
"""Test for run_options"""
with self.subTest("init"):
estimator = Estimator(run_options={"shots": 3000})
self.assertEqual(estimator.run_options.get("shots"), 3000)
with self.subTest("set_run_options"):
estimator.set_run_options(shots=1024, seed=15)
self.assertEqual(estimator.run_options.get("shots"), 1024)
self.assertEqual(estimator.run_options.get("seed"), 15)
with self.subTest("run"):
result = estimator.run(
[self.ansatz],
[self.observable],
parameter_values=[[0, 1, 1, 2, 3, 5]],
).result()
self.assertIsInstance(result, EstimatorResult)
np.testing.assert_allclose(result.values, [-1.307397243478641])


if __name__ == "__main__":
unittest.main()
14 changes: 14 additions & 0 deletions test/python/primitives/test_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -658,6 +658,20 @@ def test_primitive_job_status_done(self):
job = sampler.run(circuits=[bell])
self.assertEqual(job.status(), JobStatus.DONE)

def test_run_options(self):
"""Test for run_options"""
with self.subTest("init"):
sampler = Sampler(run_options={"shots": 3000})
self.assertEqual(sampler.run_options.get("shots"), 3000)
with self.subTest("set_run_options"):
sampler.set_run_options(shots=1024, seed=15)
self.assertEqual(sampler.run_options.get("shots"), 1024)
self.assertEqual(sampler.run_options.get("seed"), 15)
with self.subTest("run"):
params, target = self._generate_params_target([1])
result = sampler.run([self._pqc], parameter_values=params).result()
self._compare_probs(result.quasi_dists, target)


if __name__ == "__main__":
unittest.main()