Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bitarray postselect (backport #12693) #12836

Merged
merged 1 commit into from
Jul 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 91 additions & 0 deletions qiskit/primitives/containers/bit_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -470,6 +470,97 @@ def slice_shots(self, indices: int | Sequence[int]) -> "BitArray":
arr = arr[..., indices, :]
return BitArray(arr, self.num_bits)

def postselect(
self,
indices: Sequence[int] | int,
selection: Sequence[bool | int] | bool | int,
) -> BitArray:
"""Post-select this bit array based on sliced equality with a given bitstring.

.. note::
If this bit array contains any shape axes, it is first flattened into a long list of shots
before applying post-selection. This is done because :class:`~BitArray` cannot handle
ragged numbers of shots across axes.

Args:
indices: A list of the indices of the cbits on which to postselect.
If this bit array was produced by a sampler, then an index ``i`` corresponds to the
:class:`~.ClassicalRegister` location ``creg[i]`` (as in :meth:`~slice_bits`).
Negative indices are allowed.

selection: A list of binary values (will be cast to ``bool``) of length matching
``indices``, with ``indices[i]`` corresponding to ``selection[i]``. Shots will be
discarded unless all cbits specified by ``indices`` have the values given by
``selection``.

Returns:
A new bit array with ``shape=(), num_bits=data.num_bits, num_shots<=data.num_shots``.

Raises:
IndexError: If ``max(indices)`` is greater than or equal to :attr:`num_bits`.
IndexError: If ``min(indices)`` is less than negative :attr:`num_bits`.
ValueError: If the lengths of ``selection`` and ``indices`` do not match.
"""
if isinstance(indices, int):
indices = (indices,)
if isinstance(selection, (bool, int)):
selection = (selection,)
selection = np.asarray(selection, dtype=bool)

num_indices = len(indices)

if len(selection) != num_indices:
raise ValueError("Lengths of indices and selection do not match.")

num_bytes = self._array.shape[-1]
indices = np.asarray(indices)

if num_indices > 0:
if indices.max() >= self.num_bits:
raise IndexError(
f"index {int(indices.max())} out of bounds for the number of bits {self.num_bits}."
)
if indices.min() < -self.num_bits:
raise IndexError(
f"index {int(indices.min())} out of bounds for the number of bits {self.num_bits}."
)

flattened = self.reshape((), self.size * self.num_shots)

# If no conditions, keep all data, but flatten as promised:
if num_indices == 0:
return flattened

# Make negative bit indices positive:
indices %= self.num_bits

# Handle special-case of contradictory conditions:
if np.intersect1d(indices[selection], indices[np.logical_not(selection)]).size > 0:
return BitArray(np.empty((0, num_bytes), dtype=np.uint8), num_bits=self.num_bits)

# Recall that creg[0] is the LSb:
byte_significance, bit_significance = np.divmod(indices, 8)
# least-significant byte is at last position:
byte_idx = (num_bytes - 1) - byte_significance
# least-significant bit is at position 0:
bit_offset = bit_significance.astype(np.uint8)

# Get bitpacked representation of `indices` (bitmask):
bitmask = np.zeros(num_bytes, dtype=np.uint8)
np.bitwise_or.at(bitmask, byte_idx, np.uint8(1) << bit_offset)

# Get bitpacked representation of `selection` (desired bitstring):
selection_bytes = np.zeros(num_bytes, dtype=np.uint8)
## This assumes no contradictions present, since those were already checked for:
np.bitwise_or.at(
selection_bytes, byte_idx, np.asarray(selection, dtype=np.uint8) << bit_offset
)

return BitArray(
flattened._array[((flattened._array & bitmask) == selection_bytes).all(axis=-1)],
num_bits=self.num_bits,
)

def expectation_values(self, observables: ObservablesArrayLike) -> NDArray[np.float64]:
"""Compute the expectation values of the provided observables, broadcasted against
this bit array.
Expand Down
11 changes: 11 additions & 0 deletions releasenotes/notes/bitarray-postselect-659b8f7801ccaa60.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
---
features_primitives:
- |
Added a new method :meth:`.BitArray.postselect` that returns all shots containing specified bit values.
Example usage::

from qiskit.primitives.containers import BitArray

ba = BitArray.from_counts({'110': 2, '100': 4, '000': 3})
print(ba.postselect([0,2], [0,1]).get_counts())
# {'110': 2, '100': 4}
79 changes: 79 additions & 0 deletions test/python/primitives/containers/test_bit_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -719,3 +719,82 @@ def test_expectation_values(self):
_ = ba.expectation_values("Z")
with self.assertRaisesRegex(ValueError, "is not diagonal"):
_ = ba.expectation_values("X" * ba.num_bits)

def test_postselection(self):
"""Test the postselection method."""

flat_data = np.array(
[
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 1, 1, 1, 1, 1],
[0, 1, 0, 1, 0, 1, 0, 1, 0, 1],
],
dtype=bool,
)

shaped_data = np.array(
[
[
[
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 1, 1, 1, 1, 1],
[0, 1, 0, 1, 0, 1, 0, 1, 0, 1],
],
[
[1, 0, 1, 0, 1, 0, 1, 0, 1, 0],
[1, 1, 1, 1, 1, 0, 0, 0, 0, 0],
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
],
]
],
dtype=bool,
)

for dataname, bool_array in zip(["flat", "shaped"], [flat_data, shaped_data]):

bit_array = BitArray.from_bool_array(bool_array, order="little")
# indices value of i <-> creg[i] <-> bool_array[..., i]

num_bits = bool_array.shape[-1]
bool_array = bool_array.reshape(-1, num_bits)

test_cases = [
("basic", [0, 1], [0, 0]),
("multibyte", [0, 9], [0, 1]),
("repeated", [5, 5, 5], [0, 0, 0]),
("contradict", [5, 5, 5], [1, 0, 0]),
("unsorted", [5, 0, 9, 3], [1, 0, 1, 0]),
("negative", [-5, 1, -2, -10], [1, 0, 1, 0]),
("negcontradict", [4, -6], [1, 0]),
("trivial", [], []),
("bareindex", 6, 0),
]

for name, indices, selection in test_cases:
with self.subTest("_".join([dataname, name])):
postselected_bools = np.unpackbits(
bit_array.postselect(indices, selection).array[:, ::-1],
count=num_bits,
axis=-1,
bitorder="little",
).astype(bool)
if isinstance(indices, int):
indices = (indices,)
if isinstance(selection, bool):
selection = (selection,)
answer = bool_array[np.all(bool_array[:, indices] == selection, axis=-1)]
if name in ["contradict", "negcontradict"]:
self.assertEqual(len(answer), 0)
else:
self.assertGreater(len(answer), 0)
np.testing.assert_equal(postselected_bools, answer)

error_cases = [
("aboverange", [0, 6, 10], [True, True, False], IndexError),
("belowrange", [0, 6, -11], [True, True, False], IndexError),
("mismatch", [0, 1, 2], [False, False], ValueError),
]
for name, indices, selection, error in error_cases:
with self.subTest(dataname + "_" + name):
with self.assertRaises(error):
bit_array.postselect(indices, selection)
Loading