9
9
import logging
10
10
import math
11
11
import typing
12
- from typing import Any , Callable , Iterable , Sequence , Union
12
+ from typing import Any , Callable , Collection , Iterable , Sequence , Union
13
13
14
14
import numpy as np
15
15
import onnx
24
24
DEFAULT_CONSTANT_FOLD_OUTPUT_SIZE_LIMIT = 1024 * 1024
25
25
26
26
27
- def is_control_flow_op (node : ir .Node ) -> bool :
28
- graph_types = {ir .AttributeType .GRAPH , ir .AttributeType .GRAPHS }
29
- return any (attr .type in graph_types for attr in node .attributes .values ())
30
-
31
-
32
- non_deterministic_ops = frozenset (
27
+ _NON_DETERMINISTIC_OPS = frozenset (
33
28
{
34
29
"RandomUniform" ,
35
30
"RandomNormal" ,
@@ -40,21 +35,21 @@ def is_control_flow_op(node: ir.Node) -> bool:
40
35
)
41
36
42
37
43
- def is_non_deterministic_op (node : ir .Node ) -> bool :
44
- return node .op_type in non_deterministic_ops and utils .is_onnx_domain (node .domain )
38
+ logger = logging .getLogger (__name__ )
45
39
46
40
47
- def is_onnx_op (node : ir .Node , op_type : str ) -> bool :
48
- return node .op_type == op_type and utils .is_onnx_domain (node .domain )
41
+ def _is_control_flow_op (node : ir .Node ) -> bool :
42
+ graph_types = {ir .AttributeType .GRAPH , ir .AttributeType .GRAPHS }
43
+ return any (attr .type in graph_types for attr in node .attributes .values ())
49
44
50
45
51
- def is_constant_op (node : ir .Node ) -> bool :
52
- return node .op_type in {"Constant" , "ConstantOfShape" } and utils .is_onnx_domain (
53
- node .domain
54
- )
46
+ def _is_non_deterministic_op (node : ir .Node ) -> bool :
47
+ return node .op_type in _NON_DETERMINISTIC_OPS and utils .is_onnx_domain (node .domain )
55
48
56
49
57
- logger = logging .getLogger (__name__ )
50
+ def _is_onnx_op (node : ir .Node , op_type : str ) -> bool :
51
+ return node .op_type == op_type and utils .is_onnx_domain (node .domain )
52
+
58
53
59
54
# "Standard" evaluators are used to perform constant-folding.
60
55
# The API below works only for non-control-flow ops (ops without any graph-attributes).
@@ -168,19 +163,6 @@ def get_sym_value(self, value: ir.Value | None) -> SymbolicValue | None:
168
163
def set_sym_value (self , value : ir .Value , sym_value : SymbolicValue ) -> None :
169
164
self ._sym_value_map [value ] = sym_value
170
165
171
- def push_initializer_inputs (self ) -> None :
172
- self ._initializer_inputs .append (set ())
173
-
174
- def pop_initializer_inputs (self ) -> None :
175
- self ._initializer_inputs .pop ()
176
-
177
- def add_initializer_input (self , value : ir .Value ) -> None :
178
- assert self ._initializer_inputs
179
- self ._initializer_inputs [- 1 ].add (value )
180
-
181
- def is_initializer_input (self , value : ir .Value ) -> bool :
182
- return any (value in inputs for inputs in self ._initializer_inputs )
183
-
184
166
def get_shape_value (self , value : ir .Value | None ) -> ir .Shape | None :
185
167
const_value = _get_numpy_value (value , ir .DataType .INT64 , size_limit = 10 )
186
168
if const_value is not None :
@@ -301,6 +283,11 @@ def _get_numpy_value(
301
283
array = const_value .numpy ().view (const_value .dtype .numpy ())
302
284
except FileNotFoundError :
303
285
# External data is not available.
286
+ logger .warning (
287
+ "External data for value '%s' is not available. "
288
+ "This may lead to incorrect constant folding." ,
289
+ val .name ,
290
+ )
304
291
return None
305
292
assert isinstance (array , np .ndarray )
306
293
return array
@@ -841,28 +828,48 @@ def merge_dims(dim1, dim2):
841
828
842
829
843
830
class FoldConstantsPass (ir .passes .InPlacePass ):
831
+ """A pass that folds constant expressions in the model.
832
+
833
+ Attributes:
834
+ shape_inference: Whether to perform shape inference.
835
+ input_size_limit: Maximum size of input tensors to fold.
836
+ output_size_limit: Maximum size of output tensors to fold.
837
+ always_fold_ops: Collection of op types that should always be folded.
838
+ For ops from the default opset, only op_type is neede (e.g. "Transpose"),
839
+ otherwise specify the domain with ``{domain}::{op_type}``.
840
+ """
841
+
844
842
def __init__ (
845
843
self ,
846
844
* ,
847
845
shape_inference : bool ,
848
846
input_size_limit : int ,
849
847
output_size_limit : int ,
848
+ always_fold_ops : Collection [str ] = frozenset (["Transpose" ]),
850
849
) -> None :
851
- self ._shape_inference = shape_inference
852
- self ._input_size_limit = input_size_limit
853
- self ._output_size_limit = output_size_limit
854
- self .opset_imports : dict [str , int ] = {}
855
- self .counts : dict [str , int ] = {}
856
- self .sizes : dict [str , int ] = {}
857
- self .modified : bool = False
850
+ self .shape_inference = shape_inference
851
+ self .input_size_limit = input_size_limit
852
+ self .output_size_limit = output_size_limit
853
+ ops = []
854
+ for name in always_fold_ops :
855
+ domain , op_type = name .split ("::" , 1 ) if "::" in name else ("" , name )
856
+ if domain == "ai.onnx" :
857
+ domain = ""
858
+ ops .append ((domain , op_type ))
859
+ self .always_fold_ops : frozenset [tuple [str , str ]] = frozenset (ops )
860
+
861
+ self ._opset_imports : dict [str , int ] = {}
862
+ self ._counts : dict [str , int ] = {}
863
+ self ._sizes : dict [str , int ] = {}
864
+ self ._modified : bool = False
858
865
self ._state = OptimizerState ()
859
866
self ._reset ()
860
867
861
868
def _reset (self ) -> None :
862
869
"""Reset internal states for a new run."""
863
- self .counts = {}
864
- self .sizes = {}
865
- self .modified = False
870
+ self ._counts = {}
871
+ self ._sizes = {}
872
+ self ._modified = False
866
873
self ._state = OptimizerState ()
867
874
868
875
def _do_inference (self , node : ir .Node ) -> None :
@@ -896,7 +903,7 @@ def get_type(value: ir.Value) -> onnx.TypeProto | None:
896
903
# TODO: pass in constant values, ir_version
897
904
try :
898
905
schema = onnx .defs .get_schema (
899
- node .op_type , self .opset_imports [node .domain ], node .domain
906
+ node .op_type , self ._opset_imports [node .domain ], node .domain
900
907
)
901
908
output_types = onnx .shape_inference .infer_node_outputs (
902
909
schema ,
@@ -937,7 +944,7 @@ def new_constant(self, node: ir.Node, value) -> ir.Node | None:
937
944
tensor .name = irvalue .name
938
945
irvalue .const_value = tensor
939
946
940
- if value .nbytes > self ._output_size_limit :
947
+ if value .nbytes > self .output_size_limit :
941
948
# Handle examples like Transpose(weight) to be folded even if the size is large,
942
949
# as long as weight has no other uses. This won't increase model size.
943
950
removed_input_size = 0
@@ -967,6 +974,7 @@ def new_constant(self, node: ir.Node, value) -> ir.Node | None:
967
974
return node
968
975
969
976
def process_node (self , node : ir .Node ) -> Replacement | None :
977
+ """Process a node and return a Replacement if the node can be replaced."""
970
978
for i , value in enumerate (node .inputs ):
971
979
sym_value = self ._state .get_sym_value (value )
972
980
if isinstance (sym_value , ir .Value ):
@@ -977,16 +985,16 @@ def process_node(self, node: ir.Node) -> Replacement | None:
977
985
sym_value .name ,
978
986
)
979
987
node .replace_input_with (i , sym_value )
980
- self .modified = True
988
+ self ._modified = True
981
989
# TODO(rama): consider merging type/other info from both values
982
990
983
991
# Do incremental shape inference
984
- if self ._shape_inference and not is_control_flow_op (node ):
992
+ if self .shape_inference and not _is_control_flow_op (node ):
985
993
self ._do_inference (node )
986
994
987
- if node .domain not in self .opset_imports :
995
+ if node .domain not in self ._opset_imports :
988
996
return None
989
- version = self .opset_imports [node .domain ]
997
+ version = self ._opset_imports [node .domain ]
990
998
op_optimizers = registry .lookup_evaluators (node .domain , node .op_type , version )
991
999
for optimizer in op_optimizers :
992
1000
assert optimizer
@@ -999,31 +1007,58 @@ def process_node(self, node: ir.Node) -> Replacement | None:
999
1007
output = [output ]
1000
1008
return Replacement (output , context .nodes )
1001
1009
1002
- if is_control_flow_op (node ) or is_non_deterministic_op (node ):
1010
+ if _is_control_flow_op (node ) or _is_non_deterministic_op (node ):
1003
1011
return None
1004
1012
1005
- if is_onnx_op (node , "Constant" ):
1013
+ if _is_onnx_op (node , "Constant" ):
1006
1014
_process_constant_node (node )
1007
1015
return None
1008
1016
1009
- input_values = [_get_numpy_value (x ) for x in node .inputs ]
1010
- if any (x is None for x in input_values ):
1011
- return None
1012
-
1013
- if any (self ._state .is_initializer_input (x ) for x in node .inputs ): # type: ignore[arg-type]
1017
+ if any (x .is_graph_input () for x in node .inputs if x is not None ):
1018
+ # Do not fold any graph inputs to preserve graph signature
1014
1019
return None
1015
1020
1016
- if any (input .nbytes > self ._input_size_limit for input in input_values ): # type: ignore[union-attr]
1021
+ # Ensure all node inputs are constants
1022
+ if any (x .const_value is None for x in node .inputs if x is not None ):
1017
1023
if logger .isEnabledFor (logging .DEBUG ):
1018
- input_sizes = [input .size for input in input_values ] # type: ignore[union-attr]
1019
1024
logger .debug (
1020
- "Skipping constant folding for op %s due to large input size: %s " ,
1021
- node . op_type ,
1022
- input_sizes ,
1025
+ "Skipping constant folding for node %s because it has non-constant inputs " ,
1026
+ node ,
1027
+ [ x . name for x in node . inputs if x is not None ] ,
1023
1028
)
1024
1029
return None
1025
1030
1026
- # Filter out bfloat16 cases?
1031
+ input_tensors = [x .const_value if x is not None else None for x in node .inputs ]
1032
+
1033
+ if any (
1034
+ tensor .nbytes > self .input_size_limit
1035
+ for tensor in input_tensors
1036
+ if tensor is not None
1037
+ ):
1038
+ if (node .domain , node .op_type ) in self .always_fold_ops and all (
1039
+ len (input .consumers ()) == 1 for input in node .inputs if input is not None
1040
+ ):
1041
+ # If the op is in always_fold_ops and all inputs are used only by this node,
1042
+ # we can still fold it even if the input size exceeds the limit.
1043
+ logger .debug (
1044
+ "Folding large constant for node %s because it is in the always_fold_ops list" ,
1045
+ node ,
1046
+ )
1047
+ else :
1048
+ # Skip folding large tensors
1049
+ if logger .isEnabledFor (logging .DEBUG ):
1050
+ input_sizes = [
1051
+ tensor .nbytes for tensor in input_tensors if tensor is not None
1052
+ ]
1053
+ logger .debug (
1054
+ "Skipping constant folding for node %s due to large input size: %s" ,
1055
+ node ,
1056
+ input_sizes ,
1057
+ )
1058
+ return None
1059
+
1060
+ input_values = [_get_numpy_value (x ) for x in node .inputs ]
1061
+
1027
1062
def convert (av ):
1028
1063
if av .type == ir .AttributeType .TENSOR :
1029
1064
return ir .serde .serialize_tensor (av .value )
@@ -1038,7 +1073,7 @@ def convert(av):
1038
1073
return None
1039
1074
if len (node .outputs ) == 1 and not isinstance (outputs , (tuple , list )):
1040
1075
replacement = self .new_constant (node , outputs )
1041
- if is_onnx_op (node , "ConstantOfShape" ) or replacement is None :
1076
+ if _is_onnx_op (node , "ConstantOfShape" ) or replacement is None :
1042
1077
return None
1043
1078
return Replacement (replacement .outputs , [replacement ])
1044
1079
else :
@@ -1054,7 +1089,7 @@ def replace_node(self, node: ir.Node, replacement, root: ir.Graph | ir.Function)
1054
1089
root , node , [node ], replacement .new_nodes , node .outputs , replacement .new_outputs
1055
1090
)
1056
1091
1057
- self .modified = True
1092
+ self ._modified = True
1058
1093
1059
1094
# TODO: what about new opset_imports?
1060
1095
# TODO: track statistics about replaced nodes and sizes of new constants
@@ -1079,13 +1114,6 @@ def visit_node(self, node: ir.Node, root: ir.Graph | ir.Function) -> None:
1079
1114
self .replace_node (node , replacement , root )
1080
1115
1081
1116
def visit_graph (self , graph : ir .Graph ) -> None :
1082
- # Track inputs that have a const_value (which is really a default-value, and should not
1083
- # be used for constant-folding).
1084
- self ._state .push_initializer_inputs ()
1085
- for input in graph .inputs :
1086
- if input .const_value is not None :
1087
- self ._state .add_initializer_input (input )
1088
-
1089
1117
for node in graph :
1090
1118
self .visit_node (node , graph )
1091
1119
@@ -1103,22 +1131,20 @@ def visit_graph(self, graph: ir.Graph) -> None:
1103
1131
# Rename sym_value to match the output name
1104
1132
sym_value .name = output .name
1105
1133
graph .outputs [i ] = sym_value
1106
- self .modified = True
1107
-
1108
- self ._state .pop_initializer_inputs ()
1134
+ self ._modified = True
1109
1135
1110
1136
def visit_function (self , function : ir .Function ) -> None :
1111
1137
for node in function :
1112
1138
self .visit_node (node , function )
1113
1139
1114
- def call (self , model : ir .Model ) -> ir . passes . PassResult :
1140
+ def call (self , model : ir .Model ) -> FoldConstantsResult :
1115
1141
self ._reset ()
1116
- self .opset_imports = model .opset_imports
1142
+ self ._opset_imports = model .opset_imports
1117
1143
self .visit_graph (model .graph )
1118
1144
for function in model .functions .values ():
1119
1145
# TODO(rama): Should we specialize functions?
1120
1146
self .visit_function (function )
1121
- return FoldConstantsResult (model , self .modified , self ._state .symbolic_value_map )
1147
+ return FoldConstantsResult (model , self ._modified , self ._state .symbolic_value_map )
1122
1148
1123
1149
1124
1150
def _sym_value_can_replace_graph_output (
@@ -1155,6 +1181,7 @@ def fold_constants(
1155
1181
onnx_shape_inference : bool = False ,
1156
1182
input_size_limit : int = DEFAULT_CONSTANT_FOLD_INPUT_SIZE_LIMIT ,
1157
1183
output_size_limit : int = DEFAULT_CONSTANT_FOLD_OUTPUT_SIZE_LIMIT ,
1184
+ always_fold_ops : Collection [str ] = frozenset (["Transpose" ]),
1158
1185
) -> FoldConstantsResult :
1159
1186
"""
1160
1187
Applies constant folding optimization to the model.
@@ -1169,6 +1196,10 @@ def fold_constants(
1169
1196
output_size_limit: The maximum size (in bytes) of output tensors
1170
1197
that can be stored after constant folding. Defaults to
1171
1198
`DEFAULT_CONSTANT_FOLD_OUTPUT_SIZE_LIMIT`.
1199
+ always_fold_ops: A collection of op types that should always be folded,
1200
+ regardless of their input or output sizes. For ops from the default opset,
1201
+ only op_type is neede (e.g. "Transpose"), otherwise specify the domain
1202
+ with ``{domain}::{op_type}``.
1172
1203
1173
1204
Returns:
1174
1205
An instance of `FoldConstantsResult`.
@@ -1178,5 +1209,6 @@ def fold_constants(
1178
1209
shape_inference = onnx_shape_inference ,
1179
1210
input_size_limit = input_size_limit ,
1180
1211
output_size_limit = output_size_limit ,
1212
+ always_fold_ops = always_fold_ops ,
1181
1213
)
1182
1214
return folder_pass (model ) # type: ignore[return-value]
0 commit comments