Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 61 additions & 44 deletions onnxscript/optimizer/_constant_folding.py
Original file line number Diff line number Diff line change
Expand Up @@ -935,6 +935,8 @@ def _record_contributing_values(original_node: ir.Node, replacement: Replacement
assert input.name is not None
folded_from.add(input.name)



for new_output in replacement.new_outputs:
if new_output is None:
continue
Expand All @@ -959,9 +961,9 @@ class FoldConstantsPass(ir.passes.InPlacePass):
def __init__(
self,
*,
shape_inference: bool,
input_size_limit: int,
output_size_limit: int,
shape_inference: bool = True,
input_size_limit: int = DEFAULT_CONSTANT_FOLD_INPUT_SIZE_LIMIT,
output_size_limit: int = DEFAULT_CONSTANT_FOLD_OUTPUT_SIZE_LIMIT,
should_fold: Callable[[ir.Node], bool | None] = lambda node: None,
) -> None:
self.shape_inference = shape_inference
Expand Down Expand Up @@ -1038,51 +1040,34 @@ def get_type(value: ir.Value) -> onnx.TypeProto | None:
e,
)

def new_constant(self, node: ir.Node, value) -> ir.Node | None:
irvalue = node.outputs[0]
if not isinstance(value, np.ndarray):
def new_initializer(self, old_value, array) -> ir.Value | None:
if not isinstance(array, np.ndarray):
# ONNX does not have a way to represent non-tensor constants, eg. a sequence.
# So, a constant-value of type sequence is not folded, but it can be used
# to optimize subsequent operations when possible.
logger.info(
"Skip storing constant folded value %s due to unsupported type %s.",
irvalue.name,
type(value),
old_value.name,
type(array),
)
return None

tensor = ir.tensor(value)
tensor.name = irvalue.name
irvalue.const_value = tensor

if value.size > self.output_size_limit:
# Handle examples like Transpose(weight) to be folded even if the size is large,
# as long as weight has no other uses. This won't increase model size.
removed_input_size = 0
for input in node.inputs:
if (input is not None) and (len(input.uses()) == 1):
array = _get_numpy_value(input)
if array is not None:
removed_input_size += array.size
increased_size = value.size - removed_input_size
if increased_size > 0:
logger.info(
"Skip storing constant folded nvalue %s due to large size %s.",
irvalue.name,
value.size,
)
return None
tensor = ir.tensor(array)
tensor.name = old_value.name
new_value = ir.Value(
name=old_value.name,
type=ir.TensorType(ir.DataType(tensor.dtype)),
shape=tensor.shape,
const_value=tensor,
)

logger.debug(
"New constant for value %s dtype: %s shape: %s",
irvalue.name,
value.dtype,
value.shape,
old_value.name,
new_value.dtype,
new_value.shape,
)

attributes = ir.convenience.convert_attributes({"value": tensor})
node = ir.Node("", "Constant", inputs=[], attributes=attributes, num_outputs=1)
return node
return new_value

def process_node(self, node: ir.Node) -> Replacement | None:
"""Process a node and return a Replacement if the node can be replaced."""
Expand Down Expand Up @@ -1221,16 +1206,48 @@ def convert(av):

if outputs is None:
return None
if len(node.outputs) == 1 and not isinstance(outputs, (tuple, list)):
replacement = self.new_constant(node, outputs)
if replacement is None:
return None
return Replacement(replacement.outputs, [replacement])
else:

if not isinstance(outputs, (list, tuple)):
outputs = [outputs]

if len(outputs) != len(node.outputs):
logger.warning(
"Skipping constant folding for op %s with multiple outputs.", node.op_type
"Skipping constant folding for op %s because number of outputs do not match: %d => %d",
node.op_type,
len(node.outputs),
len(outputs),
)
return None
return None

# Whether we will fold the node regardless of sizes
can_ignore_output_limit = (
should_fold is True or (node.domain, node.op_type) in _DEFAULT_ALWAYS_FOLD_OPS
)
replacement_values: list[ir.Value] = []
for i, array in enumerate(outputs):
new_initializer = self.new_initializer(node.outputs[i], array)
if new_initializer is None:
# Could not create a new initializer for the output
return None
if (
new_initializer.const_value.size > self.output_size_limit
and not can_ignore_output_limit
):
logger.info(
"Skipping constant folding for node %r because output size %d exceeds limit %d",
node.name,
new_initializer.const_value.size,
self.output_size_limit,
)
return None
assert new_initializer.const_value is not None
replacement_values.append(new_initializer)

for value in replacement_values:
assert node.graph is not None
node.graph.initializers.add(value)

return Replacement(replacement_values, [])

def replace_node(
self, node: ir.Node, replacement: Replacement, root: ir.Graph | ir.Function
Expand Down
Loading