Skip to content

Commit

Permalink
Avoid exploring extraneous minima in the cut-finder search space (#585)…
Browse files Browse the repository at this point in the history
… (#588)

* Avoid exploring extraneous minima in the search space

* fix failing test

* fix coverage

* black

* update doc string

* update doc string

Co-authored-by: Jim Garrison <garrison@ibm.com>

* add new tests and modify states check

* update test description

* style

* change to namedtuple, add release note

* update return

Co-authored-by: Jim Garrison <garrison@ibm.com>

* change type hints

---------

Co-authored-by: Jim Garrison <garrison@ibm.com>
(cherry picked from commit 2bcbe7f)

Co-authored-by: Ibrahim Shehzad <75153717+ibrahim-shehzad@users.noreply.github.com>
  • Loading branch information
mergify[bot] and ibrahim-shehzad authored May 14, 2024
1 parent a721b90 commit 7e1a3ab
Show file tree
Hide file tree
Showing 7 changed files with 129 additions and 35 deletions.
42 changes: 27 additions & 15 deletions circuit_knitting/cutting/cut_finding/best_first_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

import heapq
import numpy as np
from typing import TYPE_CHECKING, Callable, cast
from typing import TYPE_CHECKING, Callable, cast, NamedTuple
from itertools import count

from .optimization_settings import OptimizationSettings
Expand All @@ -26,6 +26,21 @@
from .cut_optimization import CutOptimizationFuncArgs


class SearchStats(NamedTuple):
"""NamedTuple for collecting search statistics.
It carries information about the number of states visited
(dequeued from the search queue), the number of next-states generated,
the number of next-states that are enqueued after cost pruning,
and the number of backjumps performed.
"""

states_visited: int
next_states_generated: int
states_enqueued: int
backjumps: int


class BestFirstPriorityQueue:
"""Class that implements priority queues for best-first search.
Expand Down Expand Up @@ -149,6 +164,8 @@ class BestFirstSearch:
``stop_at_first_min`` (Boolean) is a flag that indicates whether or not to
stop the search after the first minimum-cost goal state has been reached.
In the absence of any non-LO QPD assignments, it always makes sense to stop once
the first minimum has been reached and therefore, we set this bool to ``True``.
``max_backjumps`` (int or None) is the maximum number of backjump operations that
can be performed before the search is forced to terminate. None indicates
Expand Down Expand Up @@ -185,7 +202,7 @@ def __init__(
self,
optimization_settings: OptimizationSettings,
search_functions: SearchFunctions,
stop_at_first_min: bool = False,
stop_at_first_min: bool = True,
):
"""Initialize an instance of :class:`BestFirstSearch`.
Expand Down Expand Up @@ -213,7 +230,7 @@ def __init__(
self.num_next_states = 0
self.num_enqueues = 0
self.num_backjumps = 0
self.penultimate_stats: np.typing.NDArray | None = None
self.penultimate_stats: SearchStats | None = None

def initialize(
self,
Expand Down Expand Up @@ -258,7 +275,6 @@ def optimization_pass(
self.mincost_bound = self.mincost_bound_func(*args) # type: ignore

prev_depth = None

while (
self.pqueue.qsize() > 0
and (not self.stop_at_first_min or not self.min_reached)
Expand All @@ -267,7 +283,6 @@ def optimization_pass(
state, depth, cost = self.pqueue.get()

self.update_minimum_reached(cost)

if cost is None or self.cost_bounds_exceeded(cost):
return None, None

Expand Down Expand Up @@ -299,10 +314,10 @@ def minimum_reached(self) -> bool:
"""Return True if the optimization reached a global minimum."""
return self.min_reached

def get_stats(self, penultimate: bool = False) -> np.typing.NDArray[np.int_] | None:
def get_stats(self, penultimate: bool = False) -> SearchStats | None:
"""Return statistics of the search that was performed.
This is a Numpy array containing the number of states visited
This is a NamedTuple containing the number of states visited
(dequeued), the number of next-states generated, the number of
next-states that are enqueued after cost pruning, and the number
of backjumps performed. Return None if no search is performed.
Expand All @@ -312,14 +327,11 @@ def get_stats(self, penultimate: bool = False) -> np.typing.NDArray[np.int_] | N
if penultimate:
return self.penultimate_stats

return np.array(
(
self.num_states_visited,
self.num_next_states,
self.num_enqueues,
self.num_backjumps,
),
dtype=int,
return SearchStats(
states_visited=self.num_states_visited,
next_states_generated=self.num_next_states,
states_enqueued=self.num_enqueues,
backjumps=self.num_backjumps,
)

def get_upperbound_cost(
Expand Down
6 changes: 3 additions & 3 deletions circuit_knitting/cutting/cut_finding/cut_optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
import numpy as np
from dataclasses import dataclass
from typing import cast
from numpy.typing import NDArray
from .search_space_generator import ActionNames
from .cco_utils import select_search_engine, greedy_best_first_search
from .cutting_actions import disjoint_subcircuit_actions
Expand All @@ -27,6 +26,7 @@
SearchFunctions,
SearchSpaceGenerator,
)
from .best_first_search import SearchStats
from .disjoint_subcircuits_state import DisjointSubcircuitsState
from .circuit_interface import SimpleGateList, GateSpec
from .optimization_settings import OptimizationSettings
Expand Down Expand Up @@ -261,7 +261,7 @@ def __init__(
"CutOptimization",
self.settings,
self.search_funcs,
stop_at_first_min=False,
stop_at_first_min=True,
)
sq.initialize([start_state], self.func_args)

Expand Down Expand Up @@ -299,7 +299,7 @@ def minimum_reached(self) -> bool:
"""
return self.search_engine.minimum_reached()

def get_stats(self, penultimate: bool = False) -> NDArray[np.int_]:
def get_stats(self, penultimate: bool = False) -> SearchStats | None:
"""Return the search-engine statistics.
This is a Numpy array containing the number of states visited
Expand Down
9 changes: 3 additions & 6 deletions circuit_knitting/cutting/cut_finding/lo_cuts_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
"""File containing the wrapper class for optimizing LO gate and wire cuts."""
from __future__ import annotations

from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, NamedTuple

from .cut_optimization import CutOptimization
from .cut_optimization import disjoint_subcircuit_actions
Expand All @@ -21,9 +21,6 @@
from .cut_optimization import cut_optimization_min_cost_bound_func
from .cut_optimization import cut_optimization_upper_bound_cost_func
from .search_space_generator import SearchFunctions, SearchSpaceGenerator

import numpy as np
from numpy.typing import NDArray
from .disjoint_subcircuits_state import DisjointSubcircuitsState

if TYPE_CHECKING: # pragma: no cover
Expand Down Expand Up @@ -155,10 +152,10 @@ def get_results(self) -> DisjointSubcircuitsState | None:
"""Return the optimization results."""
return self.best_result

def get_stats(self, penultimate=False) -> dict[str, NDArray[np.int_]]:
def get_stats(self, penultimate=False) -> dict[str, NamedTuple | None]:
"""Return a dictionary containing optimization results.
The value is a Numpy array containing the number of states visited
The value is a NamedTuple containing the number of states visited
(dequeued), the number of next-states generated, the number of
next-states that are enqueued after cost pruning, and the number
of backjumps performed. Return None if no search is performed.
Expand Down
14 changes: 7 additions & 7 deletions docs/circuit_cutting/tutorials/04_automatic_cut_finding.ipynb

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
---
upgrade:
- |
The search engine inside the automated cut-finder has been primed to avoid extraneous searches and is therefore expected to run faster.
83 changes: 82 additions & 1 deletion test/cutting/cut_finding/test_best_first_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,31 @@
CircuitElement,
GateSpec,
)
from circuit_knitting.cutting.cut_finding.cut_optimization import CutOptimization
from circuit_knitting.cutting.cut_finding.cut_optimization import (
cut_optimization_next_state_func,
cut_optimization_min_cost_bound_func,
cut_optimization_cost_func,
cut_optimization_goal_state_func,
cut_optimization_upper_bound_cost_func,
CutOptimizationFuncArgs,
CutOptimization,
)
from circuit_knitting.cutting.cut_finding.optimization_settings import (
OptimizationSettings,
)
from circuit_knitting.cutting.automated_cut_finding import DeviceConstraints
from circuit_knitting.cutting.cut_finding.disjoint_subcircuits_state import (
get_actions_list,
)
from circuit_knitting.cutting.cut_finding.cutting_actions import (
disjoint_subcircuit_actions,
DisjointSubcircuitsState,
)

from circuit_knitting.cutting.cut_finding.best_first_search import (
BestFirstSearch,
SearchFunctions,
)


@fixture
Expand Down Expand Up @@ -124,3 +141,67 @@ def test_best_first_search(test_circuit: SimpleGateList):
assert op.get_upperbound_cost() == (27, inf)
assert op.minimum_reached() is True
assert out is None


def test_best_first_search_termination():
"""Test that if the best first search is run multiple times, it terminates once no further feasible cut states can be found,
in which case None is returned for both the cost and the state. This test also serves to describe the workflow of the optimizer
at a granular level."""

# Specify circuit
circuit = [
CircuitElement(name="cx", params=[], qubits=[0, 1], gamma=3),
CircuitElement(name="cx", params=[], qubits=[2, 3], gamma=3),
CircuitElement(name="cx", params=[], qubits=[1, 2], gamma=3),
]

interface = SimpleGateList(circuit)

# Specify optimization settings, search engine, and device constraints.
settings = OptimizationSettings(seed=123)
settings.set_engine_selection("CutOptimization", "BestFirst")

constraints = DeviceConstraints(qubits_per_subcircuit=3)

# Initialize and pass arguments to search space generating object.
func_args = CutOptimizationFuncArgs()
func_args.entangling_gates = interface.get_multiqubit_gates()
func_args.search_actions = disjoint_subcircuit_actions
func_args.max_gamma = settings.get_max_gamma
func_args.qpu_width = constraints.get_qpu_width()

# Initialize search functions object, needed to explore a search space.
cut_optimization_search_funcs = SearchFunctions(
cost_func=cut_optimization_cost_func,
upperbound_cost_func=cut_optimization_upper_bound_cost_func,
next_state_func=cut_optimization_next_state_func,
goal_state_func=cut_optimization_goal_state_func,
mincost_bound_func=cut_optimization_min_cost_bound_func,
)

# Initialize disjoint subcircuits state object
# while specifying number of qubits and max allowed wire cuts.
state = DisjointSubcircuitsState(interface.get_num_qubits(), 2)

# Initialize bfs object.
bfs = BestFirstSearch(
optimization_settings=settings, search_functions=cut_optimization_search_funcs
)

# Push an input state.
bfs.initialize([state], func_args)

counter = 0

cut_state = state
while cut_state is not None:
cut_state, cut_cost = bfs.optimization_pass(func_args)
counter += 1

# There are 5 possible cut states that can be found for this circuit,
# given that there need to be 3 qubits per subcircuit. These correspond
# to 3 gate cuts (i.e cutting any of the 3 gates) and cutting either of
# the input wires to the CNOT between qubits 1 and 2.
# After these 5 possible cuts are returned, at the 6th iteration, None
# is returned for both the state and the cost.
assert counter == 6 and cut_cost is None
6 changes: 3 additions & 3 deletions test/cutting/cut_finding/test_cut_finder_results.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
from __future__ import annotations

import numpy as np
from numpy import array
from pytest import fixture, raises
from qiskit import QuantumCircuit
from typing import Callable
Expand Down Expand Up @@ -190,8 +189,9 @@ def test_four_qubit_circuit_two_qubit_qpu(
) # circuit separated into 2 subcircuits.

assert (
optimization_pass.get_stats()["CutOptimization"] == array([15, 46, 15, 6])
).all() # matches known stats.
optimization_pass.get_stats()["CutOptimization"].backjumps
<= settings.max_backjumps
)


def test_seven_qubit_circuit_two_qubit_qpu(
Expand Down

0 comments on commit 7e1a3ab

Please sign in to comment.