Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add BaseResult functionality to primitives module #8091

Merged
merged 28 commits into from
Sep 14, 2022
Merged
Show file tree
Hide file tree
Changes from 21 commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
b23ab7a
Add BaseResult functionality to primitives module
pedrorrivero May 19, 2022
229f00a
Translate tests from pytest to unittest
pedrorrivero May 20, 2022
d64770d
Update docstring
pedrorrivero May 22, 2022
3d61f62
Update __post_init__ docstring
pedrorrivero May 24, 2022
41f384e
Add experiments property and update types
pedrorrivero May 26, 2022
38492e7
Refactor variable names
pedrorrivero May 26, 2022
5e24972
Refactor _field_values
pedrorrivero May 26, 2022
28a21d2
Finilize docstring
pedrorrivero Jun 8, 2022
c94241f
Add renlease note
pedrorrivero Jun 8, 2022
407dc8c
Fix lint errors
pedrorrivero Jun 8, 2022
6a80827
Fix lint docstring errors
pedrorrivero Jun 8, 2022
d942927
Fix formatting
pedrorrivero Jun 8, 2022
f366018
Fix variable names
pedrorrivero Jun 8, 2022
cfe5cd5
Remove BaseResult from primitives init and update test imports
pedrorrivero Jun 9, 2022
bad3567
Rename BaseResult to BasePrimitiveResult
pedrorrivero Jun 24, 2022
b107c9f
Update qiskit/primitives/base_result.py
pedrorrivero Jun 24, 2022
795e1d3
Merge branch 'main' into primitive-base-result
pedrorrivero Sep 8, 2022
0d505a1
Cache properties in base result since child dataclasses are frozen
pedrorrivero Sep 8, 2022
7398fab
Revert "Cache properties in base result since child dataclasses are f…
pedrorrivero Sep 8, 2022
75c8c2a
Update experiments type, add decompose method, and validate data types
pedrorrivero Sep 9, 2022
9d82b99
Fix lint
pedrorrivero Sep 9, 2022
6cc107c
Fix BaseResult post_init docstring
pedrorrivero Sep 9, 2022
689439e
Merge branch 'main' into primitive-base-result
pedrorrivero Sep 9, 2022
9232e2b
Index BasePrimitiveResult in primitives module __init__
pedrorrivero Sep 9, 2022
161d4e0
Merge branch 'main' into primitive-base-result
pedrorrivero Sep 12, 2022
2d6d137
Add bytes case to TestBasePrimitiveResult
pedrorrivero Sep 14, 2022
5045805
Exclude bytes from possible field types in BasePrimitiveResult
pedrorrivero Sep 14, 2022
2d0d40a
Merge branch 'main' into primitive-base-result
pedrorrivero Sep 14, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 87 additions & 0 deletions qiskit/primitives/base_result.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
# This code is part of Qiskit.
#
# (C) Copyright IBM 2022.
#
# This code is licensed under the Apache License, Version 2.0. You may
# obtain a copy of this license in the LICENSE.txt file in the root directory
# of this source tree or at http://www.apache.org/licenses/LICENSE-2.0.
#
# Any modifications or derivative works of this code must retain this
# copyright notice, and modified files need to carry a notice indicating
# that they have been altered from the originals.
"""
Primitive result abstract base class
"""

from __future__ import annotations

from abc import ABC
from collections.abc import Iterator, Sequence
from dataclasses import fields
from typing import Any, Dict

from numpy import ndarray


ExperimentData = Dict[str, Any]


class BasePrimitiveResult(ABC):
"""Primitive result abstract base class.

Base class for Primitive results meant to provide common functionality to all inheriting
result dataclasses.
"""

def __post_init__(self) -> None:
"""
Verify that all fields in any inheriting result dataclass are consistent, after
instantiation, with the number of experiments being represented.

This magic method is specific of `dataclasses.dataclass`, therefore all inheriting
classes must have this decorator.

Raises:
TypeError: If one of the data fields is not a Sequence or `numpy.ndarray`.
pedrorrivero marked this conversation as resolved.
Show resolved Hide resolved
ValueError: Inconsistent number of experiments across data fields.
"""
for value in self._field_values: # type: Sequence
# TODO: enforce all data fields to be tuples instead of sequences
if not isinstance(value, (Sequence, ndarray)) or isinstance(value, str):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm wondering whether the check isinstance(value, str) is enough or not. What if bytes object is given?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we check each entry of Sequence perhaps? Do you have any suggestion, @ikkoham?

Copy link
Member Author

@pedrorrivero pedrorrivero Sep 14, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The problem with checking each sequence entry is that the types will differ for each data attribute (e.g. values and metadata), and between primitives (e.g. sampler and estimator). So those checks should probably live under the particular primitives' __post_init__ methods (e.g. Estimator.__post_init__), instead of inside BasePrimitiveResult.

Is this what you are referring to?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. I just wondered whether there is any simpler way of this check.

raise TypeError(
f"Expected sequence or `numpy.ndarray`, provided {type(value)} instead."
)
if len(value) != self.num_experiments:
raise ValueError("Inconsistent number of experiments across data fields.")

@property # TODO: functools.cached_property when py37 is droppped
def num_experiments(self) -> int:
"""Number of experiments in any inheriting result dataclass."""
value: Sequence = self._field_values[0]
return len(value)

@property # TODO: functools.cached_property when py37 is droppped
def experiments(self) -> tuple[ExperimentData, ...]:
"""Experiment data dicts in any inheriting result dataclass."""
return tuple(self._generate_experiments())

def _generate_experiments(self) -> Iterator[ExperimentData]:
"""Generate experiment data dicts in any inheriting result dataclass."""
names: tuple[str, ...] = self._field_names
for values in zip(*self._field_values):
yield dict(zip(names, values))

def decompose(self) -> Iterator[BasePrimitiveResult]:
"""Generate single experiment result objects from self."""
for values in zip(*self._field_values):
yield self.__class__(*[(v,) for v in values])

@property # TODO: functools.cached_property when py37 is droppped
def _field_names(self) -> tuple[str, ...]:
"""Tuple of field names in any inheriting result dataclass."""
return tuple(field.name for field in fields(self))

@property # TODO: functools.cached_property when py37 is droppped
def _field_values(self) -> tuple:
"""Tuple of field values in any inheriting result dataclass."""
return tuple(getattr(self, name) for name in self._field_names)
4 changes: 3 additions & 1 deletion qiskit/primitives/estimator_result.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,14 @@
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any

from .base_result import BasePrimitiveResult

if TYPE_CHECKING:
import numpy as np


@dataclass(frozen=True)
class EstimatorResult:
class EstimatorResult(BasePrimitiveResult):
"""Result of Estimator.

.. code-block:: python
Expand Down
4 changes: 3 additions & 1 deletion qiskit/primitives/sampler_result.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,11 @@

from qiskit.result import QuasiDistribution

from .base_result import BasePrimitiveResult


@dataclass(frozen=True)
class SamplerResult:
class SamplerResult(BasePrimitiveResult):
"""Result of Sampler.

.. code-block:: python
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
---
features:
- |
Adds a primitive result :class:`qiskit.primitives.base_result.BasePrimitiveResult` to provide common functionality to all inheriting result dataclasses.
* Adds :py:meth:`result.num_experiments` property.
* Adds :py:meth:`result.experiments` property.
* Adds :py:meth:`result.decompose`.
* Validates data types after instantiation (i.e. on dataclass :py:meth:`__post_init__`).
* Checks for consistency in the number of experiments across data fields after instantiation (i.e. on dataclass :py:meth:`__post_init__`).
93 changes: 93 additions & 0 deletions test/python/primitives/test_result.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
# This code is part of Qiskit.
#
# (C) Copyright IBM 2022.
#
# This code is licensed under the Apache License, Version 2.0. You may
# obtain a copy of this license in the LICENSE.txt file in the root directory
# of this source tree or at http://www.apache.org/licenses/LICENSE-2.0.
#
# Any modifications or derivative works of this code must retain this
# copyright notice, and modified files need to carry a notice indicating
# that they have been altered from the originals.

"""Tests for BasePrimitiveResult."""

from __future__ import annotations

from collections.abc import Collection
from dataclasses import dataclass
from typing import Any

from ddt import data, ddt, unpack

from qiskit.primitives.base_result import BasePrimitiveResult
from qiskit.test import QiskitTestCase


################################################################################
## STUB DATACLASS
################################################################################
@dataclass
class Result(BasePrimitiveResult):
"""Dummy result dataclass implementing BasePrimitiveResult."""

field_1: Collection[Any]
field_2: Collection[Any]


################################################################################
## TESTS
################################################################################
@ddt
class TestBasePrimitiveResult(QiskitTestCase):
"""Tests BasePrimitiveResult."""

@data(0, 1.2, True, "sequence", {"name": "value"})
pedrorrivero marked this conversation as resolved.
Show resolved Hide resolved
def test_post_init_type_error(self, field_1):
"""Tests post init type error."""
self.assertRaises(TypeError, Result, *(field_1, []))

@data(([1], []), ([], [1]), ([1, 2], []), ([1], [1, 2]))
@unpack
def test_post_init_value_error(self, field_1, field_2):
"""Tests post init value error."""
self.assertRaises(ValueError, Result, *(field_1, field_2))

@data(0, 1, 2, 3)
def test_num_experiments(self, num_experiments):
"""Tests {num_experiments} num_experiments."""
result = Result([0] * num_experiments, [1] * num_experiments)
self.assertEqual(num_experiments, result.num_experiments)

@data(0, 1, 2, 3)
def test_experiments(self, num_experiments):
"""Test experiment data."""
field_1 = list(range(num_experiments))
field_2 = [i + 1 for i in range(num_experiments)]
experiments = Result(field_1, field_2).experiments
self.assertIsInstance(experiments, tuple)
for i, exp in enumerate(experiments):
self.assertEqual(exp, {"field_1": i, "field_2": i + 1})

@data(0, 1, 2, 3)
def test_decompose(self, num_experiments):
"""Test decompose."""
field_1 = list(range(num_experiments))
field_2 = [i + 1 for i in range(num_experiments)]
result = Result(field_1, field_2)
for i, res in enumerate(result.decompose()):
self.assertIsInstance(res, Result)
f1, f2 = (i,), (i + 1,)
self.assertEqual(res, Result(f1, f2))

def test_field_names(self):
"""Tests field names ("field_1", "field_2")."""
result = Result([], [])
self.assertEqual(result._field_names, ("field_1", "field_2"))

@data(([], []), ([0], [0]), ([0], [1]))
@unpack
def test_field_values(self, field_1, field_2):
"""Tests field values ({field_1}, {field_2})."""
result = Result(field_1, field_2)
self.assertEqual(result._field_values, (field_1, field_2))