Skip to content

Commit

Permalink
Split flatten into a pass
Browse files Browse the repository at this point in the history
  • Loading branch information
TsafrirA committed Mar 18, 2024
1 parent b6dbf59 commit f9193c1
Show file tree
Hide file tree
Showing 6 changed files with 322 additions and 265 deletions.
2 changes: 1 addition & 1 deletion qiskit/pulse/compiler/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,4 @@
"""Pass-based Qiskit pulse program compiler."""

from .passmanager import BlockTranspiler, BlockToIrCompiler
from .passes import MapMixedFrame, SetSequence, SetSchedule
from .passes import MapMixedFrame, SetSequence, SetSchedule, Flatten
1 change: 1 addition & 0 deletions qiskit/pulse/compiler/passes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,4 @@
from .map_mixed_frames import MapMixedFrame
from .set_sequence import SetSequence
from .schedule import SetSchedule
from .flatten import Flatten
96 changes: 96 additions & 0 deletions qiskit/pulse/compiler/passes/flatten.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
# This code is part of Qiskit.
#
# (C) Copyright IBM 2024.
#
# This code is licensed under the Apache License, Version 2.0. You may
# obtain a copy of this license in the LICENSE.txt file in the root directory
# of this source tree or at http://www.apache.org/licenses/LICENSE-2.0.
#
# Any modifications or derivative works of this code must retain this
# copyright notice, and modified files need to carry a notice indicating
# that they have been altered from the originals.

"""A flattening pass for Qiskit PulseIR compilation."""

from __future__ import annotations

from qiskit.pulse.compiler.basepasses import TransformationPass
from qiskit.pulse.ir import SequenceIR
from qiskit.pulse.exceptions import PulseCompilerError


class Flatten(TransformationPass):
"""Flatten ``SequenceIR`` object.
The flattening process includes breaking up nested IRs until only instructions remain.
After flattening the object will contain all instructions, timing information, and the
complete sequence graph. However, the alignment of nested IRs will be lost. Because alignment
information is essential for scheduling, flattening an unscheduled IR is not allowed.
One should apply :class:`~qiskit.pulse.compiler.passes.SetSchedule` first.
"""

def __init__(self):
"""Create new Flatten pass"""
super().__init__(target=None)

def run(
self,
passmanager_ir: SequenceIR,
) -> SequenceIR:
"""Run the pass."""

self._flatten(passmanager_ir)
return passmanager_ir

# pylint: disable=cell-var-from-loop
def _flatten(self, prog: SequenceIR) -> SequenceIR:
"""Recursively flatten the SequenceIR.
Returns:
A flattened ``SequenceIR`` object.
Raises:
PulseCompilerError: If ``prog`` is not scheduled.
"""
# TODO : Verify that the block\sub blocks are sequenced correctly?

# TODO : Consider replacing the alignment to "NullAlignment", as the original alignment
# has no meaning.

def edge_map(_x, _y, _node):
if _y == _node:
return 0
if _x == _node:
return 1
return None

if any(prog.time_table[x] is None for x in prog.sequence.node_indices() if x not in (0, 1)):
raise PulseCompilerError(
"Can not flatten unscheduled IR. Use SetSchedule pass before Flatten."
)

for ind in prog.sequence.node_indices():
if isinstance(sub_block := prog.sequence.get_node_data(ind), SequenceIR):
sub_block.flatten(inplace=True)
initial_time = prog.time_table[ind]
nodes_mapping = prog.sequence.substitute_node_with_subgraph(
ind, sub_block.sequence, lambda x, y, _: edge_map(x, y, ind)
)
if initial_time is not None:
for old_node in nodes_mapping.keys():
if old_node not in (0, 1):
prog.time_table[nodes_mapping[old_node]] = (
initial_time + sub_block.time_table[old_node]
)

del prog.time_table[ind]
prog.sequence.remove_node_retain_edges(nodes_mapping[0])
prog.sequence.remove_node_retain_edges(nodes_mapping[1])

return prog

def __hash__(self):
return hash((self.__class__.__name__,))

def __eq__(self, other):
return self.__class__.__name__ == other.__class__.__name__
64 changes: 2 additions & 62 deletions qiskit/pulse/ir/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,12 +197,8 @@ def duration(self) -> int | None:
except TypeError:
return None

def draw(self, recursive: bool = False):
def draw(self):
"""Draw the graph of the SequenceIR"""
if recursive:
draw_sequence = self.flatten().sequence
else:
draw_sequence = self.sequence

def _draw_nodes(n):
if n is SequenceIR._InNode or n is SequenceIR._OutNode:
Expand All @@ -214,66 +210,10 @@ def _draw_nodes(n):
return {"label": f"{n.__class__.__name__}" + name}

return graphviz_draw(
draw_sequence,
self.sequence,
node_attr_fn=_draw_nodes,
)

# pylint: disable=cell-var-from-loop
def flatten(self, inplace: bool = False) -> SequenceIR:
"""Recursively flatten the SequenceIR.
The flattening process includes breaking up nested IRs until only instructions remain.
The flattened object will contain all instructions, timing information, and the
complete sequence graph. However, the alignment of nested IRs will be lost. Because of
this, flattening an unscheduled IR is not allowed.
Args:
inplace: If ``True`` flatten the object itself. If ``False`` return a flattened copy.
Returns:
A flattened ``SequenceIR`` object.
Raises:
PulseError: If the IR (or nested IRs) are not scheduled.
"""
# TODO : Verify that the block\sub blocks are sequenced correctly.
if inplace:
block = self
else:
block = self.copy()

def edge_map(_x, _y, _node):
if _y == _node:
return 0
if _x == _node:
return 1
return None

if any(
block.time_table[x] is None for x in block.sequence.node_indices() if x not in (0, 1)
):
raise PulseError("Can not flatten unscheduled IR")

for ind in block.sequence.node_indices():
if isinstance(sub_block := block.sequence.get_node_data(ind), SequenceIR):
sub_block.flatten(inplace=True)
initial_time = block.time_table[ind]
nodes_mapping = block._sequence.substitute_node_with_subgraph(
ind, sub_block.sequence, lambda x, y, _: edge_map(x, y, ind)
)
if initial_time is not None:
for old_node in nodes_mapping.keys():
if old_node not in (0, 1):
block._time_table[nodes_mapping[old_node]] = (
initial_time + sub_block.time_table[old_node]
)

del block._time_table[ind]
block._sequence.remove_node_retain_edges(nodes_mapping[0])
block._sequence.remove_node_retain_edges(nodes_mapping[1])

return block

def copy(self) -> SequenceIR:
"""Create a copy of ``SequenceIR``.
Expand Down
Loading

0 comments on commit f9193c1

Please sign in to comment.