Skip to content

Commit 1df9290

Browse files
justinchubyCopilot
andauthored
Always fold the Transpose node in the constant folder (#2355)
- Create an `always_fold_ops` option to allow users to specify which ops should always be folded - Refactored the FoldConstantsPass to hide internal attributes - Update logic to check for graph initialized inputs and removed the need for tracking in the object states --------- Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
1 parent c7d5786 commit 1df9290

File tree

2 files changed

+138
-85
lines changed

2 files changed

+138
-85
lines changed

onnxscript/optimizer/_constant_folding.py

Lines changed: 105 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import logging
1010
import math
1111
import typing
12-
from typing import Any, Callable, Iterable, Sequence, Union
12+
from typing import Any, Callable, Collection, Iterable, Sequence, Union
1313

1414
import numpy as np
1515
import onnx
@@ -24,12 +24,7 @@
2424
DEFAULT_CONSTANT_FOLD_OUTPUT_SIZE_LIMIT = 1024 * 1024
2525

2626

27-
def is_control_flow_op(node: ir.Node) -> bool:
28-
graph_types = {ir.AttributeType.GRAPH, ir.AttributeType.GRAPHS}
29-
return any(attr.type in graph_types for attr in node.attributes.values())
30-
31-
32-
non_deterministic_ops = frozenset(
27+
_NON_DETERMINISTIC_OPS = frozenset(
3328
{
3429
"RandomUniform",
3530
"RandomNormal",
@@ -40,21 +35,21 @@ def is_control_flow_op(node: ir.Node) -> bool:
4035
)
4136

4237

43-
def is_non_deterministic_op(node: ir.Node) -> bool:
44-
return node.op_type in non_deterministic_ops and utils.is_onnx_domain(node.domain)
38+
logger = logging.getLogger(__name__)
4539

4640

47-
def is_onnx_op(node: ir.Node, op_type: str) -> bool:
48-
return node.op_type == op_type and utils.is_onnx_domain(node.domain)
41+
def _is_control_flow_op(node: ir.Node) -> bool:
42+
graph_types = {ir.AttributeType.GRAPH, ir.AttributeType.GRAPHS}
43+
return any(attr.type in graph_types for attr in node.attributes.values())
4944

5045

51-
def is_constant_op(node: ir.Node) -> bool:
52-
return node.op_type in {"Constant", "ConstantOfShape"} and utils.is_onnx_domain(
53-
node.domain
54-
)
46+
def _is_non_deterministic_op(node: ir.Node) -> bool:
47+
return node.op_type in _NON_DETERMINISTIC_OPS and utils.is_onnx_domain(node.domain)
5548

5649

57-
logger = logging.getLogger(__name__)
50+
def _is_onnx_op(node: ir.Node, op_type: str) -> bool:
51+
return node.op_type == op_type and utils.is_onnx_domain(node.domain)
52+
5853

5954
# "Standard" evaluators are used to perform constant-folding.
6055
# The API below works only for non-control-flow ops (ops without any graph-attributes).
@@ -168,19 +163,6 @@ def get_sym_value(self, value: ir.Value | None) -> SymbolicValue | None:
168163
def set_sym_value(self, value: ir.Value, sym_value: SymbolicValue) -> None:
169164
self._sym_value_map[value] = sym_value
170165

171-
def push_initializer_inputs(self) -> None:
172-
self._initializer_inputs.append(set())
173-
174-
def pop_initializer_inputs(self) -> None:
175-
self._initializer_inputs.pop()
176-
177-
def add_initializer_input(self, value: ir.Value) -> None:
178-
assert self._initializer_inputs
179-
self._initializer_inputs[-1].add(value)
180-
181-
def is_initializer_input(self, value: ir.Value) -> bool:
182-
return any(value in inputs for inputs in self._initializer_inputs)
183-
184166
def get_shape_value(self, value: ir.Value | None) -> ir.Shape | None:
185167
const_value = _get_numpy_value(value, ir.DataType.INT64, size_limit=10)
186168
if const_value is not None:
@@ -301,6 +283,11 @@ def _get_numpy_value(
301283
array = const_value.numpy().view(const_value.dtype.numpy())
302284
except FileNotFoundError:
303285
# External data is not available.
286+
logger.warning(
287+
"External data for value '%s' is not available. "
288+
"This may lead to incorrect constant folding.",
289+
val.name,
290+
)
304291
return None
305292
assert isinstance(array, np.ndarray)
306293
return array
@@ -841,28 +828,48 @@ def merge_dims(dim1, dim2):
841828

842829

843830
class FoldConstantsPass(ir.passes.InPlacePass):
831+
"""A pass that folds constant expressions in the model.
832+
833+
Attributes:
834+
shape_inference: Whether to perform shape inference.
835+
input_size_limit: Maximum size of input tensors to fold.
836+
output_size_limit: Maximum size of output tensors to fold.
837+
always_fold_ops: Collection of op types that should always be folded.
838+
For ops from the default opset, only op_type is neede (e.g. "Transpose"),
839+
otherwise specify the domain with ``{domain}::{op_type}``.
840+
"""
841+
844842
def __init__(
845843
self,
846844
*,
847845
shape_inference: bool,
848846
input_size_limit: int,
849847
output_size_limit: int,
848+
always_fold_ops: Collection[str] = frozenset(["Transpose"]),
850849
) -> None:
851-
self._shape_inference = shape_inference
852-
self._input_size_limit = input_size_limit
853-
self._output_size_limit = output_size_limit
854-
self.opset_imports: dict[str, int] = {}
855-
self.counts: dict[str, int] = {}
856-
self.sizes: dict[str, int] = {}
857-
self.modified: bool = False
850+
self.shape_inference = shape_inference
851+
self.input_size_limit = input_size_limit
852+
self.output_size_limit = output_size_limit
853+
ops = []
854+
for name in always_fold_ops:
855+
domain, op_type = name.split("::", 1) if "::" in name else ("", name)
856+
if domain == "ai.onnx":
857+
domain = ""
858+
ops.append((domain, op_type))
859+
self.always_fold_ops: frozenset[tuple[str, str]] = frozenset(ops)
860+
861+
self._opset_imports: dict[str, int] = {}
862+
self._counts: dict[str, int] = {}
863+
self._sizes: dict[str, int] = {}
864+
self._modified: bool = False
858865
self._state = OptimizerState()
859866
self._reset()
860867

861868
def _reset(self) -> None:
862869
"""Reset internal states for a new run."""
863-
self.counts = {}
864-
self.sizes = {}
865-
self.modified = False
870+
self._counts = {}
871+
self._sizes = {}
872+
self._modified = False
866873
self._state = OptimizerState()
867874

868875
def _do_inference(self, node: ir.Node) -> None:
@@ -896,7 +903,7 @@ def get_type(value: ir.Value) -> onnx.TypeProto | None:
896903
# TODO: pass in constant values, ir_version
897904
try:
898905
schema = onnx.defs.get_schema(
899-
node.op_type, self.opset_imports[node.domain], node.domain
906+
node.op_type, self._opset_imports[node.domain], node.domain
900907
)
901908
output_types = onnx.shape_inference.infer_node_outputs(
902909
schema,
@@ -937,7 +944,7 @@ def new_constant(self, node: ir.Node, value) -> ir.Node | None:
937944
tensor.name = irvalue.name
938945
irvalue.const_value = tensor
939946

940-
if value.nbytes > self._output_size_limit:
947+
if value.nbytes > self.output_size_limit:
941948
# Handle examples like Transpose(weight) to be folded even if the size is large,
942949
# as long as weight has no other uses. This won't increase model size.
943950
removed_input_size = 0
@@ -967,6 +974,7 @@ def new_constant(self, node: ir.Node, value) -> ir.Node | None:
967974
return node
968975

969976
def process_node(self, node: ir.Node) -> Replacement | None:
977+
"""Process a node and return a Replacement if the node can be replaced."""
970978
for i, value in enumerate(node.inputs):
971979
sym_value = self._state.get_sym_value(value)
972980
if isinstance(sym_value, ir.Value):
@@ -977,16 +985,16 @@ def process_node(self, node: ir.Node) -> Replacement | None:
977985
sym_value.name,
978986
)
979987
node.replace_input_with(i, sym_value)
980-
self.modified = True
988+
self._modified = True
981989
# TODO(rama): consider merging type/other info from both values
982990

983991
# Do incremental shape inference
984-
if self._shape_inference and not is_control_flow_op(node):
992+
if self.shape_inference and not _is_control_flow_op(node):
985993
self._do_inference(node)
986994

987-
if node.domain not in self.opset_imports:
995+
if node.domain not in self._opset_imports:
988996
return None
989-
version = self.opset_imports[node.domain]
997+
version = self._opset_imports[node.domain]
990998
op_optimizers = registry.lookup_evaluators(node.domain, node.op_type, version)
991999
for optimizer in op_optimizers:
9921000
assert optimizer
@@ -999,31 +1007,58 @@ def process_node(self, node: ir.Node) -> Replacement | None:
9991007
output = [output]
10001008
return Replacement(output, context.nodes)
10011009

1002-
if is_control_flow_op(node) or is_non_deterministic_op(node):
1010+
if _is_control_flow_op(node) or _is_non_deterministic_op(node):
10031011
return None
10041012

1005-
if is_onnx_op(node, "Constant"):
1013+
if _is_onnx_op(node, "Constant"):
10061014
_process_constant_node(node)
10071015
return None
10081016

1009-
input_values = [_get_numpy_value(x) for x in node.inputs]
1010-
if any(x is None for x in input_values):
1011-
return None
1012-
1013-
if any(self._state.is_initializer_input(x) for x in node.inputs): # type: ignore[arg-type]
1017+
if any(x.is_graph_input() for x in node.inputs if x is not None):
1018+
# Do not fold any graph inputs to preserve graph signature
10141019
return None
10151020

1016-
if any(input.nbytes > self._input_size_limit for input in input_values): # type: ignore[union-attr]
1021+
# Ensure all node inputs are constants
1022+
if any(x.const_value is None for x in node.inputs if x is not None):
10171023
if logger.isEnabledFor(logging.DEBUG):
1018-
input_sizes = [input.size for input in input_values] # type: ignore[union-attr]
10191024
logger.debug(
1020-
"Skipping constant folding for op %s due to large input size: %s",
1021-
node.op_type,
1022-
input_sizes,
1025+
"Skipping constant folding for node %s because it has non-constant inputs",
1026+
node,
1027+
[x.name for x in node.inputs if x is not None],
10231028
)
10241029
return None
10251030

1026-
# Filter out bfloat16 cases?
1031+
input_tensors = [x.const_value if x is not None else None for x in node.inputs]
1032+
1033+
if any(
1034+
tensor.nbytes > self.input_size_limit
1035+
for tensor in input_tensors
1036+
if tensor is not None
1037+
):
1038+
if (node.domain, node.op_type) in self.always_fold_ops and all(
1039+
len(input.consumers()) == 1 for input in node.inputs if input is not None
1040+
):
1041+
# If the op is in always_fold_ops and all inputs are used only by this node,
1042+
# we can still fold it even if the input size exceeds the limit.
1043+
logger.debug(
1044+
"Folding large constant for node %s because it is in the always_fold_ops list",
1045+
node,
1046+
)
1047+
else:
1048+
# Skip folding large tensors
1049+
if logger.isEnabledFor(logging.DEBUG):
1050+
input_sizes = [
1051+
tensor.nbytes for tensor in input_tensors if tensor is not None
1052+
]
1053+
logger.debug(
1054+
"Skipping constant folding for node %s due to large input size: %s",
1055+
node,
1056+
input_sizes,
1057+
)
1058+
return None
1059+
1060+
input_values = [_get_numpy_value(x) for x in node.inputs]
1061+
10271062
def convert(av):
10281063
if av.type == ir.AttributeType.TENSOR:
10291064
return ir.serde.serialize_tensor(av.value)
@@ -1038,7 +1073,7 @@ def convert(av):
10381073
return None
10391074
if len(node.outputs) == 1 and not isinstance(outputs, (tuple, list)):
10401075
replacement = self.new_constant(node, outputs)
1041-
if is_onnx_op(node, "ConstantOfShape") or replacement is None:
1076+
if _is_onnx_op(node, "ConstantOfShape") or replacement is None:
10421077
return None
10431078
return Replacement(replacement.outputs, [replacement])
10441079
else:
@@ -1054,7 +1089,7 @@ def replace_node(self, node: ir.Node, replacement, root: ir.Graph | ir.Function)
10541089
root, node, [node], replacement.new_nodes, node.outputs, replacement.new_outputs
10551090
)
10561091

1057-
self.modified = True
1092+
self._modified = True
10581093

10591094
# TODO: what about new opset_imports?
10601095
# TODO: track statistics about replaced nodes and sizes of new constants
@@ -1079,13 +1114,6 @@ def visit_node(self, node: ir.Node, root: ir.Graph | ir.Function) -> None:
10791114
self.replace_node(node, replacement, root)
10801115

10811116
def visit_graph(self, graph: ir.Graph) -> None:
1082-
# Track inputs that have a const_value (which is really a default-value, and should not
1083-
# be used for constant-folding).
1084-
self._state.push_initializer_inputs()
1085-
for input in graph.inputs:
1086-
if input.const_value is not None:
1087-
self._state.add_initializer_input(input)
1088-
10891117
for node in graph:
10901118
self.visit_node(node, graph)
10911119

@@ -1103,22 +1131,20 @@ def visit_graph(self, graph: ir.Graph) -> None:
11031131
# Rename sym_value to match the output name
11041132
sym_value.name = output.name
11051133
graph.outputs[i] = sym_value
1106-
self.modified = True
1107-
1108-
self._state.pop_initializer_inputs()
1134+
self._modified = True
11091135

11101136
def visit_function(self, function: ir.Function) -> None:
11111137
for node in function:
11121138
self.visit_node(node, function)
11131139

1114-
def call(self, model: ir.Model) -> ir.passes.PassResult:
1140+
def call(self, model: ir.Model) -> FoldConstantsResult:
11151141
self._reset()
1116-
self.opset_imports = model.opset_imports
1142+
self._opset_imports = model.opset_imports
11171143
self.visit_graph(model.graph)
11181144
for function in model.functions.values():
11191145
# TODO(rama): Should we specialize functions?
11201146
self.visit_function(function)
1121-
return FoldConstantsResult(model, self.modified, self._state.symbolic_value_map)
1147+
return FoldConstantsResult(model, self._modified, self._state.symbolic_value_map)
11221148

11231149

11241150
def _sym_value_can_replace_graph_output(
@@ -1155,6 +1181,7 @@ def fold_constants(
11551181
onnx_shape_inference: bool = False,
11561182
input_size_limit: int = DEFAULT_CONSTANT_FOLD_INPUT_SIZE_LIMIT,
11571183
output_size_limit: int = DEFAULT_CONSTANT_FOLD_OUTPUT_SIZE_LIMIT,
1184+
always_fold_ops: Collection[str] = frozenset(["Transpose"]),
11581185
) -> FoldConstantsResult:
11591186
"""
11601187
Applies constant folding optimization to the model.
@@ -1169,6 +1196,10 @@ def fold_constants(
11691196
output_size_limit: The maximum size (in bytes) of output tensors
11701197
that can be stored after constant folding. Defaults to
11711198
`DEFAULT_CONSTANT_FOLD_OUTPUT_SIZE_LIMIT`.
1199+
always_fold_ops: A collection of op types that should always be folded,
1200+
regardless of their input or output sizes. For ops from the default opset,
1201+
only op_type is neede (e.g. "Transpose"), otherwise specify the domain
1202+
with ``{domain}::{op_type}``.
11721203
11731204
Returns:
11741205
An instance of `FoldConstantsResult`.
@@ -1178,5 +1209,6 @@ def fold_constants(
11781209
shape_inference=onnx_shape_inference,
11791210
input_size_limit=input_size_limit,
11801211
output_size_limit=output_size_limit,
1212+
always_fold_ops=always_fold_ops,
11811213
)
11821214
return folder_pass(model) # type: ignore[return-value]

0 commit comments

Comments
 (0)