Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add flamegraph visualizations to visualize T-complexity of algorithms #732

Merged
merged 35 commits into from
Apr 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
01c1803
Add flamegraph visualizations to visualize T-complexity of algorithms
tanujkhattar Mar 2, 2024
967ac5d
Add flamegraph.pl script
tanujkhattar Mar 2, 2024
c3597ed
Fix pylint and failing tests
tanujkhattar Mar 2, 2024
358883a
Merge branch 'main' into flame_graph
tanujkhattar Mar 2, 2024
f064ca8
Use qubitization of quantum walks infrastructure for THC Select and P…
tanujkhattar Mar 2, 2024
32c32ab
Merge branch 'main' of https://github.com/quantumlib/Qualtran into fl…
tanujkhattar Mar 4, 2024
e69ff29
Merge branch 'main' of https://github.com/quantumlib/Qualtran into fl…
tanujkhattar Mar 4, 2024
36d49ce
Fix pylint error
tanujkhattar Mar 4, 2024
9de89da
Merge branch 'main' of https://github.com/quantumlib/Qualtran into fl…
tanujkhattar Mar 4, 2024
9e82868
Merge commit
tanujkhattar Mar 5, 2024
d1e1cdd
Remove redundant _t_complexity_ override for THCRotations Bloq
tanujkhattar Mar 5, 2024
8a54976
Merge branch 'main' into flame_graph
tanujkhattar Mar 5, 2024
eda30a9
Add back space to _bloq_to_cirq.py
tanujkhattar Mar 5, 2024
f7f6f7d
Merge branch 'flame_graph' of https://github.com/tanujkhattar/Qualtra…
tanujkhattar Mar 5, 2024
0ae4fe9
Fix infinite recursion bug in t_complexity protocol
tanujkhattar Mar 5, 2024
24bc6d0
Fix pylint
tanujkhattar Mar 5, 2024
b7a4fd2
Fix cirq to bloq interop to automatically replace basic gates with th…
tanujkhattar Mar 5, 2024
d68b0a1
Fix failing tests by updating T-complexity of CZPowGate
tanujkhattar Mar 5, 2024
a1e45f0
Merge branch 'main' of https://github.com/quantumlib/Qualtran into ci…
tanujkhattar Mar 5, 2024
eb9d518
Merge branch 'main' of https://github.com/quantumlib/Qualtran into fl…
tanujkhattar Mar 5, 2024
149941f
Merge branch 'cirq_to_bloq_interop_improve' into flame_graph
tanujkhattar Mar 5, 2024
34a819b
Merge branch 'main' of https://github.com/quantumlib/Qualtran into fl…
tanujkhattar Mar 6, 2024
02823e2
Revert changes to thc_compilation
tanujkhattar Mar 6, 2024
e054aeb
Merge branch 'main' of https://github.com/quantumlib/Qualtran into fl…
tanujkhattar Apr 1, 2024
f0a5198
Fix merge conflicts
tanujkhattar Apr 1, 2024
dafe2f8
Move flame graph script to third_party and use subprocess.run with a …
tanujkhattar Apr 1, 2024
7eb49db
Restore bloqs to main
tanujkhattar Apr 1, 2024
4d5980a
Fix formatting
tanujkhattar Apr 1, 2024
fd01889
Revert more unrelated changes, make improvements to the flame graph code
tanujkhattar Apr 2, 2024
6ec8a5f
Add tests for flame_graph.py
tanujkhattar Apr 2, 2024
72a1f27
Use sorted instead of set to compared unordered lists
tanujkhattar Apr 2, 2024
3f5d26e
Merge branch 'main' of https://github.com/quantumlib/Qualtran into fl…
tanujkhattar Apr 2, 2024
f932674
Merge branch 'main' of https://github.com/quantumlib/Qualtran into fl…
tanujkhattar Apr 3, 2024
0b844c6
Add lots of docstrings and a LICENSE and METADATA for third_party fla…
tanujkhattar Apr 3, 2024
6021909
Merge branch 'main' into flame_graph
tanujkhattar Apr 3, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 65 additions & 3 deletions qualtran/bloqs/chemistry/thc/thc.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
"source": [
"from qualtran import Bloq, CompositeBloq, BloqBuilder, Signature, Register\n",
"from qualtran import QBit, QInt, QUInt, QAny\n",
"from qualtran.drawing import show_bloq, show_call_graph, show_counts_sigma\n",
"from qualtran.drawing import show_bloq, show_call_graph, show_counts_sigma, show_flame_graph\n",
"from typing import *\n",
"import numpy as np\n",
"import sympy\n",
Expand Down Expand Up @@ -155,6 +155,24 @@
"show_counts_sigma(thc_uni_sigma)"
]
},
{
"cell_type": "markdown",
"id": "176027fd-e452-40d4-a526-9e1b9e86896a",
"metadata": {},
"source": [
"### Flame Graph"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "65f3c175-c87d-4bf5-8e5f-8bde322152f4",
"metadata": {},
"outputs": [],
"source": [
"show_flame_graph(thc_uni)"
]
},
{
"cell_type": "markdown",
"id": "f2cbc8e6",
Expand Down Expand Up @@ -360,6 +378,24 @@
"show_counts_sigma(thc_prep_sigma)"
]
},
{
"cell_type": "markdown",
"id": "c70647ea-256b-410c-abf7-33c9e2e12ac5",
"metadata": {},
"source": [
"### Flame Graph"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "7c5e44f2-d481-4e41-b9f3-10e9ec5de498",
"metadata": {},
"outputs": [],
"source": [
"show_flame_graph(thc_prep)"
]
},
{
"cell_type": "markdown",
"id": "85818cf5",
Expand Down Expand Up @@ -543,11 +579,37 @@
"show_call_graph(thc_sel_g)\n",
"show_counts_sigma(thc_sel_sigma)"
]
},
{
"cell_type": "markdown",
"id": "e98a837b-7df6-41b8-9685-f00817b33550",
"metadata": {},
"source": [
"### Flame Graph"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "ec11328c-f2f0-4217-8ad5-00f711d8f78c",
"metadata": {},
"outputs": [],
"source": [
"show_flame_graph(thc_sel)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "002749a5-7118-4c86-9a36-da44c8ffb3a4",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
Expand All @@ -561,7 +623,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.6"
"version": "3.11.8"
}
},
"nbformat": 4,
Expand Down
24 changes: 23 additions & 1 deletion qualtran/bloqs/phase_estimation_of_quantum_walk.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -189,12 +189,34 @@
"outputs": [],
"source": [
"from qualtran.bloqs.phase_estimation.qubitization_qpe import _qubitization_qpe_hubbard_model_large\n",
"from qualtran.drawing import show_call_graph\n",
"qpe = _qubitization_qpe_hubbard_model_large.make()\n",
"t_complexity.cache_clear()\n",
"%time result = qpe.t_complexity()\n",
"print(result)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Flame Graphs to visualize cost for QPE on Qubitized walk operator for 2D Hubbard model"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from qualtran.bloqs.phase_estimation.qubitization_qpe import _qubitization_qpe_hubbard_model_small\n",
"from qualtran.drawing import show_flame_graph\n",
"\n",
"qpe_small = _qubitization_qpe_hubbard_model_small.make()\n",
"\n",
"print(qpe_small.t_complexity())\n",
"\n",
"show_flame_graph(qpe_small)"
]
}
],
"metadata": {
Expand Down
5 changes: 3 additions & 2 deletions qualtran/bloqs/reflection_using_prepare.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,8 @@ def decompose_from_registers(
# 0. Allocate new ancillas, if needed.
phase_target = qm.qalloc(1)[0] if self.control_val is None else quregs.pop('control')[0]
state_prep_ancilla = {
reg.name: qm.qalloc(reg.total_bits()) for reg in self.prepare_gate.junk_registers
reg.name: np.array(qm.qalloc(reg.total_bits())).reshape(reg.shape + (reg.bitsize,))
for reg in self.prepare_gate.junk_registers
}
state_prep_selection_regs = quregs
prepare_op = self.prepare_gate.on_registers(
Expand All @@ -113,7 +114,7 @@ def decompose_from_registers(
yield prepare_op

# 4. Deallocate ancilla.
qm.qfree([q for anc in state_prep_ancilla.values() for q in anc])
qm.qfree([q for anc in state_prep_ancilla.values() for q in anc.flatten()])
if self.control_val is None:
qm.qfree([phase_target])

Expand Down
2 changes: 1 addition & 1 deletion qualtran/drawing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,4 +41,4 @@

from .bloq_counts_graph import GraphvizCounts, format_counts_sigma, format_counts_graph_markdown

from ._show_funcs import show_bloq, show_bloqs, show_call_graph, show_counts_sigma
from ._show_funcs import show_bloq, show_bloqs, show_call_graph, show_counts_sigma, show_flame_graph
7 changes: 7 additions & 0 deletions qualtran/drawing/_show_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import ipywidgets

from .bloq_counts_graph import format_counts_sigma, GraphvizCounts
from .flame_graph import get_flame_graph_svg_data
from .graphviz import PrettyGraphDrawer, TypedGraphDrawer
from .musical_score import draw_musical_score, get_musical_score_data

Expand Down Expand Up @@ -77,3 +78,9 @@ def show_call_graph(g: 'nx.DiGraph') -> None:
def show_counts_sigma(sigma: Dict['Bloq', Union[int, 'sympy.Expr']]):
"""Display nicely formatted bloq counts sums `sigma`."""
IPython.display.display(IPython.display.Markdown(format_counts_sigma(sigma)))


def show_flame_graph(*bloqs: 'Bloq', **kwargs):
"""Display hiearchical decomposition and T-complexity costs as a Flame Graph."""
svg_data = get_flame_graph_svg_data(*bloqs, **kwargs)
IPython.display.display(IPython.display.SVG(svg_data))
192 changes: 192 additions & 0 deletions qualtran/drawing/flame_graph.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,192 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Classes for drawing bloqs with FlameGraph. This relies on third party flamegraph.pl"""
import functools
import pathlib
import subprocess
import tempfile
from typing import Any, List, Optional, Sequence, Union

import networkx as nx
import numpy as np

from qualtran import Bloq
from qualtran.resource_counting.bloq_counts import _compute_sigma
from qualtran.resource_counting.t_counts_from_sigma import t_counts_from_sigma


def _pretty_arg(val: Any) -> str:
if isinstance(val, (tuple, np.ndarray)):
return f'{val.shape if isinstance(val, np.ndarray) else len(val)}'
if isinstance(val, Bloq):
return _pretty_name(val)
if isinstance(val, float):
if np.isclose(val, 0):
val = 0
return f'{val:0.2g}'
return f'{val}'


def _pretty_name(bloq: Bloq) -> str:
"""Hack to get a reasonably concise, reasonably informative description of a bloq.

This should be removed once we have a better way to get a descriptive and concise
representation for a bloq. See https://github.com/quantumlib/Qualtran/issues/791
"""

from qualtran.serialization.bloq import _iter_fields

ret = bloq.pretty_name()
if bloq.pretty_name.__qualname__.startswith('Bloq.'):
for field in _iter_fields(bloq):
ret += f'[{_pretty_arg(getattr(bloq, field.name))}]'
return ret


@functools.lru_cache(maxsize=1024)
def _t_counts_for_bloq(bloq: Bloq, graph: nx.DiGraph) -> int:
sigma = _compute_sigma(bloq, graph)
return t_counts_from_sigma(sigma)
Comment on lines +60 to +61
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why not t_complexity(bloq).t?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It doesn't sum up costs for rotation bloqs anymore, since we moved ceil(1.149 * log2(1.0 / eps) + 9.2) to TComplexity.rotation_cost and but we don't track the eps in the TComplexity object.



def _keep_if_small(bloq: Bloq) -> bool:
from qualtran.bloqs.basic_gates import Toffoli, TwoBitCSwap
from qualtran.bloqs.mcmt.and_bloq import And

if isinstance(bloq, (And, Toffoli, TwoBitCSwap)):
return True


def _is_leaf_node(callees: List[Bloq]) -> bool:
from qualtran.bloqs.basic_gates import TGate

return len(callees) == 0 or (
len(callees) == 1 and callees[0] in [TGate(), TGate(is_adjoint=True)]
)


def _populate_flame_graph_data(
bloq: Bloq, graph: nx.DiGraph, graph_t: nx.DiGraph, prefix: List[str]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

document arguments -- prefix is modified

) -> List[str]:
"""Populates data for the flame graph.

Args:
bloq: Bloq to get the flame graph data for.
graph: Callgraph of `bloq` with custom kwargs so users can control the depth / leaf nodes
for the flame graph. This is the graph we do a DFS ordering on.
graph_t: Callgraph to use to derive T-complexity of the Bloq. Ideally, we should just be able
to invoke the `bloq.t_complexity().t` but this hides the T-costs due to rotations, so we
use a temporary solution to invoke `_t_counts_for_bloq(bloq, graph_t)`. The graph is not
mutated over the course of DFS and hence can be used as a hash key for better performance.
prefix: A string that represents the bloqs visited in the path so far. This is mutated as the
graph is traversed and represents the current path from the root node to the current node
in the DFS traversal. This is used to populate the flame graph data once we hit leaf nodes
in `graph`.

Returns:
The Flame graph data for the subgraph of `graph` for which `bloq` is a root node.
"""

callees = [x for x in list(graph.successors(bloq)) if _t_counts_for_bloq(x, graph_t) > 0]
total_t_counts = _t_counts_for_bloq(bloq, graph_t)
prefix.append(_pretty_name(bloq) + f'(T:{total_t_counts})')
data = []
if _is_leaf_node(callees):
data += [';'.join(prefix) + '\t' + str(total_t_counts)]
else:
succ_t_counts = 0
for succ in callees:
curr_data = _populate_flame_graph_data(succ, graph, graph_t, prefix)
succ_t_counts += (
_t_counts_for_bloq(succ, graph_t) * graph.get_edge_data(bloq, succ)['n']
)
data += curr_data * graph.get_edge_data(bloq, succ)['n']
# TODO: This assertion should be enabled once, for each bloq, we verify that
# `assert_equivalent_bloq_example_counts` is True for `TGate`. This is currently not True
# and is tracked in https://github.com/quantumlib/Qualtran/issues/858
# assert total_t_counts == succ_t_counts, f'{bloq=}, {total_t_counts=}, {succ_t_counts=}'
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

uncomment or remove

prefix.pop()
return data


def get_flame_graph_data(
*bloqs: Bloq,
file_path: Union[None, pathlib.Path, str] = None,
keep: Optional[Sequence['Bloq']] = _keep_if_small,
**kwargs,
) -> List[str]:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

docstring

"""Get the flame graph data for visualizing T-costs distribution of a sequence of bloqs.

For each bloq in the input, this will do a DFS ordering over all edges in the DAG and
add an entry corresponding to each leaf node in the call graph. The string representation
added for a leaf node encodes the entire path taken from the root node to the leaf node
and is repeated a number of times that's equivalent to the weight of that path. Thus, the
length of the output would be roughly equal to the number of T-gates in the Bloq and can be
very high. If you want to limit the output size, consider specifying a `keep` predicate where
the leaf nodes are higher level Bloqs with a larger T-count weight.

Args:
bloqs: Bloqs to plot the flame graph for
file_path: If specified, the output is stored at the file.
keep: A predicate to determine the leaf nodes in the call graph. The flame graph would use
these Bloqs as leaf nodes and thus would not contain decompositions for these nodes.
**kwargs: Additional arguments to be passed to `bloq.call_graph`, like generalizers etc.

Returns:
A list of strings, one for each path from root node to the leaf node in the call graph x
the weight of the path, that can be passed to the `third_party/flame_graph/flame_graph.pl`
script.
"""
from qualtran.resource_counting.generalizers import cirq_to_bloqs

data = []
for bloq in bloqs:
call_graph, _ = bloq.call_graph(keep=keep, **kwargs, generalizer=cirq_to_bloqs)
call_graph_t_counts, _ = bloq.call_graph()
data += _populate_flame_graph_data(bloq, call_graph, call_graph_t_counts, prefix=[])
if file_path:
with open(file_path, 'w') as f:
f.write('\n'.join(data))
else:
return data


def get_flame_graph_svg_data(
*bloqs: Bloq, file_path: Union[None, pathlib.Path, str] = None, **kwargs
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

docstring

) -> Optional[str]:
"""Invokes the `third_party/flamegraph/flamegraph.pl` using data from `get_flame_graph_data`."""

data = get_flame_graph_data(*bloqs, **kwargs)

data_file = tempfile.NamedTemporaryFile(mode='w')
flame_graph_path = (
pathlib.Path(__file__).resolve().parent.parent / "third_party/flamegraph/flamegraph.pl"
)

data_file.write('\n'.join(data))
data_file.flush()
svg_data = subprocess.run(
[flame_graph_path, "--countname", "TCounts", f'{data_file.name}'],
capture_output=True,
text=True,
check=True,
).stdout
data_file.close()

if file_path:
with open(file_path, 'w') as f:
f.write(svg_data)
else:
return svg_data
Loading
Loading