Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion onnxscript/optimizer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,9 @@ def optimize(model: ir.Model, *args, **kwargs) -> ir.Model:
return legacy_optimizer.optimize(model, *args, **kwargs)


def fold_constants(model: ir.Model | onnx.ModelProto, *args, **kwargs) -> bool:
def fold_constants(
model: ir.Model | onnx.ModelProto, *args, **kwargs
) -> constant_folding.FoldConstantsResult | bool:
if isinstance(model, ir.Model):
return constant_folding.fold_constants(model, *args, **kwargs)
else:
Expand Down
54 changes: 39 additions & 15 deletions onnxscript/optimizer/_constant_folding.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,18 +148,24 @@
# Currently, we assume that symbolic dimensions are also guaranteed to be non-negative.
# TODO: Add support for negative symbolic dimensions.

SymbolicValue = Union[ir.Value, list[ir.Value], ir.Shape]


class OptimizerState:
def __init__(self):
self._sym_value_map: dict[ir.Value, Any] = {}
self._sym_value_map: dict[ir.Value, SymbolicValue] = {}
self._initializer_inputs: list[set[ir.Value]] = []

def get_sym_value(self, value: ir.Value | None) -> Any:
@property
def symbolic_value_map(self) -> dict[ir.Value, SymbolicValue]:
return self._sym_value_map

def get_sym_value(self, value: ir.Value | None) -> SymbolicValue | None:
if value is None:
return None
return self._sym_value_map.get(value)

def set_sym_value(self, value: ir.Value, sym_value: Any) -> None:
def set_sym_value(self, value: ir.Value, sym_value: SymbolicValue) -> None:
self._sym_value_map[value] = sym_value

def push_initializer_inputs(self) -> None:
Expand Down Expand Up @@ -1094,7 +1100,17 @@
for function in model.functions.values():
# TODO(rama): Should we specialize functions?
self.visit_function(function)
return ir.passes.PassResult(model, self.modified)
return FoldConstantsResult(model, self.modified, self._state.symbolic_value_map)


@dataclasses.dataclass
class FoldConstantsResult(ir.passes.PassResult):
symbolic_value_map: dict[ir.Value, SymbolicValue]

# Add conversion to bool for backward compatibility. The previously returned value
# for the fold_constants method was a boolean indicating whether the model was modified.
def __bool__(self) -> bool:
return self.modified

Check warning on line 1113 in onnxscript/optimizer/_constant_folding.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/optimizer/_constant_folding.py#L1113

Added line #L1113 was not covered by tests


def fold_constants(
Expand All @@ -1104,23 +1120,31 @@
onnx_shape_inference: bool = False,
input_size_limit: int = DEFAULT_CONSTANT_FOLD_INPUT_SIZE_LIMIT,
output_size_limit: int = DEFAULT_CONSTANT_FOLD_OUTPUT_SIZE_LIMIT,
) -> bool:
) -> FoldConstantsResult:
"""
Applies constant folding optimization to the model.
Returns true iff the model was modified.

Args:
model: The ONNX model to optimize.
external_data_folder: Path to the folder containing external data
for the model. Defaults to an empty string.
onnx_shape_inference: Whether to enable ONNX shape inference during
constant folding. Defaults to False.
input_size_limit: The maximum size (in bytes) of input tensors
that can be considered for constant folding. Defaults to
`DEFAULT_CONSTANT_FOLD_INPUT_SIZE_LIMIT`.
output_size_limit: The maximum size (in bytes) of output tensors
that can be stored after constant folding. Defaults to
`DEFAULT_CONSTANT_FOLD_OUTPUT_SIZE_LIMIT`.

Returns:
An instance of `FoldConstantsResult`.

"""
folder_pass = FoldConstantsPass(
external_data_folder=external_data_folder,
shape_inference=onnx_shape_inference,
input_size_limit=input_size_limit,
output_size_limit=output_size_limit,
)
folder_pass(model)
for op in folder_pass.counts:
logger.info(
"Constant-folded '%s' %s times, with %s size.",
op,
folder_pass.counts[op],
folder_pass.sizes[op],
)
return folder_pass.modified
return folder_pass(model) # type: ignore[return-value]
Loading