From 1adca713be88c8b0947a67a510e02c42ffd5e69e Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Sun, 29 Jun 2025 15:43:16 -0700 Subject: [PATCH 01/31] Scaffold Signed-off-by: Justin Chu --- src/onnx_ir/_shape_inference/__init__.py | 0 src/onnx_ir/_shape_inference/_inferencer.py | 3473 +++++++++++++++++++ 2 files changed, 3473 insertions(+) create mode 100644 src/onnx_ir/_shape_inference/__init__.py create mode 100644 src/onnx_ir/_shape_inference/_inferencer.py diff --git a/src/onnx_ir/_shape_inference/__init__.py b/src/onnx_ir/_shape_inference/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/onnx_ir/_shape_inference/_inferencer.py b/src/onnx_ir/_shape_inference/_inferencer.py new file mode 100644 index 00000000..8538bfe0 --- /dev/null +++ b/src/onnx_ir/_shape_inference/_inferencer.py @@ -0,0 +1,3473 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import argparse +import logging + +import numpy as np +import onnx +import sympy +from onnx import helper, numpy_helper, shape_inference + +logger = logging.getLogger(__name__) + + +def get_attribute(node, attr_name, default_value=None): + """Retrieve the value of an attribute from an ONNX node, returning a default if the attribute is not found.""" + found = [attr for attr in node.attribute if attr.name == attr_name] + return helper.get_attribute_value(found[0]) if found else default_value + + +def get_dim_from_proto(dim): + """Retrieve the dimension value from the ONNX protobuf object if it is a string.""" + return ( + getattr(dim, dim.WhichOneof("value")) if type(dim.WhichOneof("value")) is str else None + ) + + +def is_sequence(type_proto): + """Check if the given ONNX proto type is a sequence.""" + cls_type = type_proto.WhichOneof("value") + assert cls_type in {"tensor_type", "sequence_type"} + return cls_type == "sequence_type" + + +def get_shape_from_type_proto(type_proto): + """Extract the shape of a tensor from an ONNX type proto if available, otherwise return None.""" + assert not is_sequence(type_proto) + if type_proto.tensor_type.HasField("shape"): + return [get_dim_from_proto(d) for d in type_proto.tensor_type.shape.dim] + else: + return None # note no shape is different from shape without dim (scalar) + + +def get_elem_type_from_type_proto(type_proto): + """Return the element type from a given TypeProto object, either from sequence type or tensor type.""" + if is_sequence(type_proto): + return type_proto.sequence_type.elem_type.tensor_type.elem_type + else: + return type_proto.tensor_type.elem_type + + +def get_shape_from_value_info(vi): + """Return the shape from the given ValueInfoProto object, either from sequence type or tensor type.""" + cls_type = vi.type.WhichOneof("value") + if cls_type is None: + return None + if not is_sequence(vi.type): + return get_shape_from_type_proto(vi.type) + if vi.type.sequence_type.elem_type.WhichOneof("value") == "tensor_type": + return get_shape_from_type_proto(vi.type.sequence_type.elem_type) + else: + return None + + +def make_named_value_info(name): + """Create and return an ONNX ValueInfoProto object with the specified name.""" + vi = onnx.ValueInfoProto() + vi.name = name + return vi + + +def get_shape_from_sympy_shape(sympy_shape): + """Convert a sympy shape to a list with int, str, or None elements.""" + return [None if i is None else (int(i) if is_literal(i) else str(i)) for i in sympy_shape] + + +def is_literal(dim): + """Check if a dimension is a literal number (int, np.int64, np.int32, sympy.Integer) or has an 'is_number' + attribute. + """ + return type(dim) in {int, np.int64, np.int32, sympy.Integer} or ( + hasattr(dim, "is_number") and dim.is_number + ) + + +def handle_negative_axis(axis, rank): + """Convert a potentially negative axis to a positive axis based on the given rank.""" + assert axis < rank and axis >= -rank + return axis if axis >= 0 else rank + axis + + +def get_opset(mp, domain=None): + """Retrieve the opset version for a given model namespace, defaulting to common ONNX domains if no specific domain + is provided. + """ + domain = domain or ["", "onnx", "ai.onnx"] + if type(domain) != list: + domain = [domain] + for opset in mp.opset_import: + if opset.domain in domain: + return opset.version + + return None + + +def as_scalar(x): + """Convert input to scalar if input is a list with a single item or a NumPy ndarray.""" + if type(x) == list: + assert len(x) == 1 + return x[0] + elif type(x) == np.ndarray: + return x.item() + else: + return x + + +def as_list(x, keep_none): + """Convert input to list, optionally preserving None values.""" + if type(x) == list: + return x + elif type(x) == np.ndarray: + return list(x) + elif keep_none and x is None: + return None + else: + return [x] + + +def sympy_reduce_product(x): + """Reduce a list or element to a product using Sympy's Integer.""" + if type(x) == list: + value = sympy.Integer(1) + for v in x: + value = value * v + else: + value = x + return value + + +class SymbolicShapeInference: + def __init__(self, int_max, auto_merge, guess_output_rank, verbose, prefix=""): + """Initializes the SymbolicShapeInference class with configuration parameters for symbolic shape inference.""" + self.dispatcher_ = { + "Add": self._infer_symbolic_compute_ops, + "ArrayFeatureExtractor": self._infer_ArrayFeatureExtractor, + "AveragePool": self._infer_Pool, + "BatchNormalization": self._infer_BatchNormalization, + "Cast": self._infer_Cast, + "CategoryMapper": self._infer_CategoryMapper, + "Compress": self._infer_Compress, + "Concat": self._infer_Concat, + "ConcatFromSequence": self._infer_ConcatFromSequence, + "Constant": self._infer_Constant, + "ConstantOfShape": self._infer_ConstantOfShape, + "Conv": self._infer_Conv, + "CumSum": self._pass_on_shape_and_type, + "Div": self._infer_symbolic_compute_ops, + "Einsum": self._infer_Einsum, + "Expand": self._infer_Expand, + "Equal": self._infer_symbolic_compute_ops, + "Floor": self._infer_symbolic_compute_ops, + "Gather": self._infer_Gather, + "GatherElements": self._infer_GatherElements, + "GatherND": self._infer_GatherND, + "Identity": self._pass_on_shape_and_type, + "AllReduce": self._pass_on_shape_and_type, + "If": self._infer_If, + "Loop": self._infer_Loop, + "MatMul": self._infer_MatMul, + "MatMulInteger16": self._infer_MatMulInteger, + "MaxPool": self._infer_Pool, + "Max": self._infer_symbolic_compute_ops, + "MemcpyFromHost": self._pass_on_shape_and_type, + "MemcpyToHost": self._pass_on_shape_and_type, + "Min": self._infer_symbolic_compute_ops, + "MoE": self._pass_on_shape_and_type, + "Mul": self._infer_symbolic_compute_ops, + "NonMaxSuppression": self._infer_NonMaxSuppression, + "NonZero": self._infer_NonZero, + "OneHot": self._infer_OneHot, + "Pad": self._infer_Pad, + "Range": self._infer_Range, + "Reciprocal": self._pass_on_shape_and_type, + "ReduceSum": self._infer_ReduceSum, + "ReduceProd": self._infer_ReduceProd, + "Reshape": self._infer_Reshape, + "Resize": self._infer_Resize, + "Round": self._pass_on_shape_and_type, + "Scan": self._infer_Scan, + "ScatterElements": self._infer_ScatterElements, + "SequenceAt": self._infer_SequenceAt, + "SequenceInsert": self._infer_SequenceInsert, + "Shape": self._infer_Shape, + "Size": self._infer_Size, + "Slice": self._infer_Slice, + "SoftmaxCrossEntropyLoss": self._infer_SoftmaxCrossEntropyLoss, + "SoftmaxCrossEntropyLossInternal": self._infer_SoftmaxCrossEntropyLoss, + "NegativeLogLikelihoodLossInternal": self._infer_SoftmaxCrossEntropyLoss, + "Split": self._infer_Split, + "SplitToSequence": self._infer_SplitToSequence, + "Squeeze": self._infer_Squeeze, + "Sub": self._infer_symbolic_compute_ops, + "Tile": self._infer_Tile, + "TopK": self._infer_TopK, + "Transpose": self._infer_Transpose, + "Unsqueeze": self._infer_Unsqueeze, + "Where": self._infer_symbolic_compute_ops, + "ZipMap": self._infer_ZipMap, + "Neg": self._infer_symbolic_compute_ops, + # contrib ops: + "Attention": self._infer_Attention, + "BiasAdd": self._infer_BiasAdd, + "BiasGelu": self._infer_BiasGelu, + "BiasSplitGelu": self._infer_BiasSplitGelu, + "DecoderMaskedMultiHeadAttention": self._infer_DecoderMaskedMultiHeadAttention, + "DequantizeLinear": self._infer_DequantizeLinear, + "EmbedLayerNormalization": self._infer_EmbedLayerNormalization, + "FastGelu": self._infer_FastGelu, + "GatedRelativePositionBias": self._infer_GatedRelativePositionBias, + "Gelu": self._infer_Gelu, + "GemmFastGelu": self._infer_GemmFastGelu, + "GemmFloat8": self._infer_GemmFloat8, + "GroupNorm": self._infer_GroupNorm, + "SkipGroupNorm": self._infer_SkipGroupNorm, + "LayerNormalization": self._infer_LayerNormalization, + "LongformerAttention": self._infer_LongformerAttention, + "MultiHeadAttention": self._infer_MultiHeadAttention, + "NhwcConv": self._infer_NhwcConv, + "PackedAttention": self._infer_PackedAttention, + "PackedMultiHeadAttention": self._infer_PackedMultiHeadAttention, + "MultiScaleDeformableAttnTRT": self._infer_MultiScaleDeformableAttnTRT, + "PythonOp": self._infer_PythonOp, + "QuantizeLinear": self._infer_QuantizeLinear, + "QuickGelu": self._infer_FastGelu, + "RelativePositionBias": self._infer_RelativePositionBias, + "RemovePadding": self._infer_RemovePadding, + "RestorePadding": self._infer_RestorePadding, + "RotaryEmbedding": self._infer_RotaryEmbedding, + "SimplifiedLayerNormalization": self._infer_LayerNormalization, + "SkipLayerNormalization": self._infer_SkipLayerNormalization, + "SkipSimplifiedLayerNormalization": self._infer_SkipLayerNormalization, + } + self.aten_op_dispatcher_ = { + "embedding": self._infer_Gather, + "bitwise_or": self._infer_aten_bitwise_or, + "diagonal": self._infer_aten_diagonal, + "max_pool2d_with_indices": self._infer_aten_pool2d, + "max": self._infer_aten_minmax, + "min": self._infer_aten_minmax, + "multinomial": self._infer_aten_multinomial, + "unfold": self._infer_aten_unfold, + "argmax": self._infer_aten_argmax, + "avg_pool2d": self._infer_aten_pool2d, + "_adaptive_avg_pool2d": self._infer_aten_pool2d, + "numpy_T": self._infer_Transpose, + "native_group_norm": self._infer_aten_group_norm, + "upsample_nearest1d": self._infer_aten_upsample, + "upsample_nearest2d": self._infer_aten_upsample, + "upsample_nearest3d": self._infer_aten_upsample, + "upsample_bicubic2d": self._infer_aten_upsample, + } + self.run_ = True + self.suggested_merge_ = {} + self.symbolic_dims_ = {} + self.input_symbols_ = {} + self.auto_merge_ = auto_merge + self.guess_output_rank_ = guess_output_rank + self.verbose_ = verbose + self.int_max_ = int_max + self.subgraph_id_ = 0 + self.prefix_ = prefix + + def _add_suggested_merge(self, symbols, apply=False): + """Add suggested merges for input symbols, prioritizing literals, input symbolic dims, or existing symbolic + dims. + """ + assert all( + (type(s) == str and s in self.symbolic_dims_) or is_literal(s) for s in symbols + ) + symbols = set(symbols) + for k, v in self.suggested_merge_.items(): + if k in symbols: + symbols.remove(k) + symbols.add(v) + map_to = None + # if there is literal, map to it first + for s in symbols: + if is_literal(s): + map_to = s + break + # when no literals, map to input symbolic dims, then existing symbolic dims + if map_to is None: + for s in symbols: + if s in self.input_symbols_: + map_to = s + break + if map_to is None: + for s in symbols: + if type(self.symbolic_dims_[s]) == sympy.Symbol: + map_to = s + break + # when nothing to map to, use the shorter one + if map_to is None: + if self.verbose_ > 0: + logger.warning( + f"Potential unsafe merge between symbolic expressions: ({','.join(symbols)})" + ) + symbols_list = list(symbols) + lens = [len(s) for s in symbols_list] + map_to = symbols_list[lens.index(min(lens))] + symbols.remove(map_to) + + for s in symbols: + if s == map_to: + continue + if is_literal(map_to) and is_literal(s): + assert int(map_to) == int(s) + self.suggested_merge_[s] = int(map_to) if is_literal(map_to) else map_to + for k, v in self.suggested_merge_.items(): + if v == s: + self.suggested_merge_[k] = map_to + if apply and self.auto_merge_: + self._apply_suggested_merge() + + def _apply_suggested_merge(self, graph_input_only=False): + """Applies suggested merges to graph dimensions based on predefined rules in the `suggested_merge_` + dictionary. + """ + if not self.suggested_merge_: + return + for i in list(self.out_mp_.graph.input) + ( + [] if graph_input_only else list(self.out_mp_.graph.value_info) + ): + for d in i.type.tensor_type.shape.dim: + if d.dim_param in self.suggested_merge_: + v = self.suggested_merge_[d.dim_param] + if is_literal(v): + d.dim_value = int(v) + else: + d.dim_param = v + + def _preprocess(self, in_mp): + """Preprocess ONNX model by copying its structure and updating graph input and initializer dictionaries.""" + self.out_mp_ = onnx.ModelProto() + self.out_mp_.CopyFrom(in_mp) + self.graph_inputs_ = {i.name: i for i in list(self.out_mp_.graph.input)} + self.initializers_ = {i.name: i for i in self.out_mp_.graph.initializer} + self.known_vi_ = {i.name: i for i in list(self.out_mp_.graph.input)} + self.known_vi_.update( + { + i.name: helper.make_tensor_value_info(i.name, i.data_type, list(i.dims)) + for i in self.out_mp_.graph.initializer + } + ) + + def _merge_symbols(self, dims): + """Merge dimension symbols, handling automatic merging and validation of symbolic dimensions.""" + if any(type(d) != str for d in dims): + if not self.auto_merge_: + return None + unique_dims = list(set(dims)) + is_int = [is_literal(d) for d in unique_dims] + assert sum(is_int) <= 1 # if there are more than 1 unique ints, something is wrong + if sum(is_int) == 1: + int_dim = is_int.index(1) + if self.verbose_ > 0: + logger.debug( + f"dim {unique_dims[:int_dim] + unique_dims[int_dim + 1 :]} has been merged with value {unique_dims[int_dim]}" + ) + self._check_merged_dims(unique_dims, allow_broadcast=False) + return unique_dims[int_dim] + else: + if self.verbose_ > 0: + logger.debug( + f"dim {unique_dims[1:]} has been merged with dim {unique_dims[0]}" + ) + return dims[0] + if all(d == dims[0] for d in dims): + return dims[0] + merged = [self.suggested_merge_[d] if d in self.suggested_merge_ else d for d in dims] + if all(d == merged[0] for d in merged): + assert merged[0] in self.symbolic_dims_ + return merged[0] + else: + return None + + # broadcast from right to left, and merge symbolic dims if needed + def _broadcast_shapes(self, shape1, shape2): + """Broadcast two shapes from right to left, merging symbolic dimensions if necessary.""" + new_shape = [] + rank1 = len(shape1) + rank2 = len(shape2) + new_rank = max(rank1, rank2) + for i in range(new_rank): + dim1 = shape1[rank1 - 1 - i] if i < rank1 else 1 + dim2 = shape2[rank2 - 1 - i] if i < rank2 else 1 + if dim1 in [1, dim2]: + new_dim = dim2 + elif dim2 == 1: + new_dim = dim1 + else: + new_dim = self._merge_symbols([dim1, dim2]) + if not new_dim: + # warning about unsupported broadcast when not auto merge + # note that auto merge has the risk of incorrectly merge symbols while one of them being 1 + # for example, 'a' = 1, 'b' = 5 at runtime is valid broadcasting, but with auto merge 'a' == 'b' + if self.auto_merge_: + self._add_suggested_merge([dim1, dim2], apply=True) + else: + logger.warning( + f"unsupported broadcast between {str(dim1)} {str(dim2)}" + ) + new_shape = [new_dim, *new_shape] + return new_shape + + def _get_shape(self, node, idx): + """Retrieve the shape of a tensor from a node's inputs based on known value info or initializers.""" + name = node.input[idx] + if name in self.known_vi_: + vi = self.known_vi_[name] + return get_shape_from_value_info(vi) + else: + assert name in self.initializers_ + return list(self.initializers_[name].dims) + + def _try_get_shape(self, node, idx): + """Attempts to retrieve the shape of the input node at the specified index if available.""" + if idx > len(node.input) - 1: + return None + name = node.input[idx] + if name in self.known_vi_: + vi = self.known_vi_[name] + return get_shape_from_value_info(vi) + if name in self.initializers_: + return list(self.initializers_[name].dims) + return None + + def _get_shape_rank(self, node, idx): + """Return the rank (number of dimensions) of the shape of the input tensor at the specified index for a given + node. + """ + return len(self._get_shape(node, idx)) + + def _get_sympy_shape(self, node, idx): + """Return the symbolic shape dimensions using SymPy for the given input tensor at the specified index for a + node. + """ + sympy_shape = [] + for d in self._get_shape(node, idx): + if type(d) == str: + sympy_shape.append( + self.symbolic_dims_[d] + if d in self.symbolic_dims_ + else sympy.Symbol(d, integer=True, nonnegative=True) + ) + else: + assert None is not d + sympy_shape.append(d) + return sympy_shape + + def _get_value(self, node, idx): + """Retrieve the value associated with a node's input index from sympy_data_ or initializers_.""" + name = node.input[idx] + assert name in self.sympy_data_ or name in self.initializers_ + return ( + self.sympy_data_[name] + if name in self.sympy_data_ + else numpy_helper.to_array(self.initializers_[name]) + ) + + def _try_get_value(self, node, idx): + """Try to retrieve the value associated with a node's input index from sympy_data_ or initializers_.""" + if idx >= len(node.input): + return None + name = node.input[idx] + if name in self.sympy_data_ or name in self.initializers_: + return self._get_value(node, idx) + return None + + def _update_computed_dims(self, new_sympy_shape): + """Update dimensions in new_sympy_shape based on suggested merges and computational expressions.""" + for i, new_dim in enumerate(new_sympy_shape): + if not is_literal(new_dim) and type(new_dim) != str: + str_dim = str(new_dim) + if str_dim in self.suggested_merge_: + if not is_literal(self.suggested_merge_[str_dim]): + new_sympy_shape[i] = self.symbolic_dims_[ + self.suggested_merge_[str_dim] + ] + elif str_dim not in self.symbolic_dims_: + self.symbolic_dims_[str_dim] = new_dim + + def _onnx_infer_single_node(self, node): + """Performs ONNX shape inference for a single node, skipping inference for specified operation types.""" + skip_infer = node.op_type in { + "If", + "Loop", + "Scan", + "SplitToSequence", + "ZipMap", # contrib ops + "Attention", + "BiasGelu", + "EmbedLayerNormalization", + "FastGelu", + "Gelu", + "GemmFastGelu", + "LayerNormalization", + "LongformerAttention", + "DequantizeLinear", + "QuantizeLinear", + "RelativePositionBias", + "RemovePadding", + "RestorePadding", + "SimplifiedLayerNormalization", + "SkipLayerNormalization", + "SkipSimplifiedLayerNormalization", + "PackedAttention", + "PythonOp", + "MultiHeadAttention", + "GroupNorm", + "SkipGroupNorm", + "BiasSplitGelu", + "BiasAdd", + "NhwcConv", + "QuickGelu", + "RotaryEmbedding", + } + + if not skip_infer: + # Only pass initializers that satisfy the following condition: + # (1) Operator need value of some input for shape inference. + # For example, Unsqueeze in opset 13 uses the axes input to calculate shape of output. + # (2) opset version >= 9. In older version, initializer is required in graph input by onnx spec. + # (3) The initializer is not in graph input. The means the node input is "constant" in inference. + initializers = [] + if (get_opset(self.out_mp_) >= 9) and node.op_type == "Unsqueeze": + initializers = [ + self.initializers_[name] + for name in node.input + if (name in self.initializers_ and name not in self.graph_inputs_) + ] + + if ( + node.op_type + in { + "Add", + "Sub", + "Mul", + "Div", + "MatMul", + "MatMulInteger", + "MatMulInteger16", + "Where", + "Sum", + } + and node.output[0] in self.known_vi_ + ): + vi = self.known_vi_[node.output[0]] + out_rank = len(get_shape_from_type_proto(vi.type)) + in_shapes = [self._get_shape(node, i) for i in range(len(node.input))] + for d in range( + out_rank + - ( + 2 + if node.op_type in {"MatMul", "MatMulInteger", "MatMulInteger16"} + else 0 + ) + ): + in_dims = [ + s[len(s) - out_rank + d] for s in in_shapes if len(s) + d >= out_rank + ] + if len(in_dims) > 1: + self._check_merged_dims(in_dims, allow_broadcast=True) + + # run single node inference with self.known_vi_ shapes + tmp_graph = helper.make_graph( + [node], + "tmp", + [self.known_vi_[i] for i in node.input if i], + [make_named_value_info(i) for i in node.output], + initializers, + ) + self.tmp_mp_.graph.CopyFrom(tmp_graph) + + self.tmp_mp_ = shape_inference.infer_shapes(self.tmp_mp_) + + for i_o in range(len(node.output)): + o = node.output[i_o] + if o: # skip optional output + vi = self.out_mp_.graph.value_info.add() + if not skip_infer: + vi.CopyFrom(self.tmp_mp_.graph.output[i_o]) + else: + vi.name = o + self.known_vi_[o] = vi + + def _onnx_infer_subgraph(self, node, subgraph, use_node_input=True, inc_subgraph_id=True): + """Infer shapes and types within a subgraph for a given ONNX node using temporary graphs and known value + information. + """ + if self.verbose_ > 2: + logger.debug( + f"Inferencing subgraph of node {node.name} with output({node.output[0]}...): {node.op_type}" + ) + # node inputs are not passed directly to the subgraph + # it's up to the node dispatcher to prepare subgraph input + # for example, with Scan/Loop, subgraph input shape would be trimmed from node input shape + # besides, inputs in subgraph could shadow implicit inputs + subgraph_inputs = {i.name for i in list(subgraph.initializer) + list(subgraph.input)} + subgraph_implicit_input = { + name for name in self.known_vi_ if name not in subgraph_inputs + } + tmp_graph = helper.make_graph( + list(subgraph.node), + "tmp", + list(subgraph.input) + [self.known_vi_[i] for i in subgraph_implicit_input], + [make_named_value_info(i.name) for i in subgraph.output], + ) + tmp_graph.initializer.extend( + [i for i in self.out_mp_.graph.initializer if i.name in subgraph_implicit_input] + ) + tmp_graph.initializer.extend(subgraph.initializer) + self.tmp_mp_.graph.CopyFrom(tmp_graph) + + symbolic_shape_inference = SymbolicShapeInference( + self.int_max_, + self.auto_merge_, + self.guess_output_rank_, + self.verbose_, + prefix=f"{self.prefix_}_{str(self.subgraph_id_)}", + ) + if inc_subgraph_id: + self.subgraph_id_ += 1 + + symbolic_shape_inference._preprocess(self.tmp_mp_) + symbolic_shape_inference.suggested_merge_ = self.suggested_merge_.copy() + while symbolic_shape_inference.run_: + symbolic_shape_inference._infer_impl(self.sympy_data_.copy()) + symbolic_shape_inference._update_output_from_vi() + if use_node_input: + # if subgraph uses node input, it needs to update to merged dims + subgraph.ClearField("input") + subgraph.input.extend( + symbolic_shape_inference.out_mp_.graph.input[: len(node.input)] + ) + subgraph.ClearField("output") + subgraph.output.extend(symbolic_shape_inference.out_mp_.graph.output) + subgraph.ClearField("value_info") + subgraph.value_info.extend(symbolic_shape_inference.out_mp_.graph.value_info) + subgraph.ClearField("node") + subgraph.node.extend(symbolic_shape_inference.out_mp_.graph.node) + # for new symbolic dims from subgraph output, add to main graph symbolic dims + subgraph_shapes = [ + get_shape_from_value_info(o) for o in symbolic_shape_inference.out_mp_.graph.output + ] + subgraph_new_symbolic_dims = { + d + for s in subgraph_shapes + if s + for d in s + if type(d) == str and d not in self.symbolic_dims_ + } + new_dims = {} + for d in subgraph_new_symbolic_dims: + assert d in symbolic_shape_inference.symbolic_dims_ + new_dims[d] = symbolic_shape_inference.symbolic_dims_[d] + self.symbolic_dims_.update(new_dims) + return symbolic_shape_inference + + def _get_int_or_float_values(self, node, broadcast=False, allow_float_values=False): + """Extracts integer or float values from a node, with options for broadcasting and allowing float values.""" + + def int_or_float(value, allow_float_values): + """Converts a value to an integer unless precision loss occurs and allow_float_values is True.""" + return value if allow_float_values and value % 1 != 0 else int(value) + + values = [self._try_get_value(node, i) for i in range(len(node.input))] + if all(v is not None for v in values): + # some shape compute is in floating point, cast to int for sympy + for i, v in enumerate(values): + if type(v) != np.ndarray: + continue + if len(v.shape) > 1: + new_v = None # ignore value for rank > 1 + elif len(v.shape) == 0: + new_v = int_or_float(v.item(), allow_float_values) + else: + assert len(v.shape) == 1 + new_v = [int_or_float(vv, allow_float_values) for vv in v] + values[i] = new_v + values_len = [len(v) if isinstance(v, list) else 0 for v in values] + max_len = max(values_len) + if max_len >= 1 and broadcast: + # broadcast + for i, v in enumerate(values): + if v is None: + continue # don't broadcast if value is unknown + if isinstance(v, list): + if len(v) < max_len: + values[i] = v * max_len + else: + assert len(v) == max_len + else: + values[i] = [v] * max_len + return values + + def _compute_on_sympy_data(self, node, op_func): + """Calculate the result using Sympy data and a specified operation function.""" + assert len(node.output) == 1 + + # Before mul & div operations + # cast inputs into integer might lose decimal part and reduce precision + # keep them as float, finish the operation, then cast the result into integer + if node.op_type in {"Mul", "Div"}: + values = self._get_int_or_float_values( + node, broadcast=True, allow_float_values=True + ) + else: + values = self._get_int_or_float_values(node, broadcast=True) + + if all(v is not None for v in values): + is_list = [isinstance(v, list) for v in values] + as_list = any(is_list) + if as_list: + self.sympy_data_[node.output[0]] = [op_func(vs) for vs in zip(*values)] + else: + self.sympy_data_[node.output[0]] = op_func(values) + + def _pass_on_sympy_data(self, node): + """Pass Sympy data through a node, validating input length or node operation type 'Reshape', 'Unsqueeze', + 'Squeeze'. + """ + assert len(node.input) == 1 or node.op_type in { + "Reshape", + "Unsqueeze", + "Squeeze", + } + self._compute_on_sympy_data(node, lambda x: x[0]) + + def _pass_on_shape_and_type(self, node): + """Propagates the shape and type information from input to output for a given node.""" + vi = self.known_vi_[node.output[0]] + vi.CopyFrom( + helper.make_tensor_value_info( + node.output[0], + get_elem_type_from_type_proto(self.known_vi_[node.input[0]].type), + self._get_shape(node, 0), + ) + ) + + def _new_symbolic_dim(self, prefix, dim): + """Create and return a new symbolic dimension, handling literal values and caching for repeated uses.""" + new_dim = f"{prefix}_d{dim}" + if new_dim in self.suggested_merge_: + v = self.suggested_merge_[new_dim] + new_symbolic_dim = sympy.Integer(int(v)) if is_literal(v) else v + else: + new_symbolic_dim = sympy.Symbol(new_dim, integer=True, nonnegative=True) + self.symbolic_dims_[new_dim] = new_symbolic_dim + return new_symbolic_dim + + def _new_symbolic_dim_from_output(self, node, out_idx=0, dim=0): + """Generates a new symbolic dimension for a given node's output using the node's operation type, prefix, and + output index. + """ + return self._new_symbolic_dim( + f"{node.op_type}{self.prefix_}_{list(self.out_mp_.graph.node).index(node)}_o{out_idx}_", + dim, + ) + + def _new_symbolic_shape(self, rank, node, out_idx=0): + """Generate a new symbolic shape for a node output based on its rank and index.""" + return [self._new_symbolic_dim_from_output(node, out_idx, i) for i in range(rank)] + + def _compute_conv_pool_shape(self, node, channels_last=False): + """Calculate the output shape of a convolutional or pooling layer node, optionally considering channels_last + format. + """ + sympy_shape = self._get_sympy_shape(node, 0) + if len(node.input) > 1: + W_shape = self._get_sympy_shape(node, 1) + rank = len(W_shape) - 2 # number of spatial axes + kernel_shape = W_shape[-rank - 1 : -1] if channels_last else W_shape[-rank:] + sympy_shape[3 if channels_last else 1] = W_shape[0] + else: + W_shape = None + kernel_shape = get_attribute(node, "kernel_shape") + rank = len(kernel_shape) + + assert len(sympy_shape) == rank + 2 + + # only need to symbolic shape inference if input has symbolic dims in spatial axes + spatial_shape = sympy_shape[-rank - 1 : -1] if channels_last else sympy_shape[-rank:] + is_symbolic_dims = [not is_literal(i) for i in spatial_shape] + + if not any(is_symbolic_dims): + shape = get_shape_from_value_info(self.known_vi_[node.output[0]]) + if len(shape) > 0: + assert len(sympy_shape) == len(shape) + if channels_last: + sympy_shape[-rank - 1 : -1] = [ + sympy.Integer(d) for d in shape[-rank - 1 : -1] + ] + else: + sympy_shape[-rank:] = [sympy.Integer(d) for d in shape[-rank:]] + return sympy_shape + + dilations = get_attribute(node, "dilations", [1] * rank) + strides = get_attribute(node, "strides", [1] * rank) + effective_kernel_shape = [(k - 1) * d + 1 for k, d in zip(kernel_shape, dilations)] + pads = get_attribute(node, "pads") + if pads is None: + pads = [0] * (2 * rank) + auto_pad = get_attribute(node, "auto_pad", b"NOTSET").decode("utf-8") + if auto_pad not in {"VALID", "NOTSET"}: + try: + residual = [sympy.Mod(d, s) for d, s in zip(sympy_shape[-rank:], strides)] + total_pads = [ + max(0, (k - s) if r == 0 else (k - r)) + for k, s, r in zip(effective_kernel_shape, strides, residual) + ] + except ( + TypeError + ): # sympy may throw TypeError: cannot determine truth value of Relational + total_pads = [ + max(0, (k - s)) for k, s in zip(effective_kernel_shape, strides) + ] # assuming no residual if sympy throws error + elif auto_pad == "VALID": + total_pads = [] + else: + total_pads = [0] * rank + else: + assert len(pads) == 2 * rank + total_pads = [p1 + p2 for p1, p2 in zip(pads[:rank], pads[rank:])] + + ceil_mode = get_attribute(node, "ceil_mode", 0) + for i in range(rank): + effective_input_size = sympy_shape[-rank + i + (-1 if channels_last else 0)] + if len(total_pads) > 0: + effective_input_size = effective_input_size + total_pads[i] + if ceil_mode: + strided_kernel_positions = sympy.ceiling( + (effective_input_size - effective_kernel_shape[i]) / strides[i] + ) + else: + strided_kernel_positions = ( + effective_input_size - effective_kernel_shape[i] + ) // strides[i] + sympy_shape[-rank + i + (-1 if channels_last else 0)] = ( + strided_kernel_positions + 1 + ) + return sympy_shape + + def _check_merged_dims(self, dims, allow_broadcast=True): + """Checks merged dimensions for consistency, optionally allowing broadcasting.""" + if allow_broadcast: + dims = [d for d in dims if not (is_literal(d) and int(d) <= 1)] + if any(d != dims[0] for d in dims): + self._add_suggested_merge(dims, apply=True) + + def _compute_matmul_shape(self, node, output_dtype=None): + """Compute the output shape for a matrix multiplication operation based on input shapes and optionally infer the + output data type. + """ + lhs_shape = self._get_shape(node, 0) + rhs_shape = self._get_shape(node, 1) + lhs_rank = len(lhs_shape) + rhs_rank = len(rhs_shape) + lhs_reduce_dim = 0 + rhs_reduce_dim = 0 + assert lhs_rank > 0 and rhs_rank > 0 + if lhs_rank == 1 and rhs_rank == 1: + new_shape = [] + elif lhs_rank == 1: + rhs_reduce_dim = -2 + new_shape = rhs_shape[:rhs_reduce_dim] + [rhs_shape[-1]] + elif rhs_rank == 1: + lhs_reduce_dim = -1 + new_shape = lhs_shape[:lhs_reduce_dim] + else: + lhs_reduce_dim = -1 + rhs_reduce_dim = -2 + new_shape = [ + *self._broadcast_shapes(lhs_shape[:-2], rhs_shape[:-2]), + lhs_shape[-2], + rhs_shape[-1], + ] + # merge reduce dim + self._check_merged_dims( + [lhs_shape[lhs_reduce_dim], rhs_shape[rhs_reduce_dim]], + allow_broadcast=False, + ) + if output_dtype is None: + # infer output_dtype from input type when not specified + output_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type + vi = self.known_vi_[node.output[0]] + vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, new_shape)) + + def _fuse_tensor_type(self, node, out_idx, dst_type, src_type): + """Update dst_tensor_type to be compatible with src_tensor_type when dimension mismatches.""" + dst_tensor_type = ( + dst_type.sequence_type.elem_type.tensor_type + if is_sequence(dst_type) + else dst_type.tensor_type + ) + src_tensor_type = ( + src_type.sequence_type.elem_type.tensor_type + if is_sequence(src_type) + else src_type.tensor_type + ) + if dst_tensor_type.elem_type != src_tensor_type.elem_type: + node_id = node.name or node.op_type + raise ValueError( + f"For node {node_id}, dst_tensor_type.elem_type != src_tensor_type.elem_type: " + f"{onnx.onnx_pb.TensorProto.DataType.Name(dst_tensor_type.elem_type)} vs " + f"{onnx.onnx_pb.TensorProto.DataType.Name(src_tensor_type.elem_type)}" + ) + if dst_tensor_type.HasField("shape"): + for di, ds in enumerate(zip(dst_tensor_type.shape.dim, src_tensor_type.shape.dim)): + if ds[0] != ds[1]: + # create a new symbolic dimension for node/out_idx/mismatch dim id in dst_tensor_type for tensor_type + # for sequence_type, clear the dimension + new_dim = onnx.TensorShapeProto.Dimension() + if not is_sequence(dst_type): + new_dim.dim_param = str( + self._new_symbolic_dim_from_output(node, out_idx, di) + ) + dst_tensor_type.shape.dim[di].CopyFrom(new_dim) + else: + dst_tensor_type.CopyFrom(src_tensor_type) + + def _infer_ArrayFeatureExtractor(self, node): + """Infer and update the shape and type information for the ArrayFeatureExtractor node using input data and + indices shapes. + """ + data_shape = self._get_shape(node, 0) + indices_shape = self._get_shape(node, 1) + vi = self.known_vi_[node.output[0]] + vi.CopyFrom( + helper.make_tensor_value_info( + node.output[0], + self.known_vi_[node.input[0]].type.tensor_type.elem_type, + data_shape[:-1] + indices_shape, + ) + ) + + def _infer_symbolic_compute_ops(self, node): + """Handles symbolic computation operations for given node based on predefined functions.""" + funcs = { + "Add": lambda l: l[0] + l[1], + "Div": lambda l: ( + int(l[0] // l[1]) + if isinstance(l[0] // l[1], float) + else l[0] // l[1] + ), # integer div in sympy + "Equal": lambda l: l[0] == l[1], + "Floor": lambda l: sympy.floor(l[0]), + "Max": lambda l: ( + l[1] + if is_literal(l[0]) and int(l[0]) < -self.int_max_ + else ( + l[0] + if is_literal(l[1]) and int(l[1]) < -self.int_max_ + else sympy.Max(l[0], l[1]) + ) + ), + "Min": lambda l: ( + l[1] + if is_literal(l[0]) and int(l[0]) > self.int_max_ + else ( + l[0] + if is_literal(l[1]) and int(l[1]) > self.int_max_ + else sympy.Min(l[0], l[1]) + ) + ), + "Mul": lambda l: ( + int(l[0] * l[1]) if isinstance(l[0] * l[1], float) else l[0] * l[1] + ), + "Sub": lambda l: l[0] - l[1], + "Where": lambda l: l[1] if l[0] else l[2], + "Neg": lambda l: -l[0], + } + assert node.op_type in funcs + self._compute_on_sympy_data(node, funcs[node.op_type]) + + def _infer_Cast(self, node): + """Pass node's data to SymPy representation without alteration.""" + self._pass_on_sympy_data(node) + + def _infer_CategoryMapper(self, node): + """Infer and set output tensor type for ONNX CategoryMapper nodes based on input tensor type.""" + input_type = self.known_vi_[node.input[0]].type.tensor_type.elem_type + if input_type == onnx.TensorProto.STRING: + output_type = onnx.TensorProto.INT64 + else: + output_type = onnx.TensorProto.STRING + vi = self.known_vi_[node.output[0]] + vi.CopyFrom( + helper.make_tensor_value_info( + node.output[0], output_type, self._get_shape(node, 0) + ) + ) + + def _infer_Compress(self, node): + """Infer the output shape and type for the Compress operation based on input shape and axis attribute.""" + input_shape = self._get_shape(node, 0) + # create a new symbolic dimension for Compress output + compress_len = str(self._new_symbolic_dim_from_output(node)) + axis = get_attribute(node, "axis") + if axis is None: + # when axis is not specified, input is flattened before compress so output is 1D + output_shape = [compress_len] + else: + output_shape = input_shape + output_shape[handle_negative_axis(axis, len(input_shape))] = compress_len + vi = self.known_vi_[node.output[0]] + vi.CopyFrom( + helper.make_tensor_value_info( + node.output[0], + self.known_vi_[node.input[0]].type.tensor_type.elem_type, + output_shape, + ) + ) + + def _infer_Concat(self, node): + """Infer the output shape and type for the Concat operation based on input node values.""" + if any(i in self.sympy_data_ or i in self.initializers_ for i in node.input): + values = self._get_int_or_float_values(node) + if all(v is not None for v in values): + assert get_attribute(node, "axis") == 0 + self.sympy_data_[node.output[0]] = [] + for i in range(len(node.input)): + value = values[i] + if isinstance(value, list): + self.sympy_data_[node.output[0]].extend(value) + else: + self.sympy_data_[node.output[0]].append(value) + + sympy_shape = self._get_sympy_shape(node, 0) + axis = handle_negative_axis(get_attribute(node, "axis"), len(sympy_shape)) + for i_idx in range(1, len(node.input)): + input_shape = self._get_sympy_shape(node, i_idx) + if input_shape: + sympy_shape[axis] = sympy_shape[axis] + input_shape[axis] + self._update_computed_dims(sympy_shape) + # merge symbolic dims for non-concat axes + for d in range(len(sympy_shape)): + if d == axis: + continue + dims = [ + self._get_shape(node, i_idx)[d] + for i_idx in range(len(node.input)) + if self._get_shape(node, i_idx) + ] + if all(d == dims[0] for d in dims): + continue + merged = self._merge_symbols(dims) + if type(merged) == str: + sympy_shape[d] = self.symbolic_dims_[merged] if merged else None + else: + sympy_shape[d] = merged + vi = self.known_vi_[node.output[0]] + vi.CopyFrom( + helper.make_tensor_value_info( + node.output[0], + self.known_vi_[node.input[0]].type.tensor_type.elem_type, + get_shape_from_sympy_shape(sympy_shape), + ) + ) + + def _infer_ConcatFromSequence(self, node): + """Infers the output shape and type info for ConcatFromSequence operation in a computational graph node.""" + seq_shape = self._get_shape(node, 0) + new_axis = 1 if get_attribute(node, "new_axis") else 0 + axis = handle_negative_axis(get_attribute(node, "axis"), len(seq_shape) + new_axis) + concat_dim = str(self._new_symbolic_dim_from_output(node, 0, axis)) + new_shape = seq_shape + if new_axis: + new_shape = seq_shape[:axis] + [concat_dim] + seq_shape[axis:] + else: + new_shape[axis] = concat_dim + vi = self.known_vi_[node.output[0]] + vi.CopyFrom( + helper.make_tensor_value_info( + node.output[0], + self.known_vi_[ + node.input[0] + ].type.sequence_type.elem_type.tensor_type.elem_type, + new_shape, + ) + ) + + def _infer_Constant(self, node): + """Infer the constant value for a given node and store it in sympy_data_.""" + t = get_attribute(node, "value") + self.sympy_data_[node.output[0]] = numpy_helper.to_array(t) + + def _infer_ConstantOfShape(self, node): + """Infer the constant tensor of a given shape from a node and update sympy_data_.""" + sympy_shape = self._get_int_or_float_values(node)[0] + vi = self.known_vi_[node.output[0]] + if sympy_shape is not None: + if type(sympy_shape) != list: + sympy_shape = [sympy_shape] + self._update_computed_dims(sympy_shape) + # update sympy data if output type is int, and shape is known + if vi.type.tensor_type.elem_type == onnx.TensorProto.INT64 and all( + is_literal(x) for x in sympy_shape + ): + self.sympy_data_[node.output[0]] = np.ones( + [int(x) for x in sympy_shape], dtype=np.int64 + ) * numpy_helper.to_array(get_attribute(node, "value", 0)) + else: + # create new dynamic shape + # note input0 is a 1D vector of shape, the new symbolic shape has the rank of the shape vector length + sympy_shape = self._new_symbolic_shape(self._get_shape(node, 0)[0], node) + + vi.CopyFrom( + helper.make_tensor_value_info( + node.output[0], + vi.type.tensor_type.elem_type, + get_shape_from_sympy_shape(sympy_shape), + ) + ) + + def _infer_Conv(self, node): + """Infers the shape of the output tensor for a convolution operation node and updates the known value info.""" + sympy_shape = self._compute_conv_pool_shape(node) + self._update_computed_dims(sympy_shape) + vi = self.known_vi_[node.output[0]] + vi.CopyFrom( + helper.make_tensor_value_info( + node.output[0], + vi.type.tensor_type.elem_type, + get_shape_from_sympy_shape(sympy_shape), + ) + ) + + def _infer_NhwcConv(self, node): + """Infer the shape of the output tensor for a convolution operation with NHWC format.""" + sympy_shape = self._compute_conv_pool_shape(node, channels_last=True) + self._update_computed_dims(sympy_shape) + vi = self.known_vi_[node.output[0]] + vi.CopyFrom( + helper.make_tensor_value_info( + node.output[0], + self.known_vi_[node.input[0]].type.tensor_type.elem_type, + get_shape_from_sympy_shape(sympy_shape), + ) + ) + + def _infer_DequantizeLinear(self, node): + """Infer output type and shape for the DequantizeLinear node based on input 1's scale data type.""" + output_dtype = self.known_vi_[node.input[1]].type.tensor_type.elem_type + + # Get the output shape from the first input. + output_shape = self._get_shape(node, 0) + + vi = self.known_vi_[node.output[0]] + vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, output_shape)) + + def _infer_QuantizeLinear(self, node): + """Infer the output data type and shape for the QuantizeLinear ONNX node, defaulting to uint8 if not + specified. + """ + # Otherwise, default to uint8 + output_dtype = onnx.TensorProto.UINT8 + if len(node.input) > 2 and node.input[2]: + output_dtype = self.known_vi_[node.input[2]].type.tensor_type.elem_type + + # Get the output shape from the first input. + output_shape = self._get_shape(node, 0) + + vi = self.known_vi_[node.output[0]] + vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, output_shape)) + + def _infer_Einsum(self, node): + """Infer the output shape and type for the Einsum operation as per ONNX standards: https://github.com/onnx/onnx/blob/623dfaa/onnx/defs/math/defs.cc#L3275.""" + equation = get_attribute(node, "equation") + equation = equation.replace(b" ", b"") + mid_index = equation.find(b"->") + left_equation = equation[:mid_index] if mid_index != -1 else equation + + num_operands = 0 + num_ellipsis = 0 + num_ellipsis_indices = 0 + + letter_to_dim = {} + + terms = left_equation.split(b",") + for term in terms: + ellipsis_index = term.find(b"...") + shape = self._get_shape(node, num_operands) + rank = len(shape) + if ellipsis_index != -1: + if num_ellipsis == 0: + num_ellipsis_indices = rank - len(term) + 3 + num_ellipsis = num_ellipsis + 1 + for i in range(1, rank + 1): + letter = term[-i] + if letter != 46: # letter != b'.' + dim = shape[-i] + if letter not in letter_to_dim or type(dim) != sympy.Symbol: + letter_to_dim[letter] = dim + num_operands = num_operands + 1 + + new_sympy_shape = [] + from collections import OrderedDict + + num_letter_occurrences = OrderedDict() + if mid_index != -1: + right_equation = equation[mid_index + 2 :] + right_ellipsis_index = right_equation.find(b"...") + if right_ellipsis_index != -1: + for i in range(num_ellipsis_indices): + new_sympy_shape.append(shape[i]) + for c in right_equation: + if c != 46: # c != b'.' + new_sympy_shape.append(letter_to_dim[c]) + else: + for i in range(num_ellipsis_indices): + new_sympy_shape.append(shape[i]) + for c in left_equation: + if c not in {44, 46}: # c != b',' and c != b'.': + if c in num_letter_occurrences: + num_letter_occurrences[c] = num_letter_occurrences[c] + 1 + else: + num_letter_occurrences[c] = 1 + for key, value in num_letter_occurrences.items(): + if value == 1: + new_sympy_shape.append(letter_to_dim[key]) + + output_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type + vi = self.known_vi_[node.output[0]] + vi.CopyFrom( + helper.make_tensor_value_info(node.output[0], output_dtype, new_sympy_shape) + ) + + def _infer_Expand(self, node): + """Infers and updates the output shape for the Expand operation based on broadcasted input shapes.""" + expand_to_shape = as_list(self._try_get_value(node, 1), keep_none=True) + if expand_to_shape is not None: + # new_shape's dim can come from shape value + self._update_computed_dims(expand_to_shape) + shape = self._get_shape(node, 0) + new_shape = self._broadcast_shapes( + shape, get_shape_from_sympy_shape(expand_to_shape) + ) + vi = self.known_vi_[node.output[0]] + vi.CopyFrom( + helper.make_tensor_value_info( + node.output[0], + self.known_vi_[node.input[0]].type.tensor_type.elem_type, + new_shape, + ) + ) + + def _infer_Gather(self, node): + """Infer the output shape of the Gather operation based on the input data and indices shapes.""" + data_shape = self._get_shape(node, 0) + axis = handle_negative_axis(get_attribute(node, "axis", 0), len(data_shape)) + indices_shape = self._get_shape(node, 1) + vi = self.known_vi_[node.output[0]] + vi.CopyFrom( + helper.make_tensor_value_info( + node.output[0], + self.known_vi_[node.input[0]].type.tensor_type.elem_type, + data_shape[:axis] + indices_shape + data_shape[axis + 1 :], + ) + ) + # for 1D input, do some sympy compute + if ( + node.input[0] in self.sympy_data_ + and len(data_shape) == 1 + and get_attribute(node, "axis", 0) == 0 + ): + idx = self._try_get_value(node, 1) + if idx is not None: + data = self.sympy_data_[node.input[0]] + if type(data) == list: + if type(idx) == np.ndarray and len(idx.shape) == 1: + self.sympy_data_[node.output[0]] = [data[int(i)] for i in idx] + else: + self.sympy_data_[node.output[0]] = data[int(idx)] + else: + assert idx in {0, -1} + self.sympy_data_[node.output[0]] = data + + def _infer_GatherElements(self, node): + """Infers the output shape and type for the GatherElements node based on input tensors and updates the node's + value information. + """ + indices_shape = self._get_shape(node, 1) + vi = self.known_vi_[node.output[0]] + vi.CopyFrom( + helper.make_tensor_value_info( + node.output[0], + self.known_vi_[node.input[0]].type.tensor_type.elem_type, + indices_shape, + ) + ) + + def _infer_GatherND(self, node): + """Infers the output shape and type for the GatherND operation based on input data and indices shapes.""" + data_shape = self._get_shape(node, 0) + data_rank = len(data_shape) + indices_shape = self._get_shape(node, 1) + len(indices_shape) + last_index_dimension = indices_shape[-1] + batch_dims = get_attribute(node, "batch_dims", 0) + assert ( + is_literal(last_index_dimension) + and is_literal(batch_dims) + and (batch_dims + last_index_dimension) <= data_rank + ) + new_shape = indices_shape[:-1] + data_shape[batch_dims + last_index_dimension :] + vi = self.known_vi_[node.output[0]] + vi.CopyFrom( + helper.make_tensor_value_info( + node.output[0], + self.known_vi_[node.input[0]].type.tensor_type.elem_type, + new_shape, + ) + ) + + def _infer_If(self, node): + """Infer the output shape for an If node, handling constant conditions to ensure shape consistency between + branches. + """ + subgraphs = [ + get_attribute(node, "then_branch"), + get_attribute(node, "else_branch"), + ] + cond = self._try_get_value(node, 0) + if cond is not None: + if as_scalar(cond) > 0: + subgraphs[1].CopyFrom(subgraphs[0]) + else: + subgraphs[0].CopyFrom(subgraphs[1]) + + for i_sub, subgraph in enumerate(subgraphs): + subgraph_infer = self._onnx_infer_subgraph(node, subgraph, use_node_input=False) + for i_out in range(len(node.output)): + vi = self.known_vi_[node.output[i_out]] + if i_sub == 0: + vi.CopyFrom(subgraph.output[i_out]) + vi.name = node.output[i_out] + else: + self._fuse_tensor_type(node, i_out, vi.type, subgraph.output[i_out].type) + + # pass on sympy data from subgraph, if cond is constant + if ( + cond is not None + and i_sub == (0 if as_scalar(cond) > 0 else 1) + and subgraph.output[i_out].name in subgraph_infer.sympy_data_ + ): + self.sympy_data_[vi.name] = subgraph_infer.sympy_data_[ + subgraph.output[i_out].name + ] + + def _infer_Loop(self, node): + """Infer the shape and type of variables produced by the 'Loop' operation in an ONNX graph.""" + subgraph = get_attribute(node, "body") + assert len(subgraph.input) == len(node.input) + num_loop_carried = len(node.input) - 2 # minus the length and initial loop condition + # when sequence_type is used as loop carried input + # needs to run subgraph infer twice if the tensor shape in sequence contains None + for i, si in enumerate(subgraph.input): + si_name = si.name + si.CopyFrom(self.known_vi_[node.input[i]]) + si.name = si_name + + self._onnx_infer_subgraph(node, subgraph) + + # check subgraph input/output for shape changes in loop carried variables + # for tensor_type, create new symbolic dim when changing, i.e., output = Concat(input, a) + # for sequence_type, propagate from output to input + need_second_infer = False + for i_out in range(1, num_loop_carried + 1): + so = subgraph.output[i_out] + so_shape = get_shape_from_value_info(so) + if is_sequence(so.type): + if so_shape and None in so_shape: + # copy shape from output to input + # note that loop input is [loop_len, cond, input_0, input_1, ...] + # while loop output is [cond, output_0, output_1, ...] + subgraph.input[i_out + 1].type.sequence_type.elem_type.CopyFrom( + so.type.sequence_type.elem_type + ) + need_second_infer = True + else: + si = subgraph.input[i_out + 1] + si_shape = get_shape_from_value_info(si) + for di, dims in enumerate(zip(si_shape, so_shape)): + if dims[0] != dims[1]: + new_dim = onnx.TensorShapeProto.Dimension() + new_dim.dim_param = str( + self._new_symbolic_dim_from_output(node, i_out, di) + ) + si.type.tensor_type.shape.dim[di].CopyFrom(new_dim) + so.type.tensor_type.shape.dim[di].CopyFrom(new_dim) + need_second_infer = True + + if need_second_infer: + if self.verbose_ > 2: + logger.debug( + f"Rerun Loop: {node.name}({node.output[0]}...), because of sequence in loop carried variables" + ) + self._onnx_infer_subgraph(node, subgraph, inc_subgraph_id=False) + + # create a new symbolic dimension for iteration dependent dimension + loop_iter_dim = str(self._new_symbolic_dim_from_output(node)) + for i in range(len(node.output)): + vi = self.known_vi_[node.output[i]] + vi.CopyFrom( + subgraph.output[i + 1] + ) # first subgraph output is condition, not in node output + if i >= num_loop_carried: + assert not is_sequence( + vi.type + ) # TODO: handle loop accumulation in sequence_type + subgraph_vi_dim = subgraph.output[i + 1].type.tensor_type.shape.dim + vi.type.tensor_type.shape.ClearField("dim") + vi_dim = vi.type.tensor_type.shape.dim + vi_dim.add().dim_param = loop_iter_dim + vi_dim.extend(list(subgraph_vi_dim)) + vi.name = node.output[i] + + def _infer_MatMul(self, node): + """Infer the output shape of a matrix multiplication node.""" + self._compute_matmul_shape(node) + + def _infer_MatMulInteger(self, node): + """Infer the output shape of an integer matrix multiplication node.""" + self._compute_matmul_shape(node, onnx.TensorProto.INT32) + + def _infer_NonMaxSuppression(self, node): + """Infer the output shape of a NonMaxSuppression node and update the value info.""" + selected = str(self._new_symbolic_dim_from_output(node)) + vi = self.known_vi_[node.output[0]] + vi.CopyFrom( + helper.make_tensor_value_info( + node.output[0], onnx.TensorProto.INT64, [selected, 3] + ) + ) + + def _infer_NonZero(self, node): + """Infer the output shape of a NonZero node and update the value info.""" + input_rank = self._get_shape_rank(node, 0) + # create a new symbolic dimension for NonZero output + nz_len = str(self._new_symbolic_dim_from_output(node, 0, 1)) + vi = self.known_vi_[node.output[0]] + vi.CopyFrom( + helper.make_tensor_value_info( + node.output[0], vi.type.tensor_type.elem_type, [input_rank, nz_len] + ) + ) + + def _infer_OneHot(self, node): + """Infer the shape and type of the output tensor for the OneHot node operation.""" + sympy_shape = self._get_sympy_shape(node, 0) + depth = self._try_get_value(node, 1) + axis = get_attribute(node, "axis", -1) + axis = handle_negative_axis(axis, len(sympy_shape) + 1) + new_shape = get_shape_from_sympy_shape( + sympy_shape[:axis] + + [(depth if is_literal(depth) else self._new_symbolic_dim_from_output(node))] + + sympy_shape[axis:] + ) + vi = self.known_vi_[node.output[0]] + vi.CopyFrom( + helper.make_tensor_value_info( + node.output[0], + self.known_vi_[node.input[2]].type.tensor_type.elem_type, + new_shape, + ) + ) + + def _infer_Pad(self, node): + """Infers the output shape and type for the Pad operation based on ONNX node attributes and opset version.""" + if get_opset(self.out_mp_) <= 10: + pads = get_attribute(node, "pads") + else: + pads = self._try_get_value(node, 1) + + sympy_shape = self._get_sympy_shape(node, 0) + rank = len(sympy_shape) + + if pads is not None: + assert len(pads) == 2 * rank + new_sympy_shape = [ + d + pad_up + pad_down + for d, pad_up, pad_down in zip(sympy_shape, pads[:rank], pads[rank:]) + ] + self._update_computed_dims(new_sympy_shape) + else: + # dynamic pads, create new symbolic dimensions + new_sympy_shape = self._new_symbolic_shape(rank, node) + output_tp = self.known_vi_[node.input[0]].type.tensor_type.elem_type + + vi = self.known_vi_[node.output[0]] + vi.CopyFrom( + helper.make_tensor_value_info( + node.output[0], output_tp, get_shape_from_sympy_shape(new_sympy_shape) + ) + ) + + def _infer_Pool(self, node): + """Infer and update dimensions for pooling layers based on the input node.""" + sympy_shape = self._compute_conv_pool_shape(node) + self._update_computed_dims(sympy_shape) + for o in node.output: + if not o: + continue + vi = self.known_vi_[o] + vi.CopyFrom( + helper.make_tensor_value_info( + o, + vi.type.tensor_type.elem_type, + get_shape_from_sympy_shape(sympy_shape), + ) + ) + + def _infer_aten_bitwise_or(self, node): + """Infers the output shape for Aten bitwise OR operation based on input node shapes.""" + shape0 = self._get_shape(node, 0) + shape1 = self._get_shape(node, 1) + new_shape = self._broadcast_shapes(shape0, shape1) + t0 = self.known_vi_[node.input[0]] + vi = self.known_vi_[node.output[0]] + vi.CopyFrom( + helper.make_tensor_value_info( + node.output[0], t0.type.tensor_type.elem_type, new_shape + ) + ) + + def _infer_aten_diagonal(self, node): + """Infers the shape of the diagonal of a tensor given a node, offset, and dimensions.""" + sympy_shape = self._get_sympy_shape(node, 0) + rank = len(sympy_shape) + offset = self._try_get_value(node, 1) + dim1 = self._try_get_value(node, 2) + dim2 = self._try_get_value(node, 3) + + assert offset is not None and dim1 is not None and dim2 is not None + dim1 = handle_negative_axis(dim1, rank) + dim2 = handle_negative_axis(dim2, rank) + + new_shape = [val for dim, val in enumerate(sympy_shape) if dim not in {dim1, dim2}] + shape1 = sympy_shape[dim1] + shape2 = sympy_shape[dim2] + if offset >= 0: + diag_shape = sympy.Max(0, sympy.Min(shape1, shape2 - offset)) + else: + diag_shape = sympy.Max(0, sympy.Min(shape1 + offset, shape2)) + new_shape.append(diag_shape) + + if node.output[0]: + vi = self.known_vi_[node.output[0]] + vi.CopyFrom( + helper.make_tensor_value_info( + node.output[0], + self.known_vi_[node.input[0]].type.tensor_type.elem_type, + get_shape_from_sympy_shape(new_shape), + ) + ) + + def _infer_aten_multinomial(self, node): + """Infers the output shape and type for the PyTorch multinomial operation in an ONNX graph node.""" + sympy_shape = self._get_sympy_shape(node, 0) + rank = len(sympy_shape) + assert rank in {1, 2} + num_samples = self._try_get_value(node, 1) + di = rank - 1 + last_dim = num_samples or str(self._new_symbolic_dim_from_output(node, 0, di)) + output_shape = sympy_shape[:-1] + [last_dim] + vi = self.known_vi_[node.output[0]] + vi.CopyFrom( + helper.make_tensor_value_info( + node.output[0], + onnx.TensorProto.INT64, + get_shape_from_sympy_shape(output_shape), + ) + ) + + def _infer_aten_pool2d(self, node): + """Infer the output shape of a 2D pooling operation in an ONNX graph node.""" + sympy_shape = self._get_sympy_shape(node, 0) + assert len(sympy_shape) == 4 + sympy_shape[-2:] = [self._new_symbolic_dim_from_output(node, 0, i) for i in {2, 3}] + self._update_computed_dims(sympy_shape) + for i, o in enumerate(node.output): + if not o: + continue + vi = self.known_vi_[o] + elem_type = ( + onnx.TensorProto.INT64 + if i == 1 + else self.known_vi_[node.input[0]].type.tensor_type.elem_type + ) + vi.CopyFrom( + helper.make_tensor_value_info( + o, elem_type, get_shape_from_sympy_shape(sympy_shape) + ) + ) + + def _infer_aten_minmax(self, node): + """Infer the output shape and type for the ATen MinMax operation in an ONNX node.""" + vi = self.known_vi_[node.output[0]] + if len(node.input) == 1: + vi.CopyFrom( + helper.make_tensor_value_info( + node.output[0], + self.known_vi_[node.input[0]].type.tensor_type.elem_type, + [], + ) + ) + else: + assert len(node.input) == 3 + keepdim = self._try_get_value(node, 2) + assert keepdim is not None # can only handle known keepdim case. + dim = self._try_get_value(node, 1) + if dim is None: + rank = self._get_shape_rank(node, 0) + output_shape = self._new_symbolic_shape(rank if keepdim else rank - 1, node) + else: + shape = self._get_sympy_shape(node, 0) + dim = handle_negative_axis(dim, len(shape)) + output_shape = shape[:dim] + if keepdim: + output_shape += [1] + output_shape += shape[dim + 1 :] + + output_shape = get_shape_from_sympy_shape(output_shape) + vi.CopyFrom( + helper.make_tensor_value_info( + node.output[0], + self.known_vi_[node.input[0]].type.tensor_type.elem_type, + output_shape, + ) + ) + vi1 = self.known_vi_[node.output[1]] + vi1.CopyFrom( + helper.make_tensor_value_info( + node.output[1], onnx.TensorProto.INT64, output_shape + ) + ) + + def _infer_aten_unfold(self, node): + """Infer the tensor shape for the 'aten::unfold' operation based on input shape and parameters dimension, size, and step.""" + sympy_shape = self._get_sympy_shape(node, 0) + dimension = self._try_get_value(node, 1) + size = self._try_get_value(node, 2) + step = self._try_get_value(node, 3) + if dimension is not None and size is not None and step is not None: + assert dimension < len(sympy_shape) + sympy_shape[dimension] = (sympy_shape[dimension] - size) // step + 1 + sympy_shape.append(size) + else: + rank = len(sympy_shape) + sympy_shape = self._new_symbolic_shape(rank + 1, node) + self._update_computed_dims(sympy_shape) + if node.output[0]: + vi = self.known_vi_[node.output[0]] + vi.CopyFrom( + helper.make_tensor_value_info( + node.output[0], + self.known_vi_[node.input[0]].type.tensor_type.elem_type, + get_shape_from_sympy_shape(sympy_shape), + ) + ) + + def _infer_aten_argmax(self, node): + """Infers the output shape for the ONNX ATen argmax operation.""" + new_shape = None + if not node.input[1]: + # The argmax of the flattened input is returned. + new_shape = [] + else: + dim = self._try_get_value(node, 1) + keepdim = self._try_get_value(node, 2) + if keepdim is not None: + sympy_shape = self._get_sympy_shape(node, 0) + if dim is not None: + dim = handle_negative_axis(dim, len(sympy_shape)) + if keepdim: + sympy_shape[dim] = 1 + else: + del sympy_shape[dim] + else: + rank = len(sympy_shape) + sympy_shape = self._new_symbolic_shape(rank if keepdim else rank - 1, node) + self._update_computed_dims(sympy_shape) + new_shape = get_shape_from_sympy_shape(sympy_shape) + if node.output[0] and new_shape is not None: + vi = self.known_vi_[node.output[0]] + vi.CopyFrom( + helper.make_tensor_value_info( + node.output[0], onnx.TensorProto.INT64, new_shape + ) + ) + + def _infer_aten_group_norm(self, node): + """Infers the output shapes and types for the ATen GroupNorm operation based on the provided node + information. + """ + self._propagate_shape_and_type(node) + input_shape = self._get_shape(node, 0) + N = input_shape[0] if input_shape is not None and len(input_shape) != 0 else None + group = self._try_get_value(node, 6) + output_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type + for i in {1, 2}: + if node.output[i]: + vi = self.known_vi_[node.output[i]] + vi.CopyFrom( + helper.make_tensor_value_info( + node.output[i], + output_dtype, + [ + ( + N + if N is not None + else str(self._new_symbolic_dim_from_output(node, i, 0)) + ), + ( + as_scalar(group) + if group is not None + else str(self._new_symbolic_dim_from_output(node, i, 1)) + ), + ], + ) + ) + + def _infer_aten_upsample(self, node): + """Infers the output shape for an aten::upsample operation based on the input shape and specified upsampling parameters.""" + new_shape = None + input_shape = self._get_shape(node, 0) + if input_shape is not None: + new_shape = input_shape[:2] + output_size = self._try_get_value(node, 1) + if output_size is not None: + new_shape += [ + dim_size.item() if type(dim_size) == np.int64 else dim_size + for dim_size in output_size + ] + else: + rank = len(input_shape) + new_shape += [ + str(self._new_symbolic_dim_from_output(node, 0, i)) for i in range(2, rank) + ] + if node.output[0] and new_shape is not None: + output_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type + vi = self.known_vi_[node.output[0]] + vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, new_shape)) + + def _infer_BatchNormalization(self, node): + """Propagate the shape and type information for the BatchNormalization node.""" + self._propagate_shape_and_type(node) + + # this works for opsets < 14 and 14 since we check i < len(node.output) in the loop + for i in {1, 2, 3, 4}: + if i < len(node.output) and node.output[i]: + # all of these parameters have the same shape as the 1st input + self._propagate_shape_and_type(node, input_index=1, output_index=i) + + def _infer_Range(self, node): + """Infers the shape and type for Range nodes based on the provided start, limit, and delta values.""" + vi = self.known_vi_[node.output[0]] + input_data = self._get_int_or_float_values(node) + if all(i is not None for i in input_data): + start = as_scalar(input_data[0]) + limit = as_scalar(input_data[1]) + delta = as_scalar(input_data[2]) + new_sympy_shape = [sympy.Max(sympy.ceiling((limit - start) / delta), 0)] + else: + new_sympy_shape = [self._new_symbolic_dim_from_output(node)] + self._update_computed_dims(new_sympy_shape) + vi.CopyFrom( + helper.make_tensor_value_info( + node.output[0], + self.known_vi_[node.input[0]].type.tensor_type.elem_type, + get_shape_from_sympy_shape(new_sympy_shape), + ) + ) + + def _infer_ReduceSum(self, node): + """Infer output shape for ReduceSum operation based on input shape, axes, and keep_dims attribute.""" + keep_dims = get_attribute(node, "keepdims", 1) + if get_opset(self.out_mp_) >= 13 and len(node.input) > 1: + # ReduceSum changes axes to input[1] in opset 13 + axes = self._try_get_value(node, 1) + vi = self.known_vi_[node.output[0]] + if axes is None: + assert keep_dims # can only handle keep_dims==True when axes is unknown, by generating new ranks + vi.CopyFrom( + helper.make_tensor_value_info( + node.output[0], + self.known_vi_[node.input[0]].type.tensor_type.elem_type, + get_shape_from_sympy_shape( + self._new_symbolic_shape(self._get_shape_rank(node, 0), node) + ), + ) + ) + else: + shape = self._get_shape(node, 0) + output_shape = [] + axes = [handle_negative_axis(a, len(shape)) for a in axes] + for i, d in enumerate(shape): + if i in axes: + if keep_dims: + output_shape.append(1) + else: + output_shape.append(d) + vi.CopyFrom( + helper.make_tensor_value_info( + node.output[0], + self.known_vi_[node.input[0]].type.tensor_type.elem_type, + output_shape, + ) + ) + + def _infer_ReduceProd(self, node): + """Infer the ReduceProd operation on a node, considering axes and keep dimensions attributes.""" + axes = get_attribute(node, "axes") + keep_dims = get_attribute(node, "keepdims", 1) + if keep_dims == 0 and axes == [0]: + data = self._get_int_or_float_values(node)[0] + if data is not None: + self.sympy_data_[node.output[0]] = sympy_reduce_product(data) + + def _infer_RelativePositionBias(self, node): + """Infers the relative position bias for a given ONNX node.""" + seq_len = self._try_get_value(node, 1) + real_seq_len = self._try_get_value(node, 2) + if seq_len is None or real_seq_len is None: + return + num_heads = self._get_sympy_shape(node, 0)[1] + + new_shape = [1, num_heads, str(seq_len), str(real_seq_len)] + + output_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type + vi = self.known_vi_[node.output[0]] + vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, new_shape)) + + def _infer_Reshape(self, node): + """Infer the output shape for the Reshape operation based on the provided input shape and reshape parameters.""" + shape_value = self._try_get_value(node, 1) + vi = self.known_vi_[node.output[0]] + if shape_value is None: + shape_shape = self._get_shape(node, 1) + assert len(shape_shape) == 1 + shape_rank = shape_shape[0] + assert is_literal(shape_rank) + vi.CopyFrom( + helper.make_tensor_value_info( + node.output[0], + vi.type.tensor_type.elem_type, + get_shape_from_sympy_shape(self._new_symbolic_shape(shape_rank, node)), + ) + ) + else: + input_sympy_shape = self._get_sympy_shape(node, 0) + total = 1 + for d in input_sympy_shape: + total = total * d + new_sympy_shape = [] + deferred_dim_idx = -1 + non_deferred_size = 1 + for i, d in enumerate(shape_value): + if type(d) == sympy.Symbol or d != 0: + new_sympy_shape.append(d) + else: + new_sympy_shape.append(input_sympy_shape[i]) + non_deferred_size = non_deferred_size * input_sympy_shape[i] + if d == -1: + deferred_dim_idx = i + elif d != 0: + non_deferred_size = non_deferred_size * d + + assert new_sympy_shape.count(-1) < 2 + if -1 in new_sympy_shape: + new_dim = total // non_deferred_size + new_sympy_shape[deferred_dim_idx] = new_dim + + self._update_computed_dims(new_sympy_shape) + vi.CopyFrom( + helper.make_tensor_value_info( + node.output[0], + vi.type.tensor_type.elem_type, + get_shape_from_sympy_shape(new_sympy_shape), + ) + ) + + self._pass_on_sympy_data(node) + + def _infer_Resize(self, node): + """Infers and updates the shape of the output tensor for a Resize node based on scales or sizes.""" + vi = self.known_vi_[node.output[0]] + input_sympy_shape = self._get_sympy_shape(node, 0) + if get_opset(self.out_mp_) <= 10: + scales = self._try_get_value(node, 1) + if scales is not None: + new_sympy_shape = [ + sympy.simplify(sympy.floor(d * s)) + for d, s in zip(input_sympy_shape, scales) + ] + self._update_computed_dims(new_sympy_shape) + vi.CopyFrom( + helper.make_tensor_value_info( + node.output[0], + self.known_vi_[node.input[0]].type.tensor_type.elem_type, + get_shape_from_sympy_shape(new_sympy_shape), + ) + ) + else: + roi = self._try_get_value(node, 1) + scales = self._try_get_value(node, 2) + sizes = self._try_get_value(node, 3) + if sizes is not None: + new_sympy_shape = [sympy.simplify(sympy.floor(s)) for s in sizes] + self._update_computed_dims(new_sympy_shape) + elif scales is not None: + rank = len(scales) + if ( + get_attribute(node, "coordinate_transformation_mode") + == "tf_crop_and_resize" + ): + assert len(roi) == 2 * rank + roi_start = list(roi)[:rank] + roi_end = list(roi)[rank:] + else: + roi_start = [0] * rank + roi_end = [1] * rank + if isinstance(scales, np.ndarray): + scales = scales.tolist() + else: + scales = list(scales) + new_sympy_shape = [ + (sympy.floor(d * (end - start) * scale)) + for d, start, end, scale in zip( + input_sympy_shape, roi_start, roi_end, scales + ) + ] + self._update_computed_dims(new_sympy_shape) + else: + new_sympy_shape = self._new_symbolic_shape(self._get_shape_rank(node, 0), node) + + vi.CopyFrom( + helper.make_tensor_value_info( + node.output[0], + self.known_vi_[node.input[0]].type.tensor_type.elem_type, + get_shape_from_sympy_shape(new_sympy_shape), + ) + ) + + def _infer_Scan(self, node): + """Infer shape and type information for the ONNX 'Scan' operator node.""" + subgraph = get_attribute(node, "body") + num_scan_inputs = get_attribute(node, "num_scan_inputs") + scan_input_axes = get_attribute(node, "scan_input_axes", [0] * num_scan_inputs) + num_scan_states = len(node.input) - num_scan_inputs + scan_input_axes = [ + handle_negative_axis(ax, self._get_shape_rank(node, i + num_scan_states)) + for i, ax in enumerate(scan_input_axes) + ] + # We may have cases where the subgraph has optional inputs that appear in both subgraph's input and initializer, + # but not in the node's input. In such cases, the input model might be invalid, but let's skip those optional inputs. + assert len(subgraph.input) >= len(node.input) + subgraph_inputs = subgraph.input[: len(node.input)] + for i, si in enumerate(subgraph_inputs): + subgraph_name = si.name + si.CopyFrom(self.known_vi_[node.input[i]]) + if i >= num_scan_states: + scan_input_dim = si.type.tensor_type.shape.dim + scan_input_dim.remove(scan_input_dim[scan_input_axes[i - num_scan_states]]) + si.name = subgraph_name + self._onnx_infer_subgraph(node, subgraph) + num_scan_outputs = len(node.output) - num_scan_states + scan_output_axes = get_attribute(node, "scan_output_axes", [0] * num_scan_outputs) + scan_input_dim = get_shape_from_type_proto(self.known_vi_[node.input[-1]].type)[ + scan_input_axes[-1] + ] + for i, o in enumerate(node.output): + vi = self.known_vi_[o] + if i >= num_scan_states: + shape = get_shape_from_type_proto(subgraph.output[i].type) + new_dim = handle_negative_axis( + scan_output_axes[i - num_scan_states], len(shape) + 1 + ) + shape = shape[:new_dim] + [scan_input_dim] + shape[new_dim:] + vi.CopyFrom( + helper.make_tensor_value_info( + o, subgraph.output[i].type.tensor_type.elem_type, shape + ) + ) + else: + vi.CopyFrom(subgraph.output[i]) + vi.name = o + + def _infer_ScatterElements(self, node): + """Infer the output shape and type for ScatterElements node and update known value infos.""" + data_shape = self._get_shape(node, 0) + vi = self.known_vi_[node.output[0]] + vi.CopyFrom( + helper.make_tensor_value_info( + node.output[0], + self.known_vi_[node.input[0]].type.tensor_type.elem_type, + data_shape, + ) + ) + + def _infer_SequenceAt(self, node): + """Infers the shape and type for the output of the 'SequenceAt' ONNX operation, handling symbolic dimensions if + necessary. + """ + seq_shape = self._get_shape(node, 0) + if seq_shape is not None: + vi = self.known_vi_[node.output[0]] + for di, d in enumerate(seq_shape): + if d is not None: + continue + new_dim = onnx.TensorShapeProto.Dimension() + new_dim.dim_param = str(self._new_symbolic_dim_from_output(node, 0, di)) + vi.type.tensor_type.shape.dim[di].CopyFrom(new_dim) + + def _infer_SequenceInsert(self, node): + """Workaround ONNX's shape inference bug by inferring sequence insertion shapes and types for the provided + node. + """ + vi_seq = self.known_vi_[node.input[0]] + vi_tensor = self.known_vi_[node.input[1]] + vi_out_seq = self.known_vi_[node.output[0]] + vi_out_seq.CopyFrom(vi_seq) + vi_out_seq.name = node.output[0] + self._fuse_tensor_type(node, 0, vi_out_seq.type, vi_tensor.type) + + def _infer_Shape(self, node): + """Infers and sets the symbolic shape for the output node in the computation graph.""" + self.sympy_data_[node.output[0]] = self._get_sympy_shape(node, 0) + + def _infer_Size(self, node): + """Infers and sets the size of the output node by computing the product of its shape in the computation + graph. + """ + sympy_shape = self._get_sympy_shape(node, 0) + self.sympy_data_[node.output[0]] = sympy_reduce_product(sympy_shape) + self.known_vi_[node.output[0]].CopyFrom( + helper.make_tensor_value_info(node.output[0], onnx.TensorProto.INT64, []) + ) + + def _infer_Slice(self, node): + """Infer the shape and value information for the Slice node using SymPy and ONNX helper methods.""" + + # even when the relation holds for both `a` and `b`. + # + # When given `expr` of form `min(a, b) + ...`, this function returns `[a + ..., b + ...]`, + # so that we can prove inequalities for both expressions separately. + # + # If the number of `min(...)` subexpressions is not exactly one, this function just returns `[expr]`. + def flatten_min(expr): + """Returns a list with expressions split by min() for inequality proof or original expr if no single min() + found. + """ + assert isinstance(expr, sympy.Add), f"Expected a sum of two arguments, got {expr}" + min_positions = [ + idx for idx in range(len(expr.args)) if isinstance(expr.args[idx], sympy.Min) + ] + if len(min_positions) == 1: + min_pos = min_positions[0] + + def replace_min_with_arg(arg_idx): + """Replace the sympy.Min() function at a specified position in a sympy.Add() expression with one of + its arguments. + """ + replaced = list(expr.args) + assert isinstance(replaced[min_pos], sympy.Min), ( + f"Expected a sympy.Min() at position {min_pos}, got {replaced[min_pos]}" + ) + assert len(replaced[min_pos].args) == 2, ( + f"Expected a sympy.Min() with exactly 2 arguments, got {replaced[min_pos]}" + ) + replaced[min_pos] = replaced[min_pos].args[arg_idx] + return sympy.Add(*replaced) + + return [ + replace_min_with_arg(0), + replace_min_with_arg(1), + ] + return [expr] + + def less_equal(x, y): + """Returns True if x is less than or equal to y, otherwise False.""" + try: + return x <= y + except TypeError: + pass + try: + return y >= x + except TypeError: + pass + try: + return -x >= -y + except TypeError: + pass + try: + return -y <= -x + except TypeError: + pass + try: + return y - x >= 0 + except TypeError: + # the last attempt; this may raise TypeError + return all(d >= 0 for d in flatten_min(y - x)) + + def handle_negative_index(index, bound): + """Normalizes a negative index to be in [0, bound).""" + try: + if not less_equal(0, index): + if is_literal(index) and index <= -self.int_max_: + # this case is handled separately + return index + return bound + index + except TypeError: + logger.warning(f"Cannot determine if {index} < 0") + return index + + if get_opset(self.out_mp_) <= 9: + axes = get_attribute(node, "axes") + starts = get_attribute(node, "starts") + ends = get_attribute(node, "ends") + if not axes: + axes = list(range(len(starts))) + steps = [1] * len(axes) + else: + starts = as_list(self._try_get_value(node, 1), keep_none=True) + ends = as_list(self._try_get_value(node, 2), keep_none=True) + axes = self._try_get_value(node, 3) + steps = self._try_get_value(node, 4) + if axes is None and (starts is not None or ends is not None): + axes = list(range(len(starts if starts is not None else ends))) + if steps is None and (starts is not None or ends is not None): + steps = [1] * len(starts if starts is not None else ends) + axes = as_list(axes, keep_none=True) + steps = as_list(steps, keep_none=True) + + new_sympy_shape = self._get_sympy_shape(node, 0) + if starts is None or ends is None: + if axes is None: + for i in range(len(new_sympy_shape)): + new_sympy_shape[i] = self._new_symbolic_dim_from_output(node, 0, i) + else: + new_sympy_shape = get_shape_from_sympy_shape(new_sympy_shape) + for i in axes: + new_sympy_shape[i] = self._new_symbolic_dim_from_output(node, 0, i) + else: + for i, s, e, t in zip(axes, starts, ends, steps): + if is_literal(e): + e = handle_negative_index(e, new_sympy_shape[i]) + if is_literal(e): + if e >= self.int_max_: + e = new_sympy_shape[i] + elif e <= -self.int_max_: + e = 0 if s > 0 else -1 + elif is_literal(new_sympy_shape[i]): + if e < 0: + e = max(0, e + new_sympy_shape[i]) + e = min(e, new_sympy_shape[i]) + else: + if e > 0: + e = ( + sympy.Min(e, new_sympy_shape[i]) if e > 1 else e + ) # special case for slicing first to make computation easier + else: + if is_literal(new_sympy_shape[i]): + if new_sympy_shape[i] < 0: + e = sympy.Min(e, new_sympy_shape[i]) + else: + try: + if not less_equal(e, new_sympy_shape[i]): + e = new_sympy_shape[i] + except Exception: + logger.warning( + f"Unable to determine if {e} <= {new_sympy_shape[i]}, treat as equal" + ) + e = new_sympy_shape[i] + + s = handle_negative_index(s, new_sympy_shape[i]) + if is_literal(new_sympy_shape[i]) and is_literal(s): + s = max(0, min(s, new_sympy_shape[i])) + + new_sympy_shape[i] = sympy.simplify((e - s + t + (-1 if t > 0 else 1)) // t) + + self._update_computed_dims(new_sympy_shape) + + vi = self.known_vi_[node.output[0]] + vi.CopyFrom( + helper.make_tensor_value_info( + node.output[0], + vi.type.tensor_type.elem_type, + get_shape_from_sympy_shape(new_sympy_shape), + ) + ) + + # handle sympy_data if needed, for slice in shape computation + if ( + node.input[0] in self.sympy_data_ + and [0] == axes + and starts is not None + and len(starts) == 1 + and ends is not None + and len(ends) == 1 + and steps is not None + and len(steps) == 1 + ): + input_sympy_data = self.sympy_data_[node.input[0]] + if type(input_sympy_data) == list or ( + type(input_sympy_data) == np.array and len(input_sympy_data.shape) == 1 + ): + self.sympy_data_[node.output[0]] = input_sympy_data[ + starts[0] : ends[0] : steps[0] + ] + + def _infer_SoftmaxCrossEntropyLoss(self, node): + """Infer the softmax cross-entropy loss for a given node in the computation graph.""" + vi = self.known_vi_[node.output[0]] + elem_type = self.known_vi_[node.input[0]].type.tensor_type.elem_type + + # If output type is explicit specified in attribute, we use it as output tensor type. + specified_output_type = get_attribute(node, "output_type", None) + if specified_output_type is not None: + elem_type = specified_output_type + + vi.type.tensor_type.elem_type = elem_type + vi.type.tensor_type.shape.CopyFrom(onnx.TensorShapeProto()) + + if len(node.output) > 1: + data_shape = self._get_shape(node, 0) + vi = self.known_vi_[node.output[1]] + vi.CopyFrom(helper.make_tensor_value_info(vi.name, elem_type, data_shape)) + + def _infer_Split_Common(self, node, make_value_info_func): + """Infers the output shape for the Split operator given an ONNX node and a function to create tensor value + info. + """ + input_sympy_shape = self._get_sympy_shape(node, 0) + axis = handle_negative_axis(get_attribute(node, "axis", 0), len(input_sympy_shape)) + op_set = get_opset(self.out_mp_) + + # Depending on op-version 'split' are provided as attribute or via 2nd input + if op_set < 13: + split = get_attribute(node, "split") + assert self._try_get_value(node, 1) is None + else: + split = self._try_get_value(node, 1) + assert get_attribute(node, "split") is None + + if split is None: + num_outputs = len(node.output) + split = [input_sympy_shape[axis] / sympy.Integer(num_outputs)] * num_outputs + self._update_computed_dims(split) + else: + split = [sympy.Integer(s) for s in split] + + for i_o in range(len(split)): + vi = self.known_vi_[node.output[i_o]] + vi.CopyFrom( + make_value_info_func( + node.output[i_o], + self.known_vi_[node.input[0]].type.tensor_type.elem_type, + get_shape_from_sympy_shape( + input_sympy_shape[:axis] + [split[i_o]] + input_sympy_shape[axis + 1 :] + ), + ) + ) + self.known_vi_[vi.name] = vi + + def _infer_Split(self, node): + """Infers the output shapes and types for the Split operation node.""" + self._infer_Split_Common(node, helper.make_tensor_value_info) + + def _infer_SplitToSequence(self, node): + """Infers the output shapes and types for the SplitToSequence operation node.""" + self._infer_Split_Common(node, helper.make_sequence_value_info) + + def _infer_Squeeze(self, node): + """Infers the output shapes and types for the Squeeze operation node.""" + input_shape = self._get_shape(node, 0) + op_set = get_opset(self.out_mp_) + + # Depending on op-version 'axes' are provided as attribute or via 2nd input + if op_set < 13: + axes = get_attribute(node, "axes") + assert self._try_get_value(node, 1) is None + else: + axes = self._try_get_value(node, 1) + assert get_attribute(node, "axes") is None + + if axes is None: + # No axes have been provided (neither via attribute nor via input). + # In this case the 'Shape' op should remove all axis with dimension 1. + # For symbolic dimensions we guess they are !=1. + output_shape = [s for s in input_shape if s != 1] + if self.verbose_ > 0: + symbolic_dimensions = [s for s in input_shape if type(s) != int] + if symbolic_dimensions: + logger.debug( + f"Symbolic dimensions in input shape of op: '{node.op_type}' node: '{node.name}'. " + f"Assuming the following dimensions are never equal to 1: {symbolic_dimensions}" + ) + else: + axes = [handle_negative_axis(a, len(input_shape)) for a in axes] + output_shape = [] + for i in range(len(input_shape)): + if i not in axes: + output_shape.append(input_shape[i]) + else: + assert input_shape[i] == 1 or type(input_shape[i]) != int + if self.verbose_ > 0 and type(input_shape[i]) != int: + logger.debug( + f"Symbolic dimensions in input shape of op: '{node.op_type}' node: '{node.name}'. " + f"Assuming the dimension '{input_shape[i]}' at index {i} of the input to be equal to 1." + ) + + vi = self.known_vi_[node.output[0]] + vi.CopyFrom( + helper.make_tensor_value_info( + node.output[0], + self.known_vi_[node.input[0]].type.tensor_type.elem_type, + output_shape, + ) + ) + self._pass_on_sympy_data(node) + + def _infer_Tile(self, node): + """Infers the output shape for the Tile operation in a computation graph based on input shape and repeat + values. + """ + repeats_value = self._try_get_value(node, 1) + new_sympy_shape = [] + if repeats_value is not None: + input_sympy_shape = self._get_sympy_shape(node, 0) + for i, d in enumerate(input_sympy_shape): + new_dim = d * repeats_value[i] + new_sympy_shape.append(new_dim) + self._update_computed_dims(new_sympy_shape) + else: + new_sympy_shape = self._new_symbolic_shape(self._get_shape_rank(node, 0), node) + vi = self.known_vi_[node.output[0]] + vi.CopyFrom( + helper.make_tensor_value_info( + node.output[0], + vi.type.tensor_type.elem_type, + get_shape_from_sympy_shape(new_sympy_shape), + ) + ) + + def _infer_TopK(self, node): + """Infers the output shape for the TopK operation in an ONNX graph node based on input shape and specified + axis. + """ + rank = self._get_shape_rank(node, 0) + axis = handle_negative_axis(get_attribute(node, "axis", -1), rank) + new_shape = self._get_shape(node, 0) + + if get_opset(self.out_mp_) <= 9: + k = get_attribute(node, "k") + else: + k = self._get_int_or_float_values(node)[1] + + k = self._new_symbolic_dim_from_output(node) if k is None else as_scalar(k) + if type(k) in {int, str}: + new_shape[axis] = k + else: + new_sympy_shape = self._get_sympy_shape(node, 0) + new_sympy_shape[axis] = k + self._update_computed_dims( + new_sympy_shape + ) # note that TopK dim could be computed in sympy_data, so need to update computed_dims when it enters shape + new_shape = get_shape_from_sympy_shape(new_sympy_shape) + + for i_o in range(len(node.output)): + vi = self.known_vi_[node.output[i_o]] + vi.CopyFrom( + helper.make_tensor_value_info( + node.output[i_o], vi.type.tensor_type.elem_type, new_shape + ) + ) + + def _infer_Transpose(self, node): + """Infer and update the shape information for a Transpose node based on its input shape and permutation + attributes. + """ + if node.input[0] in self.sympy_data_: + data_shape = self._get_shape(node, 0) + perm = get_attribute(node, "perm", reversed(list(range(len(data_shape))))) + input_data = self.sympy_data_[node.input[0]] + self.sympy_data_[node.output[0]] = ( + np.transpose(np.array(input_data).reshape(*data_shape), axes=tuple(perm)) + .flatten() + .tolist() + ) + + def _infer_Unsqueeze(self, node): + """Infers the output shape for the Unsqueeze operation based on the input shape and operator set.""" + input_shape = self._get_shape(node, 0) + op_set = get_opset(self.out_mp_) + + # Depending on op-version 'axes' are provided as attribute or via 2nd input + if op_set < 13: + axes = get_attribute(node, "axes") + assert self._try_get_value(node, 1) is None + else: + axes = self._try_get_value(node, 1) + assert get_attribute(node, "axes") is None + + output_rank = len(input_shape) + len(axes) + axes = [handle_negative_axis(a, output_rank) for a in axes] + + input_axis = 0 + output_shape = [] + for i in range(output_rank): + if i in axes: + output_shape.append(1) + else: + output_shape.append(input_shape[input_axis]) + input_axis += 1 + + vi = self.known_vi_[node.output[0]] + vi.CopyFrom( + helper.make_tensor_value_info( + node.output[0], + self.known_vi_[node.input[0]].type.tensor_type.elem_type, + output_shape, + ) + ) + + self._pass_on_sympy_data(node) + + def _infer_ZipMap(self, node): + """Infer the type of keys for a ZipMap node based on its class labels attribute.""" + map_key_type = None + if get_attribute(node, "classlabels_int64s") is not None: + map_key_type = onnx.TensorProto.INT64 + elif get_attribute(node, "classlabels_strings") is not None: + map_key_type = onnx.TensorProto.STRING + + assert map_key_type is not None + new_vi = onnx.ValueInfoProto() + new_vi.name = node.output[0] + new_vi.type.sequence_type.elem_type.map_type.value_type.tensor_type.elem_type = ( + onnx.TensorProto.FLOAT + ) + new_vi.type.sequence_type.elem_type.map_type.key_type = map_key_type + vi = self.known_vi_[node.output[0]] + vi.CopyFrom(new_vi) + + def _infer_Attention(self, node): + """Infer shape and data type for ONNX Attention node outputs given input shapes and attributes.""" + shape = self._get_shape(node, 0) + shape_weights = self._get_shape(node, 1) + shape_bias = self._try_get_shape(node, 2) + if shape_bias is not None: + assert len(shape_bias) == 1 + tripled_hidden_size = shape_bias[0] if shape_bias is not None else shape_weights[1] + if shape and len(shape) == 3: + qkv_hidden_sizes_attr = get_attribute(node, "qkv_hidden_sizes") + if qkv_hidden_sizes_attr is not None: + assert len(qkv_hidden_sizes_attr) == 3 + shape[2] = int(qkv_hidden_sizes_attr[2]) + elif isinstance(tripled_hidden_size, int): + shape[2] = int(tripled_hidden_size / 3) + output_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type + vi = self.known_vi_[node.output[0]] + vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, shape)) + + if len(node.output) > 1: + # input shape: (batch_size, sequence_length, hidden_size) + # past shape: (2, batch_size, num_heads, past_sequence_length, head_size) + # mask shape: (batch_size, total_sequence_length) or (batch_size, sequence_length, total_sequence_length) or (batch_size, 1, max_seq_len, max_seq_len) + # present shape: (2, batch_size, num_heads, total_sequence_length, head_size), where total_sequence_length=sequence_length+past_sequence_length + input_shape = self._get_shape(node, 0) + past_shape = ( + self._get_shape(node, 4) if len(node.input) > 4 and node.input[4] else [] + ) + mask_shape = ( + self._get_shape(node, 3) if len(node.input) > 3 and node.input[3] else [] + ) + + if past_shape and len(past_shape) == 5: + if mask_shape and len(mask_shape) in {2, 3}: + past_shape[3] = mask_shape[-1] + elif input_shape and len(input_shape) == 3: + if isinstance(input_shape[1], int) and isinstance(past_shape[3], int): + past_shape[3] = input_shape[1] + past_shape[3] + else: + past_shape[3] = f"{past_shape[3]}+{input_shape[1]}" + vi = self.known_vi_[node.output[1]] + vi.CopyFrom( + helper.make_tensor_value_info(vi.name, output_dtype, past_shape) + ) + else: + num_heads = get_attribute(node, "num_heads") + head_size = input_shape[2] // num_heads + present_shape = [ + 2, + input_shape[0], + num_heads, + input_shape[1], + head_size, + ] + vi = self.known_vi_[node.output[1]] + vi.CopyFrom( + helper.make_tensor_value_info(vi.name, output_dtype, present_shape) + ) + + def _infer_GatedRelativePositionBias(self, node): + """Infer the shape for gated relative position bias given the node attributes.""" + # query_layer: (token_count, num_heads x head_size) + # token_offset: (batch_size, seq_len) + # Otherwise: + # query_layer: (batch_size, seq_len, num_heads x head_size) + # token_offset: None + # Output shape: (batch_size, num_heads, seq_len, seq_len) + num_heads = get_attribute(node, "num_heads") + + token_offset_shape = self._try_get_shape(node, 6) + if token_offset_shape is not None: + output_shape = [ + token_offset_shape[0], + num_heads, + token_offset_shape[1], + token_offset_shape[1], + ] + else: + query_layer_shape = self._get_shape(node, 0) + assert query_layer_shape is not None and len(query_layer_shape) == 3 + output_shape = [ + query_layer_shape[0], + num_heads, + query_layer_shape[1], + query_layer_shape[1], + ] + + output_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type + vi = self.known_vi_[node.output[0]] + vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, output_shape)) + + def _infer_PackedAttention(self, node): + """Infer shape and data type for PackedAttention nodes in a given computational graph.""" + shape = self._get_shape(node, 0) + shape_weights = self._get_shape(node, 1) + shape_bias = self._try_get_shape(node, 2) + if shape_bias is not None: + assert len(shape_bias) == 1 + tripled_hidden_size = shape_bias[0] if shape_bias is not None else shape_weights[1] + if shape and len(shape) == 2: + qkv_hidden_sizes_attr = get_attribute(node, "qkv_hidden_sizes") + if qkv_hidden_sizes_attr is not None: + assert len(qkv_hidden_sizes_attr) == 3 + shape[1] = int(qkv_hidden_sizes_attr[2]) + elif isinstance(tripled_hidden_size, int): + shape[1] = int(tripled_hidden_size / 3) + output_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type + vi = self.known_vi_[node.output[0]] + vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, shape)) + + def _infer_PackedMultiHeadAttention(self, node): + """Infer the output shape for PackedMultiHeadAttention node in the computational graph.""" + shape_value = self._try_get_shape(node, 2) + if shape_value is not None and len(shape_value) == 2: + output_shape = shape_value + else: + shape_query = self._get_shape(node, 0) + assert shape_query is not None and len(shape_query) == 4 + output_shape = [shape_query[0], shape_query[1] * shape_query[3]] + + output_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type + vi = self.known_vi_[node.output[0]] + vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, output_shape)) + + def _infer_MultiScaleDeformableAttnTRT(self, node): + shape_value = self._try_get_shape(node, 0) + sampling_locations = self._try_get_shape(node, 3) + output_shape = shape_value + output_shape[1] = sampling_locations[1] + output_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type + vi = self.known_vi_[node.output[0]] + vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, output_shape)) + + def _infer_RemovePadding(self, node): + """Infers the shape and data type for the output tensor after removing padding.""" + shape = self._get_shape(node, 0) + if shape and len(shape) == 3: + output_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type + vi = self.known_vi_[node.output[0]] + vi.CopyFrom( + helper.make_tensor_value_info( + node.output[0], output_dtype, ["token_count", shape[2]] + ) + ) + + vi_token_offset = self.known_vi_[node.output[1]] + vi_token_offset.CopyFrom( + helper.make_tensor_value_info( + node.output[1], onnx.TensorProto.INT32, [shape[0], shape[1]] + ) + ) + + vi_cumulated_seq_len = self.known_vi_[node.output[2]] + vi_cumulated_seq_len.CopyFrom( + helper.make_tensor_value_info( + node.output[2], onnx.TensorProto.INT32, ["batch_size + 1"] + ) + ) + + vi_max_seq_len = self.known_vi_[node.output[3]] + vi_max_seq_len.CopyFrom( + helper.make_tensor_value_info(node.output[3], onnx.TensorProto.INT32, [1]) + ) + + def _infer_RestorePadding(self, node): + """Infers the output shape and type for the RestorePadding operation.""" + shape_input = self._get_shape(node, 0) + shape_token_offset = self._get_shape(node, 1) + if ( + shape_input + and len(shape_input) == 2 + and shape_token_offset + and len(shape_token_offset) == 2 + ): + output_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type + vi = self.known_vi_[node.output[0]] + + output_shape = [ + shape_token_offset[0], + shape_token_offset[1], + shape_input[1], + ] + vi.CopyFrom( + helper.make_tensor_value_info(node.output[0], output_dtype, output_shape) + ) + + def _infer_BiasGelu(self, node): + """Propagate shape and type information for BiasGelu node during inference.""" + self._propagate_shape_and_type(node) + + def _infer_MultiHeadAttention(self, node): + """Propagate shape and type information for MultiHeadAttention node during inference.""" + # Q, K and V without packing: + # Input 0 (query) has shape (batch_size, sequence_length, hidden_size) + # Input 1 (key) has shape (batch_size, kv_sequence_length, hidden_size) or (batch_size, num_heads, kv_sequence_length, head_size) + # Input 2 (value) has shape (batch_size, kv_sequence_length, v_hidden_size) or (batch_size, num_heads, kv_sequence_length, head_size) + # Packed KV: + # Input 0 (query) has shape (batch_size, sequence_length, hidden_size) + # Input 1 (batch_size, kv_sequence_length, num_heads, 2, head_size) + # Input 2 nullptr + # Packed QKV: + # Input 0 (batch_size, sequence_length, num_heads, 3, head_size) + # Input 1 nullptr + # Input 2 nullptr + + query_shape = self._get_shape(node, 0) + total_sequence_length = None + output_dtype = None + if query_shape is not None: + if len(query_shape) == 3: + key_shape = self._try_get_shape(node, 1) + # By default, hidden size is same for Q/K/V. Only need check v_hidden_size when value is provided. + output_shape = query_shape + if key_shape is not None and len(key_shape) == 3: + value_shape = self._try_get_shape(node, 2) + if value_shape is not None and len(value_shape) == 3: + output_shape[2] = value_shape[2] + total_sequence_length = key_shape[1] + + output_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type + vi = self.known_vi_[node.output[0]] + vi.CopyFrom( + helper.make_tensor_value_info(node.output[0], output_dtype, output_shape) + ) + + elif len(query_shape) == 5: + if isinstance(query_shape[2], int) and isinstance(query_shape[4], int): + output_shape = [ + query_shape[0], + query_shape[1], + query_shape[2] * query_shape[4], + ] + else: + output_shape = [ + query_shape[0], + query_shape[1], + f"{query_shape[2]}*{query_shape[4]}", + ] + + total_sequence_length = query_shape[1] + + output_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type + vi = self.known_vi_[node.output[0]] + vi.CopyFrom( + helper.make_tensor_value_info(node.output[0], output_dtype, output_shape) + ) + + if len(node.output) > 1: + batch_size = query_shape[0] + num_heads = get_attribute(node, "num_heads") + + head_size = None + if len(query_shape) == 3: + head_size = ( + int(query_shape[2] / num_heads) + if isinstance(query_shape[2], int) + else f"{query_shape[2]}/{num_heads}" + ) + else: + head_size = query_shape[4] + + past_shape = self._try_get_shape(node, 6) + + if past_shape is not None: + if isinstance(past_shape[2], int) and isinstance( + total_sequence_length, int + ): + total_sequence_length = past_shape[2] + total_sequence_length + else: + total_sequence_length = f"{past_shape[2]}+{total_sequence_length}" + + present_shape = [ + batch_size, + num_heads, + total_sequence_length, + head_size, + ] + + assert output_dtype is not None + if len(node.output) > 2 and node.output[1] and node.output[2]: + vi = self.known_vi_[node.output[1]] + vi.CopyFrom( + helper.make_tensor_value_info(vi.name, output_dtype, present_shape) + ) + vi = self.known_vi_[node.output[2]] + vi.CopyFrom( + helper.make_tensor_value_info(vi.name, output_dtype, present_shape) + ) + + def _infer_DecoderMaskedMultiHeadAttention(self, node): + """Infers the output shape of the DecoderMaskedMultiHeadAttention node based on input shapes and attributes in + the computational graph. + """ + # Q, K and V without packing: + # Input 0 (query) has shape (batch_size, 1, hidden_size) + # Input 5 (past_key) if exists has shape (batch_size, num_heads, max_sequence_length, head_size) + + query_shape = self._get_shape(node, 0) + if query_shape is not None: + output_shape = query_shape + output_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type + assert output_dtype is not None + vi = self.known_vi_[node.output[0]] + vi.CopyFrom( + helper.make_tensor_value_info(node.output[0], output_dtype, output_shape) + ) + + if len(node.output) > 2 and node.output[1] and node.output[2]: + past_shape = self._try_get_shape(node, 5) + if past_shape is not None: + vi = self.known_vi_[node.output[1]] + vi.CopyFrom( + helper.make_tensor_value_info(vi.name, output_dtype, past_shape) + ) + vi = self.known_vi_[node.output[2]] + vi.CopyFrom( + helper.make_tensor_value_info(vi.name, output_dtype, past_shape) + ) + + def _infer_FastGelu(self, node): + """Infers the output shapes and types for the FastGelu node using shape propagation.""" + self._propagate_shape_and_type(node) + + def _infer_Gelu(self, node): + """Infers the output shapes and types for the Gelu node using shape propagation.""" + self._propagate_shape_and_type(node) + + def _infer_QuickGelu(self, node): + """Infers the output shapes and types for the QuickGelu node using shape propagation.""" + self._propagate_shape_and_type(node) + + def _infer_GemmFastGelu(self, node): + """Infers the output shapes and types for the GemmFastGelu node using matrix multiplication shape + computation. + """ + self._compute_matmul_shape(node) + + def _infer_GemmFloat8(self, node): + """Infers the output shapes and types for the GemmFloat8 node using matrix multiplication shape computation.""" + self._compute_matmul_shape(node) + + def _infer_LayerNormalization(self, node): + """Infers the output shapes and types for the LayerNormalization node, including handling mean and variance + outputs. + """ + self._propagate_shape_and_type(node) + if len(node.output) > 1: + axis = get_attribute(node, "axis") + if axis is None: + axis = -1 + x_shape = self._get_shape(node, 0) + if x_shape is not None: + rank = len(x_shape) + axis = handle_negative_axis(axis, rank) + mean_shape = x_shape[:axis] + [1 for _ in range(rank - axis)] + mean_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type + if mean_dtype in { + onnx.TensorProto.FLOAT16, + onnx.TensorProto.BFLOAT16, + }: + mean_dtype = onnx.TensorProto.FLOAT + vi = self.known_vi_[node.output[1]] + vi.CopyFrom( + helper.make_tensor_value_info(node.output[1], mean_dtype, mean_shape) + ) + if len(node.output) > 2: + vi = self.known_vi_[node.output[2]] + vi.CopyFrom( + helper.make_tensor_value_info(node.output[2], mean_dtype, mean_shape) + ) + + def _infer_LongformerAttention(self, node): + """Infer and propagate shape and type information for a LongformerAttention node.""" + self._propagate_shape_and_type(node) + + def _infer_EmbedLayerNormalization(self, node): + """Infer and propagate shape and type information for an EmbedLayerNormalization node.""" + input_ids_shape = self._get_shape(node, 0) + word_embedding_shape = self._get_shape(node, 2) + assert len(input_ids_shape) == 2 and len(word_embedding_shape) == 2 + output_shape = [*input_ids_shape, word_embedding_shape[1]] + + word_embedding_dtype = self.known_vi_[node.input[2]].type.tensor_type.elem_type + vi = self.known_vi_[node.output[0]] + vi.CopyFrom( + helper.make_tensor_value_info(node.output[0], word_embedding_dtype, output_shape) + ) + + if len(node.output) > 1 and node.output[1]: + mask_index_shape = [input_ids_shape[0]] + vi = self.known_vi_[node.output[1]] + vi.CopyFrom( + helper.make_tensor_value_info( + node.output[1], onnx.TensorProto.INT32, mask_index_shape + ) + ) + + if len(node.output) > 2: + # Optional output of add before layer normalization is done + # shape is same as the output + vi = self.known_vi_[node.output[2]] + vi.CopyFrom( + helper.make_tensor_value_info( + node.output[2], word_embedding_dtype, output_shape + ) + ) + + def _infer_SkipLayerNormalization(self, node): + """Infer the output shape and type for a node with SkipLayerNormalization in an ONNX model.""" + self._propagate_shape_and_type(node) + + # If the SkipLayerNormalization node contains the optional + # output for inference, infer the shape and type for it too + if len(node.output) > 3: + self._propagate_shape_and_type(node, 0, 3) + + def _infer_GroupNorm(self, node): + """Infer the shape and type for Group Normalization in an ONNX model.""" + self._propagate_shape_and_type(node) + + def _infer_SkipGroupNorm(self, node): + """Infer the shape and type for Skip Group Normalization in an ONNX model.""" + self._propagate_shape_and_type(node, 0, 0) + if len(node.output) > 1: + self._propagate_shape_and_type(node, 0, 1) + + def _infer_BiasSplitGelu(self, node): + """Infer the shape and type for Bias Split Gelu in an ONNX model.""" + input_shape = self._get_shape(node, 0) + bias_shape = self._get_shape(node, 1) + if input_shape and bias_shape and isinstance(bias_shape[0], int): + output_shape = input_shape + output_shape[2] = int(bias_shape[0] / 2) + vi = self.known_vi_[node.output[0]] + output_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type + vi.CopyFrom(helper.make_tensor_value_info(vi.name, output_dtype, output_shape)) + + def _infer_BiasAdd(self, node): + """Infer the output shape and type for a BiasAdd node by propagating input shape and type information.""" + self._propagate_shape_and_type(node) + + def _infer_RotaryEmbedding(self, node): + """Infer the output shape and type for a RotaryEmbedding node by appropriately propagating input shape and type + information. + """ + if len(node.output) == 1: + self._propagate_shape_and_type(node) + elif len(node.output) == 2: + # Extraneous constant nodes outputted by RotaryEmbedding function made with `export_modules_as_functions` + self._propagate_shape_and_type(node, input_index=1, output_index=0) + self._propagate_shape_and_type(node, input_index=0, output_index=1) # true output + elif len(node.output) == 3: + # Extraneous constant nodes outputted by RotaryEmbedding function made with `export_modules_as_functions` + self._propagate_shape_and_type(node, input_index=1, output_index=0) + self._propagate_shape_and_type(node, input_index=1, output_index=1) + self._propagate_shape_and_type(node, input_index=0, output_index=2) # true output + + def _infer_PythonOp(self, node): + """Infer and propagate the shape and type information for a PythonOp node in the computation graph.""" + output_tensor_types = get_attribute(node, "output_tensor_types") + assert output_tensor_types, ( + f"PythonOp '{node.name}' has no output_tensor_types attribute." + ) + output_tensor_ranks = get_attribute(node, "output_tensor_ranks") + assert output_tensor_ranks, ( + f"PythonOp '{node.name}' has no output_tensor_ranks attribute." + ) + + from onnxruntime.capi._pybind_state import get_shape_inference_function + + func_name = get_attribute(node, "func_name").decode() + shape_inferer = get_shape_inference_function(func_name) + + # Set the context output separately. + # The first output is torch.autograd.Function''s context. + vi = self.known_vi_[node.output[0]] + vi.CopyFrom(helper.make_tensor_value_info(node.output[0], onnx.TensorProto.INT64, [])) + + if shape_inferer is not None: + input_shapes = [] + input_dtypes = [] + for input_index in range(len(node.input)): + shape = self._get_shape(node, input_index) + input_shapes.append(shape) + input_dtype = self.known_vi_[ + node.input[input_index] + ].type.tensor_type.elem_type + input_dtypes.append(input_dtype) + output_shapes, output_dtypes = shape_inferer(node, input_shapes, input_dtypes) + assert len(output_shapes) == len(output_dtypes) == (len(node.output) - 1), ( + f"PythonOp '{func_name}' returned {len(output_shapes)} shapes and {len(output_dtypes)} dtypes, " + f"but expected {len(node.output) - 1} outputs." + ) + for i in range(len(node.output) - 1): + output_index = i + 1 + vi = self.known_vi_[node.output[output_index]] + vi.CopyFrom( + helper.make_tensor_value_info( + node.output[output_index], output_dtypes[i], output_shapes[i] + ) + ) + else: + # General shape inference for PythonOp. + # Outputs after torch.autograd.Function's context are tensors. + # We assume their ranks are fixed for different model inputs. + for i in range(len(node.output) - 1): + # Process the i-th tensor outputs. + vi = self.known_vi_[node.output[i + 1]] + sympy_shape = self._new_symbolic_shape(output_tensor_ranks[i], node) + shape = get_shape_from_sympy_shape(sympy_shape) + value_info = helper.make_tensor_value_info( + node.output[i + 1], output_tensor_types[i], shape + ) + vi.CopyFrom(value_info) + + def _propagate_shape_and_type(self, node, input_index=0, output_index=0): + """Propagates the shape and type information from input to output tensors in a given node.""" + shape = self._get_shape(node, input_index) + output_dtype = self.known_vi_[node.input[input_index]].type.tensor_type.elem_type + vi = self.known_vi_[node.output[output_index]] + vi.CopyFrom( + helper.make_tensor_value_info(node.output[output_index], output_dtype, shape) + ) + + def _is_none_dim(self, dim_value): + """Check if dimension value is a string representing an unknown dimension that is not in symbolic_dims_.""" + if type(dim_value) != str: + return False + return dim_value not in self.symbolic_dims_ if "unk__" in dim_value else False + + def _is_shape_contains_none_dim(self, out_shape): + """Check if any dimension in the given shape contains the 'None' dimension and return it if found.""" + for out in out_shape: + if self._is_none_dim(out): + return out + return None + + def _infer_impl(self, start_sympy_data=None): + """Infer implementation details and update symbolic data and input symbols.""" + self.sympy_data_ = start_sympy_data or {} + self.out_mp_.graph.ClearField("value_info") + self._apply_suggested_merge(graph_input_only=True) + self.input_symbols_ = set() + for i in self.out_mp_.graph.input: + input_shape = get_shape_from_value_info(i) + if input_shape is None: + continue + + if is_sequence(i.type): + input_dims = i.type.sequence_type.elem_type.tensor_type.shape.dim + else: + input_dims = i.type.tensor_type.shape.dim + + for i_dim, dim in enumerate(input_shape): + if dim is None: + # some models use None for symbolic dim in input, replace it with a string + input_dims[i_dim].dim_param = str(self._new_symbolic_dim(i.name, i_dim)) + + self.input_symbols_.update([d for d in input_shape if type(d) == str]) + + for s in self.input_symbols_: + if s in self.suggested_merge_: + s_merge = self.suggested_merge_[s] + assert s_merge in self.symbolic_dims_ + self.symbolic_dims_[s] = self.symbolic_dims_[s_merge] + else: + # Since inputs are not produced by other ops, we can assume positivity + self.symbolic_dims_[s] = sympy.Symbol(s, integer=True, positive=True) + # create a temporary ModelProto for single node inference + # note that we remove initializer to have faster inference + # for tensor ops like Reshape/Tile/Expand that read initializer, we need to do sympy computation based inference anyways + self.tmp_mp_ = onnx.ModelProto() + self.tmp_mp_.CopyFrom(self.out_mp_) + self.tmp_mp_.graph.ClearField("initializer") + + # compute prerequisite for node for topological sort + # node with subgraphs may have dependency on implicit inputs, which will affect topological sort + prereq_for_node = {} # map from node to all its inputs, including implicit ones in subgraph + + def get_prereq(node): + """Compute and return the prerequisite inputs for a given node, including implicit inputs from subgraphs.""" + names = {i for i in node.input if i} + subgraphs = [] + if node.op_type == "If": + subgraphs = [ + get_attribute(node, "then_branch"), + get_attribute(node, "else_branch"), + ] + elif node.op_type in {"Loop", "Scan"}: + subgraphs = [get_attribute(node, "body")] + for g in subgraphs: + g_outputs_and_initializers = {i.name for i in g.initializer} + g_prereq = set() + for n in g.node: + g_outputs_and_initializers.update(n.output) + for n in g.node: + g_prereq.update( + [i for i in get_prereq(n) if i not in g_outputs_and_initializers] + ) + names.update(g_prereq) + # remove subgraph inputs from g_prereq since those are local-only + for i in g.input: + if i.name in names: + names.remove(i.name) + return names + + for n in self.tmp_mp_.graph.node: + prereq_for_node[n.output[0]] = get_prereq(n) + + # topological sort nodes, note there might be dead nodes so we check if all graph outputs are reached to terminate + sorted_nodes = [] + sorted_known_vi = { + i.name + for i in list(self.out_mp_.graph.input) + list(self.out_mp_.graph.initializer) + } + if any(o.name in sorted_known_vi for o in self.out_mp_.graph.output): + # Loop/Scan will have some graph output in graph inputs, so don't do topological sort + sorted_nodes = self.out_mp_.graph.node + else: + while any(o.name not in sorted_known_vi for o in self.out_mp_.graph.output): + old_sorted_nodes_len = len(sorted_nodes) + for node in self.out_mp_.graph.node: + if node.output[0] not in sorted_known_vi and all( + i in sorted_known_vi for i in prereq_for_node[node.output[0]] if i + ): + sorted_known_vi.update(node.output) + sorted_nodes.append(node) + if old_sorted_nodes_len == len(sorted_nodes) and not all( + o.name in sorted_known_vi for o in self.out_mp_.graph.output + ): + raise Exception("Invalid model with cyclic graph") + + for node in sorted_nodes: + assert all([i in self.known_vi_ for i in node.input if i]) + self._onnx_infer_single_node(node) + known_aten_op = False + if node.op_type in self.dispatcher_: + self.dispatcher_[node.op_type](node) + elif node.op_type == "ConvTranspose": + # onnx shape inference ops like ConvTranspose may have empty shape for symbolic input + # before adding symbolic compute for them + # mark the output type as UNDEFINED to allow guessing of rank + vi = self.known_vi_[node.output[0]] + if len(vi.type.tensor_type.shape.dim) == 0: + vi.type.tensor_type.elem_type = onnx.TensorProto.UNDEFINED + elif node.op_type == "ATen" and node.domain == "org.pytorch.aten": + for attr in node.attribute: + # TODO: Is overload_name needed? + if attr.name == "operator": + aten_op_name = ( + attr.s.decode("utf-8") if isinstance(attr.s, bytes) else attr.s + ) + if aten_op_name in self.aten_op_dispatcher_: + known_aten_op = True + self.aten_op_dispatcher_[aten_op_name](node) + break + + if self.verbose_ > 2: + logger.debug(node.op_type + ": " + node.name) + for i, name in enumerate(node.input): + logger.debug( + " Input {}: {} {}".format( + i, name, "initializer" if name in self.initializers_ else "" + ) + ) + + # onnx automatically merge dims with value, i.e. Mul(['aaa', 'bbb'], [1000, 1]) -> [1000, 'bbb'] + # symbolic shape inference needs to apply merge of 'aaa' -> 1000 in this case + if node.op_type in { + "Add", + "Sub", + "Mul", + "Div", + "MatMul", + "MatMulInteger", + "MatMulInteger16", + "Where", + "Sum", + }: + vi = self.known_vi_[node.output[0]] + out_rank = len(get_shape_from_type_proto(vi.type)) + in_shapes = [self._get_shape(node, i) for i in range(len(node.input))] + for d in range( + out_rank + - ( + 2 + if node.op_type in {"MatMul", "MatMulInteger", "MatMulInteger16"} + else 0 + ) + ): + in_dims = [ + s[len(s) - out_rank + d] for s in in_shapes if len(s) + d >= out_rank + ] + if len(in_dims) > 1: + self._check_merged_dims(in_dims, allow_broadcast=True) + + for i_o in range(len(node.output)): + # Special cases: + # 1) We do not care about the training related outputs of SkipLayerNormalization + # 2) We do not care about the extraneous constant outputs in RotaryEmbedding because + # the RotaryEmbedding op created during export can be replaced by the RotaryEmbedding + # contrib op + if node.op_type in { + "SkipLayerNormalization", + "SkipSimplifiedLayerNormalization", + } and i_o in {1, 2}: + continue + if node.op_type == "RotaryEmbedding" and len(node.output) > 1: + # Skip symbolic shape inference for RotaryEmbedding functions that have extraneous outputs + # generated by `export_modules_as_functions` + continue + + vi = self.known_vi_[node.output[i_o]] + out_type = vi.type + out_type_kind = out_type.WhichOneof("value") + + # do not process shape for non-tensors + if out_type_kind not in {"tensor_type", "sparse_tensor_type", None}: + if self.verbose_ > 2: + if out_type_kind == "sequence_type": + seq_cls_type = out_type.sequence_type.elem_type.WhichOneof("value") + if seq_cls_type == "tensor_type": + logger.debug( + " {}: sequence of {} {}".format( + node.output[i_o], + str(get_shape_from_value_info(vi)), + onnx.TensorProto.DataType.Name( + vi.type.sequence_type.elem_type.tensor_type.elem_type + ), + ) + ) + else: + logger.debug( + f" {node.output[i_o]}: sequence of {seq_cls_type}" + ) + else: + logger.debug(f" {node.output[i_o]}: {out_type_kind}") + continue + + out_shape = get_shape_from_value_info(vi) + out_type_undefined = ( + out_type.tensor_type.elem_type == onnx.TensorProto.UNDEFINED + ) + if self.verbose_ > 2: + logger.debug( + f" {node.output[i_o]}: {str(out_shape)} {onnx.TensorProto.DataType.Name(vi.type.tensor_type.elem_type)}" + ) + if node.output[i_o] in self.sympy_data_: + logger.debug( + " Sympy Data: " + str(self.sympy_data_[node.output[i_o]]) + ) + + # onnx >= 1.11.0, use unk__#index instead of None when the shape dim is uncertain + if ( + out_shape is not None + and (None in out_shape or self._is_shape_contains_none_dim(out_shape)) + ) or out_type_undefined: + if self.auto_merge_: + if node.op_type in { + "Add", + "Sub", + "Mul", + "Div", + "MatMul", + "MatMulInteger", + "MatMulInteger16", + "Concat", + "Where", + "Sum", + "Equal", + "Less", + "Greater", + "LessOrEqual", + "GreaterOrEqual", + "Min", + "Max", + }: + shapes = [self._get_shape(node, i) for i in range(len(node.input))] + if node.op_type in { + "MatMul", + "MatMulInteger", + "MatMulInteger16", + } and ( + None in out_shape + or self._is_shape_contains_none_dim(out_shape) + ): + if None in out_shape: + idx = out_shape.index(None) + else: + idx = out_shape.index( + self._is_shape_contains_none_dim(out_shape) + ) + dim_idx = [len(s) - len(out_shape) + idx for s in shapes] + # only support auto merge for MatMul for dim < rank-2 when rank > 2 + assert len(shapes[0]) > 2 and dim_idx[0] < len(shapes[0]) - 2 + assert len(shapes[1]) > 2 and dim_idx[1] < len(shapes[1]) - 2 + elif node.op_type == "Expand": + # auto merge for cases like Expand([min(batch, 1), min(seq, 512)], [batch, seq]) + shapes = [ + self._get_shape(node, 0), + self._get_value(node, 1), + ] + else: + shapes = [] + + if shapes: + for idx in range(len(out_shape)): + if out_shape[idx] is not None and not self._is_none_dim( + out_shape[idx] + ): + continue + # note that the broadcasting rule aligns from right to left + # if a tensor has a lower rank (dim_idx[idx] < 0), it would automatically broadcast and need no merge + dim_idx = [len(s) - len(out_shape) + idx for s in shapes] + if dim_idx: + self._add_suggested_merge( + [ + s[i] if is_literal(s[i]) else str(s[i]) + for s, i in zip(shapes, dim_idx) + if i >= 0 + ] + ) + self.run_ = True + else: + self.run_ = False + else: + self.run_ = False + + # create new dynamic dims for ops not handled by symbolic shape inference + if ( + not self.run_ + and node.op_type not in self.dispatcher_ + and not known_aten_op + ): + is_unknown_op = out_type_undefined and ( + out_shape is None or len(out_shape) == 0 + ) + if is_unknown_op: + # unknown op to ONNX, maybe from higher opset or other domain + # only guess the output rank from input 0 when using guess_output_rank option + out_rank = ( + self._get_shape_rank(node, 0) + if self.guess_output_rank_ + else -1 + ) + else: + # valid ONNX op, but not handled by symbolic shape inference, just assign dynamic shape + out_rank = len(out_shape) + + if out_rank >= 0: + new_shape = self._new_symbolic_shape(out_rank, node, i_o) + if out_type_undefined: + # guess output data type from input vi if not defined + out_dtype = self.known_vi_[ + node.input[0] + ].type.tensor_type.elem_type + else: + # otherwise, use original data type + out_dtype = vi.type.tensor_type.elem_type + vi.CopyFrom( + helper.make_tensor_value_info( + vi.name, + out_dtype, + get_shape_from_sympy_shape(new_shape), + ) + ) + + if self.verbose_ > 0: + if is_unknown_op: + logger.debug( + f"Possible unknown op: {node.op_type} node: {node.name}, guessing {vi.name} shape" + ) + if self.verbose_ > 2: + logger.debug( + f" {node.output[i_o]}: {str(new_shape)} {vi.type.tensor_type.elem_type}" + ) + self.run_ = True + continue # continue the inference after guess, no need to stop as no merge is needed + + if self.verbose_ > 0 or not self.auto_merge_ or out_type_undefined: + logger.debug( + "Stopping at incomplete shape inference at " + + node.op_type + + ": " + + node.name + ) + logger.debug("node inputs:") + for i in node.input: + if i in self.known_vi_: + logger.debug(self.known_vi_[i]) + else: + logger.debug(f"not in known_vi_ for {i}") + logger.debug("node outputs:") + for o in node.output: + if o in self.known_vi_: + logger.debug(self.known_vi_[o]) + else: + logger.debug(f"not in known_vi_ for {o}") + if self.auto_merge_ and not out_type_undefined: + logger.debug("Merging: " + str(self.suggested_merge_)) + return False + + self.run_ = False + return True + + def _update_output_from_vi(self): + """Update output attributes using known value information dictionary.""" + for output in self.out_mp_.graph.output: + if output.name in self.known_vi_: + output.CopyFrom(self.known_vi_[output.name]) + + @staticmethod + def infer_shapes( + in_mp, int_max=2**31 - 1, auto_merge=False, guess_output_rank=False, verbose=0 + ): + """Perform symbolic shape inference on an ONNX model using the specified options to handle model shapes + efficiently. + """ + onnx_opset = get_opset(in_mp) + if (not onnx_opset) or onnx_opset < 7: + logger.warning("Only support models of onnx opset 7 and above.") + return None + symbolic_shape_inference = SymbolicShapeInference( + int_max, auto_merge, guess_output_rank, verbose + ) + all_shapes_inferred = False + symbolic_shape_inference._preprocess(in_mp) + while symbolic_shape_inference.run_: + all_shapes_inferred = symbolic_shape_inference._infer_impl() + symbolic_shape_inference._update_output_from_vi() + if not all_shapes_inferred: + raise Exception("Incomplete symbolic shape inference") + return symbolic_shape_inference.out_mp_ + + +def parse_arguments(): + """Parses command-line arguments for ONNX model transformation options.""" + parser = argparse.ArgumentParser() + parser.add_argument("--input", required=True, help="The input model file") + parser.add_argument("--output", help="The output model file") + parser.add_argument( + "--auto_merge", + help="Automatically merge symbolic dims when confliction happens", + action="store_true", + default=False, + ) + parser.add_argument( + "--int_max", + help="maximum value for integer to be treated as boundless for ops like slice", + type=int, + default=2**31 - 1, + ) + parser.add_argument( + "--guess_output_rank", + help="guess output rank to be the same as input 0 for unknown ops", + action="store_true", + default=False, + ) + parser.add_argument( + "--verbose", + help="Prints detailed logs of inference, 0: turn off, 1: warnings, 3: detailed", + type=int, + default=0, + ) + parser.add_argument( + "--save_as_external_data", + help="Saving an ONNX model to external data", + action="store_true", + default=False, + ) + parser.add_argument( + "--all_tensors_to_one_file", + help="Saving all the external data to one file", + action="store_true", + default=False, + ) + parser.add_argument( + "--external_data_location", + help="The file location to save the external file", + default="./", + ) + parser.add_argument( + "--external_data_size_threshold", + help="The size threshold for external data", + type=int, + default=1024, + ) + return parser.parse_args() + + +if __name__ == "__main__": + args = parse_arguments() + logger.info(f"input model: {args.input}") + if args.output: + logger.info(f"output model {args.output}") + logger.info("Doing symbolic shape inference...") + out_mp = SymbolicShapeInference.infer_shapes( + onnx.load(args.input), + args.int_max, + args.auto_merge, + args.guess_output_rank, + args.verbose, + ) + if args.output and out_mp: + if args.save_as_external_data: + onnx.save_model( + out_mp, + args.output, + save_as_external_data=True, + all_tensors_to_one_file=args.all_tensors_to_one_file, + location=args.external_data_location, + size_threshold=args.external_data_size_threshold, + convert_attribute=False, + ) + else: + onnx.save(out_mp, args.output) + logger.info("Done!") From f48537b4e6ae4fe5edace45ea93fe0121577e5ce Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Sun, 29 Jun 2025 15:47:20 -0700 Subject: [PATCH 02/31] Add support for expr in SymbolicDim Signed-off-by: Justin Chu --- src/onnx_ir/_core.py | 24 ++++++++++++++++++++---- 1 file changed, 20 insertions(+), 4 deletions(-) diff --git a/src/onnx_ir/_core.py b/src/onnx_ir/_core.py index af5a2581..195f1902 100644 --- a/src/onnx_ir/_core.py +++ b/src/onnx_ir/_core.py @@ -62,6 +62,7 @@ if typing.TYPE_CHECKING: import numpy.typing as npt from typing_extensions import TypeGuard + import sympy TArrayCompatible = typing.TypeVar( "TArrayCompatible", @@ -1115,13 +1116,14 @@ class SymbolicDim(_protocols.SymbolicDimProtocol, _display.PrettyPrintable): It is immutable and can be compared or hashed. """ - __slots__ = ("_value",) + __slots__ = ("_expr", "_value") - def __init__(self, value: str | None) -> None: + def __init__(self, value: str | None, /, expr: sympy.Expr | None) -> None: """Initialize a symbolic dimension. Args: value: The value of the dimension. It should not be an int. + expr: An optional sympy expression representing the dimension. Raises: TypeError: If value is an int. @@ -1132,6 +1134,7 @@ def __init__(self, value: str | None) -> None: "If you are creating a Shape, use int directly instead of SymbolicDim." ) self._value = value + self._expr: sympy.Expr | None = None def __eq__(self, other: object) -> bool: """Check equality with another SymbolicDim or string/None.""" @@ -1148,11 +1151,24 @@ def value(self) -> str | None: """The value of the symbolic dimension (string or None).""" return self._value + @property + def expr(self) -> sympy.Expr | None: + """The sympy expression representing the symbolic dimension.""" + return self._expr + def __str__(self) -> str: - return f"{self._value}" + if self._value is not None: + return str(self._value) + if self._expr is not None: + return str(self._expr) + return "?" def __repr__(self) -> str: - return f"{self.__class__.__name__}({self._value})" + if self._expr is not None: + expr_text = f", expr={self._expr!r}" + else: + expr_text = "" + return f"{self.__class__.__name__}({self._value}{expr_text})" def _is_int_compatible(value: object) -> TypeIs[SupportsInt]: From 7464af1f295240eed4504d8edbd841e81013986e Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Sun, 29 Jun 2025 16:15:35 -0700 Subject: [PATCH 03/31] wip Signed-off-by: Justin Chu --- src/onnx_ir/_shape_inference/__init__.py | 50 ++++++++++++++++++++++++ 1 file changed, 50 insertions(+) diff --git a/src/onnx_ir/_shape_inference/__init__.py b/src/onnx_ir/_shape_inference/__init__.py index e69de29b..e847ec2b 100644 --- a/src/onnx_ir/_shape_inference/__init__.py +++ b/src/onnx_ir/_shape_inference/__init__.py @@ -0,0 +1,50 @@ +"""Symbolic shape inference for ONNX IR.""" + +from typing import TYPE_CHECKING + +import numpy as np + +import onnx_ir as ir + +if TYPE_CHECKING: + import sympy + + +def get_expr(shape: ir.Shape, index: int) -> sympy.Expr: + """Get the expression or value at a specific index in the shape. + + Args: + shape: The shape to get the expression from. + index: The index of the dimension to get. + + Returns: + The expression or value at the specified index. + """ + import sympy + + dim = shape[index] + if isinstance(dim, ir.SymbolicDim): + if dim.expr is not None: + return dim.expr + return sympy.Symbol(dim.value) + return sympy.Integer(dim) + + +def set_expr(shape: ir.Shape, index: int, expr: sympy.Expr | int) -> None: + """Set the expression or value at a specific index in the shape. + + Args: + shape: The shape to set the expression in. + index: The index of the dimension to set. + expr: The expression or value to set at the specified index. + """ + from sympy.utilities.misc import as_int + if isinstance(expr, (int, np.integer)): + shape[index] = int(expr) + return + assert isinstance(expr, sympy.Expr), f"Expected sympy.Expr or int, got {type(expr)}" + expr = sympy.sympify(expr) + if expr.is_integer: + shape[index] = as_int(expr) + return + shape[index] = ir.SymbolicDim(str(expr), expr=expr) From 27ae0ca7f27809878a28872b97030b6beae2f058 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Sun, 29 Jun 2025 16:15:39 -0700 Subject: [PATCH 04/31] wip Signed-off-by: Justin Chu --- src/onnx_ir/_shape_inference/__init__.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/onnx_ir/_shape_inference/__init__.py b/src/onnx_ir/_shape_inference/__init__.py index e847ec2b..9fac0f22 100644 --- a/src/onnx_ir/_shape_inference/__init__.py +++ b/src/onnx_ir/_shape_inference/__init__.py @@ -48,3 +48,6 @@ def set_expr(shape: ir.Shape, index: int, expr: sympy.Expr | int) -> None: shape[index] = as_int(expr) return shape[index] = ir.SymbolicDim(str(expr), expr=expr) + + +class NodeInferencer: From 1afe64a5e7d70a8ce53926252bde30605837c557 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Sun, 29 Jun 2025 16:20:47 -0700 Subject: [PATCH 05/31] Create NodeInferencer Signed-off-by: Justin Chu --- .../__init__.py | 37 ++++++++++++++++++- .../_inferencer.py | 0 2 files changed, 36 insertions(+), 1 deletion(-) rename src/onnx_ir/{_shape_inference => _shape_type_inference}/__init__.py (54%) rename src/onnx_ir/{_shape_inference => _shape_type_inference}/_inferencer.py (100%) diff --git a/src/onnx_ir/_shape_inference/__init__.py b/src/onnx_ir/_shape_type_inference/__init__.py similarity index 54% rename from src/onnx_ir/_shape_inference/__init__.py rename to src/onnx_ir/_shape_type_inference/__init__.py index 9fac0f22..d76d1436 100644 --- a/src/onnx_ir/_shape_inference/__init__.py +++ b/src/onnx_ir/_shape_type_inference/__init__.py @@ -1,5 +1,7 @@ """Symbolic shape inference for ONNX IR.""" +import abc +from collections.abc import Collection, Sequence from typing import TYPE_CHECKING import numpy as np @@ -50,4 +52,37 @@ def set_expr(shape: ir.Shape, index: int, expr: sympy.Expr | int) -> None: shape[index] = ir.SymbolicDim(str(expr), expr=expr) -class NodeInferencer: +class NodeInferencer(abc.ABC): + """Base class for node inferencers. + + This class provides a common interface for all node inferencers. + """ + + def __init__(self, op_type: str, opsets: Collection[int], domain: str = "") -> None: + """Initialize the node inferencer. + + Args: + op_type: The type of the operation. + opsets: A collection of ONNX opset versions supported by this inferencer. + domain: The domain of the operation, default is an empty string. + """ + self.op_type = op_type + self.opsets = opsets + self.domain = domain + + @abc.abstractmethod + def check(self, node: ir.Node) -> None: + """Check if the node is valid for this inferencer.""" + raise NotImplementedError + + @abc.abstractmethod + def infer(self, node: ir.Node) -> Sequence[ir.Value]: + """Infer the shape for the node. + + Args: + node: The ONNX node to infer the type and shape for. + + Returns: + A sequence of ONNX values containing the inferred shapes. + """ + raise NotImplementedError diff --git a/src/onnx_ir/_shape_inference/_inferencer.py b/src/onnx_ir/_shape_type_inference/_inferencer.py similarity index 100% rename from src/onnx_ir/_shape_inference/_inferencer.py rename to src/onnx_ir/_shape_type_inference/_inferencer.py From 78ad6e005c92b42c83a263410cdee90fe29ffabc Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Sun, 29 Jun 2025 16:22:53 -0700 Subject: [PATCH 06/31] inference_common Signed-off-by: Justin Chu --- .../_shape_type_inference/{__init__.py => _common.py} | 0 src/onnx_ir/_shape_type_inference/ops/standard_ops.py | 5 +++++ 2 files changed, 5 insertions(+) rename src/onnx_ir/_shape_type_inference/{__init__.py => _common.py} (100%) create mode 100644 src/onnx_ir/_shape_type_inference/ops/standard_ops.py diff --git a/src/onnx_ir/_shape_type_inference/__init__.py b/src/onnx_ir/_shape_type_inference/_common.py similarity index 100% rename from src/onnx_ir/_shape_type_inference/__init__.py rename to src/onnx_ir/_shape_type_inference/_common.py diff --git a/src/onnx_ir/_shape_type_inference/ops/standard_ops.py b/src/onnx_ir/_shape_type_inference/ops/standard_ops.py new file mode 100644 index 00000000..851fa1fc --- /dev/null +++ b/src/onnx_ir/_shape_type_inference/ops/standard_ops.py @@ -0,0 +1,5 @@ +"""Standard inferencers for ONNX IR nodes.""" + + +from onnx_ir._shape_type_inference import _common as inference_common + From 5aa2df7b64540576feafd80bdb785f07adff650d Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Sun, 29 Jun 2025 16:43:14 -0700 Subject: [PATCH 07/31] Update shapes Signed-off-by: Justin Chu --- src/onnx_ir/_shape_type_inference/_common.py | 25 +++--- .../_shape_type_inference/ops/standard_ops.py | 79 ++++++++++++++++++- 2 files changed, 91 insertions(+), 13 deletions(-) diff --git a/src/onnx_ir/_shape_type_inference/_common.py b/src/onnx_ir/_shape_type_inference/_common.py index d76d1436..e6305cde 100644 --- a/src/onnx_ir/_shape_type_inference/_common.py +++ b/src/onnx_ir/_shape_type_inference/_common.py @@ -1,7 +1,9 @@ """Symbolic shape inference for ONNX IR.""" +from __future__ import annotations import abc from collections.abc import Collection, Sequence +import dataclasses from typing import TYPE_CHECKING import numpy as np @@ -52,18 +54,24 @@ def set_expr(shape: ir.Shape, index: int, expr: sympy.Expr | int) -> None: shape[index] = ir.SymbolicDim(str(expr), expr=expr) -class NodeInferencer(abc.ABC): - """Base class for node inferencers. +@dataclasses.dataclass +class InferenceResult: + values: Sequence[ir.Value] | None = None + failure: str | None = None - This class provides a common interface for all node inferencers. + +class NodeInferrer(abc.ABC): + """Base class for node inferrers. + + This class provides a common interface for all node inferrers. """ def __init__(self, op_type: str, opsets: Collection[int], domain: str = "") -> None: - """Initialize the node inferencer. + """Initialize the node inferrer. Args: op_type: The type of the operation. - opsets: A collection of ONNX opset versions supported by this inferencer. + opsets: A collection of ONNX opset versions supported by this inferrer. domain: The domain of the operation, default is an empty string. """ self.op_type = op_type @@ -71,12 +79,7 @@ def __init__(self, op_type: str, opsets: Collection[int], domain: str = "") -> N self.domain = domain @abc.abstractmethod - def check(self, node: ir.Node) -> None: - """Check if the node is valid for this inferencer.""" - raise NotImplementedError - - @abc.abstractmethod - def infer(self, node: ir.Node) -> Sequence[ir.Value]: + def infer(self, node: ir.Node) -> InferenceResult: """Infer the shape for the node. Args: diff --git a/src/onnx_ir/_shape_type_inference/ops/standard_ops.py b/src/onnx_ir/_shape_type_inference/ops/standard_ops.py index 851fa1fc..80e683fa 100644 --- a/src/onnx_ir/_shape_type_inference/ops/standard_ops.py +++ b/src/onnx_ir/_shape_type_inference/ops/standard_ops.py @@ -1,5 +1,80 @@ -"""Standard inferencers for ONNX IR nodes.""" +"""Standard Inferrers for ONNX IR nodes.""" +from __future__ import annotations -from onnx_ir._shape_type_inference import _common as inference_common +import sys +from collections.abc import Collection +import onnx_ir as ir +from onnx_ir._shape_type_inference import _common + + +class ElementwiseInferrer(_common.NodeInferrer): + """Base class for elementwise operation inferrers.""" + + def __init__(self, op_type: str, opsets: Collection[int] | None = None) -> None: + """Initialize the elementwise inferrer with the operation type.""" + if opsets is None: + opsets = range(sys.maxsize) + super().__init__(op_type, opsets=opsets) + + def infer(self, node: ir.Node) -> _common.InferenceResult: + """Infer the output shape and type for elementwise operations.""" + if len(node.inputs) != 1: + return _common.InferenceResult( + failure=f"Elementwise operation must have exactly one input, got {len(node.inputs)}." + ) + if node.inputs[0] is None: + return _common.InferenceResult( + failure="Elementwise operation input cannot be None." + ) + if len(node.outputs) != 1: + return _common.InferenceResult( + failure=f"Elementwise operation must have exactly one output, got {len(node.outputs)}." + ) + + return _common.InferenceResult( + (ir.Value(shape=node.inputs[0].shape, type=node.inputs[0].type),) + ) + + +def broadcast_shapes_bidirectional(shape1: ir.Shape, shape2: ir.Shape) -> ir.Shape: + """Broadcast two shapes bidirectionally. + + Args: + shape1: The first shape to broadcast. + shape2: The second shape to broadcast. + + Returns: + A new shape that is the result of broadcasting both shapes. + """ + # TODO: Use _common.get_expr and use sympy for broadcasting logic + + +class BinaryInferrer(_common.NodeInferrer): + """Base class for binary operation inferrers.""" + + def __init__(self, op_type: str, opsets: Collection[int] | None = None) -> None: + """Initialize the binary inferrer with the operation type.""" + if opsets is None: + opsets = range(sys.maxsize) + super().__init__(op_type, opsets=opsets) + + def infer(self, node: ir.Node) -> _common.InferenceResult: + """Infer the output shape and type for binary operations.""" + if len(node.inputs) != 2: + return _common.InferenceResult( + failure=f"Binary operation must have exactly two inputs, got {len(node.inputs)}." + ) + if node.inputs[0] is None or node.inputs[1] is None: + return _common.InferenceResult(failure="Binary operation inputs cannot be None.") + if len(node.outputs) != 1: + return _common.InferenceResult( + failure=f"Binary operation must have exactly one output, got {len(node.outputs)}." + ) + first_type = node.inputs[0].type + second_type = node.inputs[1].type + if first_type is not None and second_type is not None and first_type != second_type: + return _common.InferenceResult( + failure=f"Input types do not match: {first_type} vs {second_type}." + ) From dbc35932b9c11d1a6f9afc6fe96126ac3dce6ba0 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Sun, 29 Jun 2025 16:45:05 -0700 Subject: [PATCH 08/31] update Signed-off-by: Justin Chu --- src/onnx_ir/_shape_type_inference/__init__.py | 2 ++ 1 file changed, 2 insertions(+) create mode 100644 src/onnx_ir/_shape_type_inference/__init__.py diff --git a/src/onnx_ir/_shape_type_inference/__init__.py b/src/onnx_ir/_shape_type_inference/__init__.py new file mode 100644 index 00000000..7cb7d817 --- /dev/null +++ b/src/onnx_ir/_shape_type_inference/__init__.py @@ -0,0 +1,2 @@ +class SymbolicInferenceEngine: + pass From b9f0528166f031787875c380a4c648c9b88f0eb6 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Sun, 29 Jun 2025 17:05:39 -0700 Subject: [PATCH 09/31] Claude - add sympy import Signed-off-by: Justin Chu --- pyproject.toml | 2 +- src/onnx_ir/_core.py | 5 ++++- src/onnx_ir/_shape_type_inference/_common.py | 8 +------- 3 files changed, 6 insertions(+), 9 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 204f1d68..f98771e0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -21,7 +21,7 @@ classifiers = [ "Programming Language :: Python :: 3.13", "License :: OSI Approved :: Apache Software License", ] -dependencies = ["numpy", "onnx>=1.16", "typing_extensions>=4.10", "ml_dtypes"] +dependencies = ["numpy", "onnx>=1.16", "typing_extensions>=4.10", "ml_dtypes", "sympy"] [project.urls] Homepage = "https://onnx.ai/ir-py" diff --git a/src/onnx_ir/_core.py b/src/onnx_ir/_core.py index 195f1902..8667f64a 100644 --- a/src/onnx_ir/_core.py +++ b/src/onnx_ir/_core.py @@ -58,11 +58,11 @@ _protocols, _type_casting, ) +import sympy if typing.TYPE_CHECKING: import numpy.typing as npt from typing_extensions import TypeGuard - import sympy TArrayCompatible = typing.TypeVar( "TArrayCompatible", @@ -1204,6 +1204,9 @@ def _maybe_convert_to_symbolic_dim( """ if dim is None or isinstance(dim, str): return SymbolicDim(dim) + if isinstance(dim, sympy.Expr): + # If the dimension is a sympy expression, we create a SymbolicDim with it + return SymbolicDim(str(dim), expr=dim) if _is_int_compatible(dim): return int(dim) if isinstance(dim, SymbolicDim): diff --git a/src/onnx_ir/_shape_type_inference/_common.py b/src/onnx_ir/_shape_type_inference/_common.py index e6305cde..9b4abe45 100644 --- a/src/onnx_ir/_shape_type_inference/_common.py +++ b/src/onnx_ir/_shape_type_inference/_common.py @@ -4,15 +4,11 @@ import abc from collections.abc import Collection, Sequence import dataclasses -from typing import TYPE_CHECKING import numpy as np import onnx_ir as ir - -if TYPE_CHECKING: - import sympy - +import sympy def get_expr(shape: ir.Shape, index: int) -> sympy.Expr: """Get the expression or value at a specific index in the shape. @@ -24,8 +20,6 @@ def get_expr(shape: ir.Shape, index: int) -> sympy.Expr: Returns: The expression or value at the specified index. """ - import sympy - dim = shape[index] if isinstance(dim, ir.SymbolicDim): if dim.expr is not None: From c9a35b7f7997b380c4af980f52e54d7d07472bd0 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Sun, 29 Jun 2025 17:07:02 -0700 Subject: [PATCH 10/31] Claude and lint Signed-off-by: Justin Chu --- src/onnx_ir/_core.py | 2 +- src/onnx_ir/_shape_type_inference/_common.py | 7 +- .../_shape_type_inference/ops/concat.py | 92 +++++++++++++ .../_shape_type_inference/ops/matmul.py | 86 ++++++++++++ .../_shape_type_inference/ops/reshape.py | 122 ++++++++++++++++++ .../_shape_type_inference/ops/squeeze.py | 101 +++++++++++++++ .../_shape_type_inference/ops/standard_ops.py | 50 ++++++- .../_shape_type_inference/ops/transpose.py | 78 +++++++++++ .../_shape_type_inference/ops/unsqueeze.py | 94 ++++++++++++++ 9 files changed, 628 insertions(+), 4 deletions(-) create mode 100644 src/onnx_ir/_shape_type_inference/ops/concat.py create mode 100644 src/onnx_ir/_shape_type_inference/ops/matmul.py create mode 100644 src/onnx_ir/_shape_type_inference/ops/reshape.py create mode 100644 src/onnx_ir/_shape_type_inference/ops/squeeze.py create mode 100644 src/onnx_ir/_shape_type_inference/ops/transpose.py create mode 100644 src/onnx_ir/_shape_type_inference/ops/unsqueeze.py diff --git a/src/onnx_ir/_core.py b/src/onnx_ir/_core.py index 8667f64a..5e1f13eb 100644 --- a/src/onnx_ir/_core.py +++ b/src/onnx_ir/_core.py @@ -45,6 +45,7 @@ import ml_dtypes import numpy as np +import sympy from typing_extensions import TypeIs import onnx_ir @@ -58,7 +59,6 @@ _protocols, _type_casting, ) -import sympy if typing.TYPE_CHECKING: import numpy.typing as npt diff --git a/src/onnx_ir/_shape_type_inference/_common.py b/src/onnx_ir/_shape_type_inference/_common.py index 9b4abe45..6fae2350 100644 --- a/src/onnx_ir/_shape_type_inference/_common.py +++ b/src/onnx_ir/_shape_type_inference/_common.py @@ -1,14 +1,16 @@ """Symbolic shape inference for ONNX IR.""" + from __future__ import annotations import abc -from collections.abc import Collection, Sequence import dataclasses +from collections.abc import Collection, Sequence import numpy as np +import sympy import onnx_ir as ir -import sympy + def get_expr(shape: ir.Shape, index: int) -> sympy.Expr: """Get the expression or value at a specific index in the shape. @@ -37,6 +39,7 @@ def set_expr(shape: ir.Shape, index: int, expr: sympy.Expr | int) -> None: expr: The expression or value to set at the specified index. """ from sympy.utilities.misc import as_int + if isinstance(expr, (int, np.integer)): shape[index] = int(expr) return diff --git a/src/onnx_ir/_shape_type_inference/ops/concat.py b/src/onnx_ir/_shape_type_inference/ops/concat.py new file mode 100644 index 00000000..60eccb39 --- /dev/null +++ b/src/onnx_ir/_shape_type_inference/ops/concat.py @@ -0,0 +1,92 @@ +"""Concat operation inferrer for ONNX IR nodes.""" + +import sys +from collections.abc import Collection + +import onnx_ir as ir +from onnx_ir._shape_type_inference import _common + + +class ConcatInferrer(_common.NodeInferrer): + """Inferrer for Concat operations.""" + + def __init__(self, opsets: Collection[int] | None = None) -> None: + """Initialize the Concat inferrer.""" + if opsets is None: + opsets = range(sys.maxsize) + super().__init__("Concat", opsets=opsets) + + def infer(self, node: ir.Node) -> _common.InferenceResult: + """Infer the output shape and type for Concat operations.""" + if len(node.inputs) < 1: + return _common.InferenceResult( + failure="Concat operation must have at least one input." + ) + if any(inp is None for inp in node.inputs): + return _common.InferenceResult(failure="Concat operation inputs cannot be None.") + if len(node.outputs) != 1: + return _common.InferenceResult( + failure=f"Concat operation must have exactly one output, got {len(node.outputs)}." + ) + + # Get axis attribute + axis_attr = None + for attr in node.attributes: + if attr.name == "axis": + axis_attr = attr.value.i + break + + if axis_attr is None: + return _common.InferenceResult(failure="Concat operation requires axis attribute.") + + # Get first input shape as base + first_shape = node.inputs[0].shape + if first_shape is None: + return _common.InferenceResult(failure="Concat input shapes cannot be None.") + + rank = len(first_shape) + if rank == 0: + return _common.InferenceResult(failure="Concat inputs cannot be scalars.") + + # Handle negative axis + if axis_attr < 0: + axis_attr += rank + + if axis_attr < 0 or axis_attr >= rank: + return _common.InferenceResult( + failure=f"Concat axis {axis_attr} is out of bounds for rank {rank}." + ) + + # Check that all inputs have compatible shapes + output_shape = ir.Shape(list(first_shape)) + concat_dim_size = _common.get_expr(first_shape, axis_attr) + + for i, inp in enumerate(node.inputs[1:], 1): + if inp.shape is None: + return _common.InferenceResult(failure=f"Input {i} shape cannot be None.") + + input_shape = inp.shape + if len(input_shape) != rank: + return _common.InferenceResult( + failure=f"All inputs must have same rank. Input {i} has rank {len(input_shape)}, expected {rank}." + ) + + # Check non-concat dimensions are compatible + for dim_idx in range(rank): + if dim_idx == axis_attr: + # Accumulate concat dimension + concat_dim_size = concat_dim_size + _common.get_expr(input_shape, dim_idx) + else: + # Check compatibility of other dimensions + dim1 = _common.get_expr(first_shape, dim_idx) + dim2 = _common.get_expr(input_shape, dim_idx) + # For symbolic inference, we assume they are compatible + # In practice, this would need runtime verification + + # Set the concat dimension in output shape + _common.set_expr(output_shape, axis_attr, concat_dim_size) + + output_type = node.inputs[0].type + return _common.InferenceResult( + values=(ir.Value(shape=output_shape, type=output_type),) + ) diff --git a/src/onnx_ir/_shape_type_inference/ops/matmul.py b/src/onnx_ir/_shape_type_inference/ops/matmul.py new file mode 100644 index 00000000..ded876b6 --- /dev/null +++ b/src/onnx_ir/_shape_type_inference/ops/matmul.py @@ -0,0 +1,86 @@ +"""MatMul operation inferrer for ONNX IR nodes.""" + +import sys +from collections.abc import Collection + +import sympy + +import onnx_ir as ir +from onnx_ir._shape_type_inference import _common +from onnx_ir._shape_type_inference.ops.standard_ops import broadcast_shapes_bidirectional + + +class MatMulInferrer(_common.NodeInferrer): + """Inferrer for MatMul operations.""" + + def __init__(self, opsets: Collection[int] | None = None) -> None: + """Initialize the MatMul inferrer.""" + if opsets is None: + opsets = range(sys.maxsize) + super().__init__("MatMul", opsets=opsets) + + def infer(self, node: ir.Node) -> _common.InferenceResult: + """Infer the output shape and type for MatMul operations.""" + if len(node.inputs) != 2: + return _common.InferenceResult( + failure=f"MatMul operation must have exactly two inputs, got {len(node.inputs)}." + ) + if node.inputs[0] is None or node.inputs[1] is None: + return _common.InferenceResult(failure="MatMul operation inputs cannot be None.") + if len(node.outputs) != 1: + return _common.InferenceResult( + failure=f"MatMul operation must have exactly one output, got {len(node.outputs)}." + ) + + lhs_shape = node.inputs[0].shape + rhs_shape = node.inputs[1].shape + if lhs_shape is None or rhs_shape is None: + return _common.InferenceResult(failure="MatMul input shapes cannot be None.") + + lhs_rank = len(lhs_shape) + rhs_rank = len(rhs_shape) + + if lhs_rank == 0 or rhs_rank == 0: + return _common.InferenceResult(failure="MatMul inputs cannot be scalars.") + + # Compute output shape based on matrix multiplication rules + if lhs_rank == 1 and rhs_rank == 1: + # Vector dot product: (n,) × (n,) -> scalar + output_shape = ir.Shape([]) + elif lhs_rank == 1: + # Matrix-vector: (n,) × (..., n, k) -> (..., k) + output_shape = ir.Shape(rhs_shape[:-2] + rhs_shape[-1:]) + elif rhs_rank == 1: + # Vector-matrix: (..., m, n) × (n,) -> (..., m) + output_shape = ir.Shape(lhs_shape[:-1]) + else: + # Matrix-matrix: (..., m, n) × (..., n, k) -> (..., m, k) + # Broadcast batch dimensions + lhs_batch = lhs_shape[:-2] + rhs_batch = rhs_shape[:-2] + if lhs_batch and rhs_batch: + batch_shape = broadcast_shapes_bidirectional( + ir.Shape(lhs_batch), ir.Shape(rhs_batch) + ) + output_shape = ir.Shape(list(batch_shape) + [lhs_shape[-2], rhs_shape[-1]]) + elif lhs_batch: + output_shape = ir.Shape(list(lhs_batch) + [lhs_shape[-2], rhs_shape[-1]]) + elif rhs_batch: + output_shape = ir.Shape(list(rhs_batch) + [lhs_shape[-2], rhs_shape[-1]]) + else: + output_shape = ir.Shape([lhs_shape[-2], rhs_shape[-1]]) + + # Check dimension compatibility for matrix multiplication + if lhs_rank >= 1 and rhs_rank >= 1: + lhs_reduce_dim = ( + _common.get_expr(lhs_shape, -1) if lhs_rank >= 1 else sympy.Integer(1) + ) + rhs_reduce_dim = _common.get_expr(rhs_shape, -2 if rhs_rank >= 2 else 0) + + # For symbolic inference, we assume dimensions are compatible + # In practice, this would need runtime verification + + output_type = node.inputs[0].type + return _common.InferenceResult( + values=(ir.Value(shape=output_shape, type=output_type),) + ) diff --git a/src/onnx_ir/_shape_type_inference/ops/reshape.py b/src/onnx_ir/_shape_type_inference/ops/reshape.py new file mode 100644 index 00000000..f98b0318 --- /dev/null +++ b/src/onnx_ir/_shape_type_inference/ops/reshape.py @@ -0,0 +1,122 @@ +"""Reshape operation inferrer for ONNX IR nodes.""" + +import sys +from collections.abc import Collection + +import onnx_ir as ir +from onnx_ir._shape_type_inference import _common + + +class ReshapeInferrer(_common.NodeInferrer): + """Inferrer for Reshape operations.""" + + def __init__(self, opsets: Collection[int] | None = None) -> None: + """Initialize the Reshape inferrer.""" + if opsets is None: + opsets = range(sys.maxsize) + super().__init__("Reshape", opsets=opsets) + + def infer(self, node: ir.Node) -> _common.InferenceResult: + """Infer the output shape and type for Reshape operations.""" + if len(node.inputs) != 2: + return _common.InferenceResult( + failure=f"Reshape operation must have exactly two inputs, got {len(node.inputs)}." + ) + if node.inputs[0] is None or node.inputs[1] is None: + return _common.InferenceResult(failure="Reshape operation inputs cannot be None.") + if len(node.outputs) != 1: + return _common.InferenceResult( + failure=f"Reshape operation must have exactly one output, got {len(node.outputs)}." + ) + + input_shape = node.inputs[0].shape + shape_input = node.inputs[1] + + if input_shape is None: + return _common.InferenceResult(failure="Reshape input shape cannot be None.") + + # Try to get the shape values from the second input + # For symbolic inference, we may not have concrete values + if ( + hasattr(shape_input, "initializer_value") + and shape_input.initializer_value is not None + ): + shape_values = shape_input.initializer_value.tolist() + return self._infer_with_shape_values( + input_shape, shape_values, node.inputs[0].type + ) + else: + # Handle symbolic case where shape is not known at compile time + shape_shape = shape_input.shape + if shape_shape is None or len(shape_shape) != 1: + return _common.InferenceResult( + failure="Reshape shape input must be a 1D tensor." + ) + + shape_rank = shape_shape[0] + if isinstance(shape_rank, int): + # Create symbolic dimensions for the output + output_shape = ir.Shape([]) + for i in range(shape_rank): + output_shape.append(ir.SymbolicDim(f"reshape_dim_{i}")) + + return _common.InferenceResult( + values=(ir.Value(shape=output_shape, type=node.inputs[0].type),) + ) + else: + return _common.InferenceResult( + failure="Cannot infer reshape output shape with symbolic rank." + ) + + def _infer_with_shape_values( + self, input_shape: ir.Shape, shape_values: list, input_type + ) -> _common.InferenceResult: + """Infer output shape when shape values are known.""" + # Calculate total elements in input + total_elements = sympy.Integer(1) + for dim in input_shape: + if isinstance(dim, ir.SymbolicDim): + if dim.expr is not None: + total_elements *= dim.expr + else: + total_elements *= sympy.Symbol(dim.value) + else: + total_elements *= sympy.Integer(dim) + + # Process shape values + output_dims = [] + deferred_dim_idx = -1 + non_deferred_size = sympy.Integer(1) + + for i, dim_value in enumerate(shape_values): + if dim_value == -1: + if deferred_dim_idx != -1: + return _common.InferenceResult( + failure="Reshape can have at most one -1 dimension." + ) + deferred_dim_idx = i + output_dims.append(None) # Placeholder + elif dim_value == 0: + # Copy from input shape + if i >= len(input_shape): + return _common.InferenceResult( + failure=f"Cannot copy dimension {i} from input shape of rank {len(input_shape)}." + ) + dim_expr = _common.get_expr(input_shape, i) + output_dims.append(dim_expr) + non_deferred_size *= dim_expr + else: + output_dims.append(sympy.Integer(dim_value)) + non_deferred_size *= sympy.Integer(dim_value) + + # Calculate deferred dimension + if deferred_dim_idx != -1: + deferred_dim = total_elements // non_deferred_size + output_dims[deferred_dim_idx] = deferred_dim + + # Create output shape + output_shape = ir.Shape([0] * len(output_dims)) + for i, dim_expr in enumerate(output_dims): + _common.set_expr(output_shape, i, dim_expr) + + return _common.InferenceResult(values=(ir.Value(shape=output_shape, type=input_type),)) diff --git a/src/onnx_ir/_shape_type_inference/ops/squeeze.py b/src/onnx_ir/_shape_type_inference/ops/squeeze.py new file mode 100644 index 00000000..95f4c8ba --- /dev/null +++ b/src/onnx_ir/_shape_type_inference/ops/squeeze.py @@ -0,0 +1,101 @@ +"""Squeeze operation inferrer for ONNX IR nodes.""" + +import sys +from collections.abc import Collection + +import onnx_ir as ir +from onnx_ir._shape_type_inference import _common + + +class SqueezeInferrer(_common.NodeInferrer): + """Inferrer for Squeeze operations.""" + + def __init__(self, opsets: Collection[int] | None = None) -> None: + """Initialize the Squeeze inferrer.""" + if opsets is None: + opsets = range(sys.maxsize) + super().__init__("Squeeze", opsets=opsets) + + def infer(self, node: ir.Node) -> _common.InferenceResult: + """Infer the output shape and type for Squeeze operations.""" + if len(node.inputs) < 1 or len(node.inputs) > 2: + return _common.InferenceResult( + failure=f"Squeeze operation must have 1 or 2 inputs, got {len(node.inputs)}." + ) + if node.inputs[0] is None: + return _common.InferenceResult(failure="Squeeze operation input cannot be None.") + if len(node.outputs) != 1: + return _common.InferenceResult( + failure=f"Squeeze operation must have exactly one output, got {len(node.outputs)}." + ) + + input_shape = node.inputs[0].shape + if input_shape is None: + return _common.InferenceResult(failure="Squeeze input shape cannot be None.") + + rank = len(input_shape) + + # Get axes to squeeze + axes = None + + # Check for axes in second input (opset >= 13) + if len(node.inputs) == 2 and node.inputs[1] is not None: + if ( + hasattr(node.inputs[1], "initializer_value") + and node.inputs[1].initializer_value is not None + ): + axes = node.inputs[1].initializer_value.tolist() + if not isinstance(axes, list): + axes = [axes] + else: + # Check for axes attribute (opset < 13) + for attr in node.attributes: + if attr.name == "axes": + axes = list(attr.value.ints) + break + + if axes is None: + # No axes specified - squeeze all dimensions of size 1 + output_dims = [] + for i, dim in enumerate(input_shape): + dim_expr = _common.get_expr(input_shape, i) + # For symbolic dimensions, we assume they are not 1 + # Only squeeze literal 1s + if isinstance(dim, int) and dim == 1: + continue # Skip dimension of size 1 + else: + output_dims.append(dim_expr) + else: + # Normalize negative axes + normalized_axes = [] + for axis in axes: + if axis < 0: + axis += rank + if axis < 0 or axis >= rank: + return _common.InferenceResult( + failure=f"Squeeze axis {axis} is out of bounds for rank {rank}." + ) + normalized_axes.append(axis) + + # Validate that specified axes have dimension 1 (for literal dimensions) + for axis in normalized_axes: + dim = input_shape[axis] + if isinstance(dim, int) and dim != 1: + return _common.InferenceResult( + failure=f"Cannot squeeze axis {axis} with dimension {dim} (must be 1)." + ) + + # Build output shape excluding squeezed axes + output_dims = [] + for i in range(rank): + if i not in normalized_axes: + output_dims.append(_common.get_expr(input_shape, i)) + + # Create output shape + output_shape = ir.Shape([0] * len(output_dims)) + for i, dim_expr in enumerate(output_dims): + _common.set_expr(output_shape, i, dim_expr) + + return _common.InferenceResult( + values=(ir.Value(shape=output_shape, type=node.inputs[0].type),) + ) diff --git a/src/onnx_ir/_shape_type_inference/ops/standard_ops.py b/src/onnx_ir/_shape_type_inference/ops/standard_ops.py index 80e683fa..27b16ed5 100644 --- a/src/onnx_ir/_shape_type_inference/ops/standard_ops.py +++ b/src/onnx_ir/_shape_type_inference/ops/standard_ops.py @@ -5,6 +5,8 @@ import sys from collections.abc import Collection +import sympy + import onnx_ir as ir from onnx_ir._shape_type_inference import _common @@ -48,7 +50,40 @@ def broadcast_shapes_bidirectional(shape1: ir.Shape, shape2: ir.Shape) -> ir.Sha Returns: A new shape that is the result of broadcasting both shapes. """ - # TODO: Use _common.get_expr and use sympy for broadcasting logic + rank1 = len(shape1) + rank2 = len(shape2) + new_rank = max(rank1, rank2) + new_dims = [] + + for i in range(new_rank): + dim1_idx = rank1 - 1 - i + dim2_idx = rank2 - 1 - i + + # Get expressions for dimensions + dim1_expr = _common.get_expr(shape1, dim1_idx) if i < rank1 else sympy.Integer(1) + dim2_expr = _common.get_expr(shape2, dim2_idx) if i < rank2 else sympy.Integer(1) + + # Broadcasting rules + if dim1_expr == 1: + new_dim_expr = dim2_expr + elif dim2_expr == 1: + new_dim_expr = dim1_expr + elif dim1_expr == dim2_expr: + new_dim_expr = dim1_expr + else: + # Incompatible dimensions - this should be caught at runtime + # For symbolic inference, we assume they can be broadcast + new_dim_expr = sympy.Max(dim1_expr, dim2_expr) + + # Add to the front to maintain right-to-left processing order + new_dims.insert(0, new_dim_expr) + + # Create new shape and set dimensions + new_shape = ir.Shape([0] * new_rank) + for i, expr in enumerate(new_dims): + _common.set_expr(new_shape, i, expr) + + return new_shape class BinaryInferrer(_common.NodeInferrer): @@ -78,3 +113,16 @@ def infer(self, node: ir.Node) -> _common.InferenceResult: return _common.InferenceResult( failure=f"Input types do not match: {first_type} vs {second_type}." ) + + # Broadcast the input shapes + first_shape = node.inputs[0].shape + second_shape = node.inputs[1].shape + if first_shape is None or second_shape is None: + return _common.InferenceResult(failure="Input shapes cannot be None.") + + output_shape = broadcast_shapes_bidirectional(first_shape, second_shape) + output_type = first_type if first_type is not None else second_type + + return _common.InferenceResult( + values=(ir.Value(shape=output_shape, type=output_type),) + ) diff --git a/src/onnx_ir/_shape_type_inference/ops/transpose.py b/src/onnx_ir/_shape_type_inference/ops/transpose.py new file mode 100644 index 00000000..dd2b9ff1 --- /dev/null +++ b/src/onnx_ir/_shape_type_inference/ops/transpose.py @@ -0,0 +1,78 @@ +"""Transpose operation inferrer for ONNX IR nodes.""" + +import sys +from collections.abc import Collection + +import onnx_ir as ir +from onnx_ir._shape_type_inference import _common + + +class TransposeInferrer(_common.NodeInferrer): + """Inferrer for Transpose operations.""" + + def __init__(self, opsets: Collection[int] | None = None) -> None: + """Initialize the Transpose inferrer.""" + if opsets is None: + opsets = range(sys.maxsize) + super().__init__("Transpose", opsets=opsets) + + def infer(self, node: ir.Node) -> _common.InferenceResult: + """Infer the output shape and type for Transpose operations.""" + if len(node.inputs) != 1: + return _common.InferenceResult( + failure=f"Transpose operation must have exactly one input, got {len(node.inputs)}." + ) + if node.inputs[0] is None: + return _common.InferenceResult(failure="Transpose operation input cannot be None.") + if len(node.outputs) != 1: + return _common.InferenceResult( + failure=f"Transpose operation must have exactly one output, got {len(node.outputs)}." + ) + + input_shape = node.inputs[0].shape + if input_shape is None: + return _common.InferenceResult(failure="Transpose input shape cannot be None.") + + rank = len(input_shape) + + # Get permutation from attributes + perm = None + for attr in node.attributes: + if attr.name == "perm": + perm = list(attr.value.ints) + break + + # Default permutation is reversed order + if perm is None: + perm = list(reversed(range(rank))) + + # Validate permutation + if len(perm) != rank: + return _common.InferenceResult( + failure=f"Permutation length {len(perm)} does not match input rank {rank}." + ) + + if sorted(perm) != list(range(rank)): + return _common.InferenceResult( + failure=f"Invalid permutation {perm}. Must be a permutation of [0, 1, ..., {rank - 1}]." + ) + + # Apply permutation to create output shape + output_shape = ir.Shape([0] * rank) + for i, axis in enumerate(perm): + # Handle negative axis + if axis < 0: + axis += rank + + if axis < 0 or axis >= rank: + return _common.InferenceResult( + failure=f"Permutation axis {axis} is out of bounds for rank {rank}." + ) + + # Copy dimension from input to output according to permutation + input_dim_expr = _common.get_expr(input_shape, axis) + _common.set_expr(output_shape, i, input_dim_expr) + + return _common.InferenceResult( + values=(ir.Value(shape=output_shape, type=node.inputs[0].type),) + ) diff --git a/src/onnx_ir/_shape_type_inference/ops/unsqueeze.py b/src/onnx_ir/_shape_type_inference/ops/unsqueeze.py new file mode 100644 index 00000000..9b3b26b8 --- /dev/null +++ b/src/onnx_ir/_shape_type_inference/ops/unsqueeze.py @@ -0,0 +1,94 @@ +"""Unsqueeze operation inferrer for ONNX IR nodes.""" + +import sys +from collections.abc import Collection + +import onnx_ir as ir +from onnx_ir._shape_type_inference import _common + + +class UnsqueezeInferrer(_common.NodeInferrer): + """Inferrer for Unsqueeze operations.""" + + def __init__(self, opsets: Collection[int] | None = None) -> None: + """Initialize the Unsqueeze inferrer.""" + if opsets is None: + opsets = range(sys.maxsize) + super().__init__("Unsqueeze", opsets=opsets) + + def infer(self, node: ir.Node) -> _common.InferenceResult: + """Infer the output shape and type for Unsqueeze operations.""" + if len(node.inputs) < 1 or len(node.inputs) > 2: + return _common.InferenceResult( + failure=f"Unsqueeze operation must have 1 or 2 inputs, got {len(node.inputs)}." + ) + if node.inputs[0] is None: + return _common.InferenceResult(failure="Unsqueeze operation input cannot be None.") + if len(node.outputs) != 1: + return _common.InferenceResult( + failure=f"Unsqueeze operation must have exactly one output, got {len(node.outputs)}." + ) + + input_shape = node.inputs[0].shape + if input_shape is None: + return _common.InferenceResult(failure="Unsqueeze input shape cannot be None.") + + input_rank = len(input_shape) + + # Get axes to unsqueeze + axes = None + + # Check for axes in second input (opset >= 13) + if len(node.inputs) == 2 and node.inputs[1] is not None: + if ( + hasattr(node.inputs[1], "initializer_value") + and node.inputs[1].initializer_value is not None + ): + axes = node.inputs[1].initializer_value.tolist() + if not isinstance(axes, list): + axes = [axes] + else: + # Check for axes attribute (opset < 13) + for attr in node.attributes: + if attr.name == "axes": + axes = list(attr.value.ints) + break + + if axes is None: + return _common.InferenceResult(failure="Unsqueeze operation requires axes.") + + # Calculate output rank + output_rank = input_rank + len(axes) + + # Normalize negative axes relative to output rank + normalized_axes = [] + for axis in axes: + if axis < 0: + axis += output_rank + if axis < 0 or axis >= output_rank: + return _common.InferenceResult( + failure=f"Unsqueeze axis {axis} is out of bounds for output rank {output_rank}." + ) + normalized_axes.append(axis) + + # Check for duplicate axes + if len(set(normalized_axes)) != len(normalized_axes): + return _common.InferenceResult(failure="Unsqueeze axes must be unique.") + + # Build output shape by inserting 1s at specified axes + output_shape = ir.Shape([0] * output_rank) + input_axis = 0 + + for output_axis in range(output_rank): + if output_axis in normalized_axes: + # Insert dimension of size 1 + _common.set_expr(output_shape, output_axis, 1) + else: + # Copy dimension from input + input_dim_expr = _common.get_expr(input_shape, input_axis) + _common.set_expr(output_shape, output_axis, input_dim_expr) + input_axis += 1 + + return _common.InferenceResult( + values=(ir.Value(shape=output_shape, type=node.inputs[0].type),) + ) From 65e3dd24332f24c0b32bc573f8d0a0a6bdc29400 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Sun, 29 Jun 2025 17:21:07 -0700 Subject: [PATCH 11/31] concat Signed-off-by: Justin Chu --- .../_shape_type_inference/ops/concat.py | 54 +++++++------------ 1 file changed, 20 insertions(+), 34 deletions(-) diff --git a/src/onnx_ir/_shape_type_inference/ops/concat.py b/src/onnx_ir/_shape_type_inference/ops/concat.py index 60eccb39..7f65d42c 100644 --- a/src/onnx_ir/_shape_type_inference/ops/concat.py +++ b/src/onnx_ir/_shape_type_inference/ops/concat.py @@ -1,7 +1,6 @@ """Concat operation inferrer for ONNX IR nodes.""" import sys -from collections.abc import Collection import onnx_ir as ir from onnx_ir._shape_type_inference import _common @@ -10,11 +9,9 @@ class ConcatInferrer(_common.NodeInferrer): """Inferrer for Concat operations.""" - def __init__(self, opsets: Collection[int] | None = None) -> None: + def __init__(self) -> None: """Initialize the Concat inferrer.""" - if opsets is None: - opsets = range(sys.maxsize) - super().__init__("Concat", opsets=opsets) + super().__init__("Concat", opsets=range(sys.maxsize)) def infer(self, node: ir.Node) -> _common.InferenceResult: """Infer the output shape and type for Concat operations.""" @@ -30,38 +27,36 @@ def infer(self, node: ir.Node) -> _common.InferenceResult: ) # Get axis attribute - axis_attr = None - for attr in node.attributes: - if attr.name == "axis": - axis_attr = attr.value.i - break - - if axis_attr is None: + axis = node.attributes.get_int("axis") + if axis is None: return _common.InferenceResult(failure="Concat operation requires axis attribute.") # Get first input shape as base first_shape = node.inputs[0].shape if first_shape is None: return _common.InferenceResult(failure="Concat input shapes cannot be None.") + first_type = node.inputs[0].type rank = len(first_shape) if rank == 0: return _common.InferenceResult(failure="Concat inputs cannot be scalars.") # Handle negative axis - if axis_attr < 0: - axis_attr += rank + if axis < 0: + axis += rank - if axis_attr < 0 or axis_attr >= rank: + if axis < 0 or axis >= rank: return _common.InferenceResult( - failure=f"Concat axis {axis_attr} is out of bounds for rank {rank}." + failure=f"Concat axis {axis} is out of bounds for rank {rank}." ) # Check that all inputs have compatible shapes output_shape = ir.Shape(list(first_shape)) - concat_dim_size = _common.get_expr(first_shape, axis_attr) + concat_dim_size = _common.get_expr(first_shape, axis) for i, inp in enumerate(node.inputs[1:], 1): + if inp is None: + return _common.InferenceResult(failure=f"Input {i} cannot be None.") if inp.shape is None: return _common.InferenceResult(failure=f"Input {i} shape cannot be None.") @@ -71,22 +66,13 @@ def infer(self, node: ir.Node) -> _common.InferenceResult: failure=f"All inputs must have same rank. Input {i} has rank {len(input_shape)}, expected {rank}." ) - # Check non-concat dimensions are compatible - for dim_idx in range(rank): - if dim_idx == axis_attr: - # Accumulate concat dimension - concat_dim_size = concat_dim_size + _common.get_expr(input_shape, dim_idx) - else: - # Check compatibility of other dimensions - dim1 = _common.get_expr(first_shape, dim_idx) - dim2 = _common.get_expr(input_shape, dim_idx) - # For symbolic inference, we assume they are compatible - # In practice, this would need runtime verification + # TODO(justinchuby): Check non-concat dimensions are compatible + concat_dim_size = concat_dim_size + _common.get_expr(input_shape, axis) + if inp.type != first_type: + return _common.InferenceResult( + failure=f"Input {i} type {inp.type} does not match first input type {first_type}." + ) # Set the concat dimension in output shape - _common.set_expr(output_shape, axis_attr, concat_dim_size) - - output_type = node.inputs[0].type - return _common.InferenceResult( - values=(ir.Value(shape=output_shape, type=output_type),) - ) + _common.set_expr(output_shape, axis, concat_dim_size) + return _common.InferenceResult(values=(ir.Value(shape=output_shape, type=first_type),)) From 79607709c5130913405daa94129b295b7a5ad9c9 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Sun, 29 Jun 2025 17:41:12 -0700 Subject: [PATCH 12/31] Update _maybe_convert_to_symbolic_dim Signed-off-by: Justin Chu --- src/onnx_ir/_core.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/src/onnx_ir/_core.py b/src/onnx_ir/_core.py index 5e1f13eb..d64b2154 100644 --- a/src/onnx_ir/_core.py +++ b/src/onnx_ir/_core.py @@ -46,6 +46,7 @@ import ml_dtypes import numpy as np import sympy +import sympy.utilities.misc from typing_extensions import TypeIs import onnx_ir @@ -1118,7 +1119,7 @@ class SymbolicDim(_protocols.SymbolicDimProtocol, _display.PrettyPrintable): __slots__ = ("_expr", "_value") - def __init__(self, value: str | None, /, expr: sympy.Expr | None) -> None: + def __init__(self, value: str | None, /, expr: sympy.Expr | None = None) -> None: """Initialize a symbolic dimension. Args: @@ -1134,7 +1135,7 @@ def __init__(self, value: str | None, /, expr: sympy.Expr | None) -> None: "If you are creating a Shape, use int directly instead of SymbolicDim." ) self._value = value - self._expr: sympy.Expr | None = None + self._expr: sympy.Expr | None = expr def __eq__(self, other: object) -> bool: """Check equality with another SymbolicDim or string/None.""" @@ -1204,15 +1205,18 @@ def _maybe_convert_to_symbolic_dim( """ if dim is None or isinstance(dim, str): return SymbolicDim(dim) - if isinstance(dim, sympy.Expr): - # If the dimension is a sympy expression, we create a SymbolicDim with it - return SymbolicDim(str(dim), expr=dim) if _is_int_compatible(dim): return int(dim) + if isinstance(dim, sympy.Expr): + # If the dimension is a sympy expression, we create a SymbolicDim with it + expr = sympy.sympify(dim) + if expr.is_integer: + return sympy.utilities.misc.as_int(expr) + return SymbolicDim(str(expr), expr=sympy.sympify(expr)) if isinstance(dim, SymbolicDim): return dim raise TypeError( - f"Expected int, str, None or SymbolicDim, but value {dim!r} has type '{type(dim)}'" + f"Expected int, str, sympy.Expr, None or SymbolicDim, but value {dim!r} has type '{type(dim)}'" ) From a7704c5bd25001d7dbd02e9bc405f56881902d04 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Sun, 29 Jun 2025 17:47:02 -0700 Subject: [PATCH 13/31] reshape Signed-off-by: Justin Chu --- .../_shape_type_inference/ops/reshape.py | 71 +++++-------------- 1 file changed, 16 insertions(+), 55 deletions(-) diff --git a/src/onnx_ir/_shape_type_inference/ops/reshape.py b/src/onnx_ir/_shape_type_inference/ops/reshape.py index f98b0318..a4b43980 100644 --- a/src/onnx_ir/_shape_type_inference/ops/reshape.py +++ b/src/onnx_ir/_shape_type_inference/ops/reshape.py @@ -1,7 +1,8 @@ """Reshape operation inferrer for ONNX IR nodes.""" import sys -from collections.abc import Collection + +import sympy import onnx_ir as ir from onnx_ir._shape_type_inference import _common @@ -10,11 +11,8 @@ class ReshapeInferrer(_common.NodeInferrer): """Inferrer for Reshape operations.""" - def __init__(self, opsets: Collection[int] | None = None) -> None: - """Initialize the Reshape inferrer.""" - if opsets is None: - opsets = range(sys.maxsize) - super().__init__("Reshape", opsets=opsets) + def __init__(self) -> None: + super().__init__("Reshape", opsets=range(sys.maxsize)) def infer(self, node: ir.Node) -> _common.InferenceResult: """Infer the output shape and type for Reshape operations.""" @@ -37,51 +35,16 @@ def infer(self, node: ir.Node) -> _common.InferenceResult: # Try to get the shape values from the second input # For symbolic inference, we may not have concrete values - if ( - hasattr(shape_input, "initializer_value") - and shape_input.initializer_value is not None - ): - shape_values = shape_input.initializer_value.tolist() - return self._infer_with_shape_values( - input_shape, shape_values, node.inputs[0].type - ) - else: - # Handle symbolic case where shape is not known at compile time - shape_shape = shape_input.shape - if shape_shape is None or len(shape_shape) != 1: - return _common.InferenceResult( - failure="Reshape shape input must be a 1D tensor." - ) - - shape_rank = shape_shape[0] - if isinstance(shape_rank, int): - # Create symbolic dimensions for the output - output_shape = ir.Shape([]) - for i in range(shape_rank): - output_shape.append(ir.SymbolicDim(f"reshape_dim_{i}")) - - return _common.InferenceResult( - values=(ir.Value(shape=output_shape, type=node.inputs[0].type),) - ) - else: - return _common.InferenceResult( - failure="Cannot infer reshape output shape with symbolic rank." - ) - - def _infer_with_shape_values( - self, input_shape: ir.Shape, shape_values: list, input_type - ) -> _common.InferenceResult: - """Infer output shape when shape values are known.""" + shape = ir.convenience.get_const_tensor(shape_input) + if shape is None: + return _common.InferenceResult(failure="Reshape shape input is not known.") + + shape_values = shape.numpy().tolist() + # Calculate total elements in input total_elements = sympy.Integer(1) - for dim in input_shape: - if isinstance(dim, ir.SymbolicDim): - if dim.expr is not None: - total_elements *= dim.expr - else: - total_elements *= sympy.Symbol(dim.value) - else: - total_elements *= sympy.Integer(dim) + for dim in range(input_shape.rank()): + total_elements *= _common.get_expr(input_shape, dim) # Process shape values output_dims = [] @@ -106,7 +69,7 @@ def _infer_with_shape_values( output_dims.append(dim_expr) non_deferred_size *= dim_expr else: - output_dims.append(sympy.Integer(dim_value)) + output_dims.append(dim_value) non_deferred_size *= sympy.Integer(dim_value) # Calculate deferred dimension @@ -115,8 +78,6 @@ def _infer_with_shape_values( output_dims[deferred_dim_idx] = deferred_dim # Create output shape - output_shape = ir.Shape([0] * len(output_dims)) - for i, dim_expr in enumerate(output_dims): - _common.set_expr(output_shape, i, dim_expr) - - return _common.InferenceResult(values=(ir.Value(shape=output_shape, type=input_type),)) + return _common.InferenceResult( + values=(ir.Value(shape=ir.Shape(output_dims), type=node.inputs[0].type),) + ) From 922a59718e9c053aaa99df551e51760bcfd93cd3 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Sun, 29 Jun 2025 17:50:27 -0700 Subject: [PATCH 14/31] Update the way dim is set Signed-off-by: Justin Chu --- src/onnx_ir/_core.py | 2 +- src/onnx_ir/_shape_type_inference/_common.py | 21 ------------------- .../_shape_type_inference/ops/concat.py | 6 +++--- 3 files changed, 4 insertions(+), 25 deletions(-) diff --git a/src/onnx_ir/_core.py b/src/onnx_ir/_core.py index d64b2154..c454fd27 100644 --- a/src/onnx_ir/_core.py +++ b/src/onnx_ir/_core.py @@ -1357,7 +1357,7 @@ def __getitem__(self, index: slice) -> tuple[int | SymbolicDim, ...]: ... def __getitem__(self, index): return tuple(self._dims)[index] - def __setitem__(self, index: int, value: int | SymbolicDim | str | None) -> None: + def __setitem__(self, index: int, value: int | SymbolicDim | str | sympy.Expr | None) -> None: """Set the dimension at the index. Args: diff --git a/src/onnx_ir/_shape_type_inference/_common.py b/src/onnx_ir/_shape_type_inference/_common.py index 6fae2350..758d1e5a 100644 --- a/src/onnx_ir/_shape_type_inference/_common.py +++ b/src/onnx_ir/_shape_type_inference/_common.py @@ -30,27 +30,6 @@ def get_expr(shape: ir.Shape, index: int) -> sympy.Expr: return sympy.Integer(dim) -def set_expr(shape: ir.Shape, index: int, expr: sympy.Expr | int) -> None: - """Set the expression or value at a specific index in the shape. - - Args: - shape: The shape to set the expression in. - index: The index of the dimension to set. - expr: The expression or value to set at the specified index. - """ - from sympy.utilities.misc import as_int - - if isinstance(expr, (int, np.integer)): - shape[index] = int(expr) - return - assert isinstance(expr, sympy.Expr), f"Expected sympy.Expr or int, got {type(expr)}" - expr = sympy.sympify(expr) - if expr.is_integer: - shape[index] = as_int(expr) - return - shape[index] = ir.SymbolicDim(str(expr), expr=expr) - - @dataclasses.dataclass class InferenceResult: values: Sequence[ir.Value] | None = None diff --git a/src/onnx_ir/_shape_type_inference/ops/concat.py b/src/onnx_ir/_shape_type_inference/ops/concat.py index 7f65d42c..988a475d 100644 --- a/src/onnx_ir/_shape_type_inference/ops/concat.py +++ b/src/onnx_ir/_shape_type_inference/ops/concat.py @@ -51,7 +51,7 @@ def infer(self, node: ir.Node) -> _common.InferenceResult: ) # Check that all inputs have compatible shapes - output_shape = ir.Shape(list(first_shape)) + output_dims = list(first_shape.dims) concat_dim_size = _common.get_expr(first_shape, axis) for i, inp in enumerate(node.inputs[1:], 1): @@ -74,5 +74,5 @@ def infer(self, node: ir.Node) -> _common.InferenceResult: ) # Set the concat dimension in output shape - _common.set_expr(output_shape, axis, concat_dim_size) - return _common.InferenceResult(values=(ir.Value(shape=output_shape, type=first_type),)) + output_dims[axis] = concat_dim_size + return _common.InferenceResult(values=(ir.Value(shape=ir.Shape(output_dims), type=first_type),)) From 918384840a2b977a5ea3c4a2a93c421bcbb32ff8 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Sun, 29 Jun 2025 17:58:14 -0700 Subject: [PATCH 15/31] Simplify Signed-off-by: Justin Chu --- .../_shape_type_inference/ops/matmul.py | 50 ++++++++----------- .../_shape_type_inference/ops/standard_ops.py | 14 ++---- .../_shape_type_inference/ops/transpose.py | 22 +++----- 3 files changed, 33 insertions(+), 53 deletions(-) diff --git a/src/onnx_ir/_shape_type_inference/ops/matmul.py b/src/onnx_ir/_shape_type_inference/ops/matmul.py index ded876b6..33fade6f 100644 --- a/src/onnx_ir/_shape_type_inference/ops/matmul.py +++ b/src/onnx_ir/_shape_type_inference/ops/matmul.py @@ -1,9 +1,6 @@ """MatMul operation inferrer for ONNX IR nodes.""" import sys -from collections.abc import Collection - -import sympy import onnx_ir as ir from onnx_ir._shape_type_inference import _common @@ -13,11 +10,9 @@ class MatMulInferrer(_common.NodeInferrer): """Inferrer for MatMul operations.""" - def __init__(self, opsets: Collection[int] | None = None) -> None: + def __init__(self) -> None: """Initialize the MatMul inferrer.""" - if opsets is None: - opsets = range(sys.maxsize) - super().__init__("MatMul", opsets=opsets) + super().__init__("MatMul", opsets=range(sys.maxsize)) def infer(self, node: ir.Node) -> _common.InferenceResult: """Infer the output shape and type for MatMul operations.""" @@ -45,40 +40,39 @@ def infer(self, node: ir.Node) -> _common.InferenceResult: # Compute output shape based on matrix multiplication rules if lhs_rank == 1 and rhs_rank == 1: - # Vector dot product: (n,) × (n,) -> scalar + # Vector dot product: (n,) x (n,) -> scalar output_shape = ir.Shape([]) elif lhs_rank == 1: - # Matrix-vector: (n,) × (..., n, k) -> (..., k) - output_shape = ir.Shape(rhs_shape[:-2] + rhs_shape[-1:]) + # Matrix-vector: (n,) x (..., n, k) -> (..., k) + output_dims = list(rhs_shape.dims[:-2]) + [rhs_shape.dims[-1]] + output_shape = ir.Shape(output_dims) elif rhs_rank == 1: - # Vector-matrix: (..., m, n) × (n,) -> (..., m) - output_shape = ir.Shape(lhs_shape[:-1]) + # Vector-matrix: (..., m, n) x (n,) -> (..., m) + output_dims = list(lhs_shape.dims[:-1]) + output_shape = ir.Shape(output_dims) else: - # Matrix-matrix: (..., m, n) × (..., n, k) -> (..., m, k) + # Matrix-matrix: (..., m, n) x (..., n, k) -> (..., m, k) # Broadcast batch dimensions - lhs_batch = lhs_shape[:-2] - rhs_batch = rhs_shape[:-2] + lhs_batch = lhs_shape.dims[:-2] + rhs_batch = rhs_shape.dims[:-2] if lhs_batch and rhs_batch: batch_shape = broadcast_shapes_bidirectional( ir.Shape(lhs_batch), ir.Shape(rhs_batch) ) - output_shape = ir.Shape(list(batch_shape) + [lhs_shape[-2], rhs_shape[-1]]) + output_dims = list(batch_shape.dims) + [lhs_shape.dims[-2], rhs_shape.dims[-1]] + output_shape = ir.Shape(output_dims) elif lhs_batch: - output_shape = ir.Shape(list(lhs_batch) + [lhs_shape[-2], rhs_shape[-1]]) + output_dims = list(lhs_batch) + [lhs_shape.dims[-2], rhs_shape.dims[-1]] + output_shape = ir.Shape(output_dims) elif rhs_batch: - output_shape = ir.Shape(list(rhs_batch) + [lhs_shape[-2], rhs_shape[-1]]) + output_dims = list(rhs_batch) + [lhs_shape.dims[-2], rhs_shape.dims[-1]] + output_shape = ir.Shape(output_dims) else: - output_shape = ir.Shape([lhs_shape[-2], rhs_shape[-1]]) - - # Check dimension compatibility for matrix multiplication - if lhs_rank >= 1 and rhs_rank >= 1: - lhs_reduce_dim = ( - _common.get_expr(lhs_shape, -1) if lhs_rank >= 1 else sympy.Integer(1) - ) - rhs_reduce_dim = _common.get_expr(rhs_shape, -2 if rhs_rank >= 2 else 0) + output_dims = [lhs_shape.dims[-2], rhs_shape.dims[-1]] + output_shape = ir.Shape(output_dims) - # For symbolic inference, we assume dimensions are compatible - # In practice, this would need runtime verification + # For symbolic inference, we assume dimensions are compatible + # In practice, this would need runtime verification output_type = node.inputs[0].type return _common.InferenceResult( diff --git a/src/onnx_ir/_shape_type_inference/ops/standard_ops.py b/src/onnx_ir/_shape_type_inference/ops/standard_ops.py index 27b16ed5..84ed1861 100644 --- a/src/onnx_ir/_shape_type_inference/ops/standard_ops.py +++ b/src/onnx_ir/_shape_type_inference/ops/standard_ops.py @@ -78,22 +78,16 @@ def broadcast_shapes_bidirectional(shape1: ir.Shape, shape2: ir.Shape) -> ir.Sha # Add to the front to maintain right-to-left processing order new_dims.insert(0, new_dim_expr) - # Create new shape and set dimensions - new_shape = ir.Shape([0] * new_rank) - for i, expr in enumerate(new_dims): - _common.set_expr(new_shape, i, expr) - - return new_shape + # Create new shape directly + return ir.Shape(new_dims) class BinaryInferrer(_common.NodeInferrer): """Base class for binary operation inferrers.""" - def __init__(self, op_type: str, opsets: Collection[int] | None = None) -> None: + def __init__(self, op_type: str) -> None: """Initialize the binary inferrer with the operation type.""" - if opsets is None: - opsets = range(sys.maxsize) - super().__init__(op_type, opsets=opsets) + super().__init__(op_type, opsets=range(sys.maxsize)) def infer(self, node: ir.Node) -> _common.InferenceResult: """Infer the output shape and type for binary operations.""" diff --git a/src/onnx_ir/_shape_type_inference/ops/transpose.py b/src/onnx_ir/_shape_type_inference/ops/transpose.py index dd2b9ff1..dd81f145 100644 --- a/src/onnx_ir/_shape_type_inference/ops/transpose.py +++ b/src/onnx_ir/_shape_type_inference/ops/transpose.py @@ -1,7 +1,6 @@ """Transpose operation inferrer for ONNX IR nodes.""" import sys -from collections.abc import Collection import onnx_ir as ir from onnx_ir._shape_type_inference import _common @@ -10,11 +9,9 @@ class TransposeInferrer(_common.NodeInferrer): """Inferrer for Transpose operations.""" - def __init__(self, opsets: Collection[int] | None = None) -> None: + def __init__(self) -> None: """Initialize the Transpose inferrer.""" - if opsets is None: - opsets = range(sys.maxsize) - super().__init__("Transpose", opsets=opsets) + super().__init__("Transpose", opsets=range(sys.maxsize)) def infer(self, node: ir.Node) -> _common.InferenceResult: """Infer the output shape and type for Transpose operations.""" @@ -36,11 +33,7 @@ def infer(self, node: ir.Node) -> _common.InferenceResult: rank = len(input_shape) # Get permutation from attributes - perm = None - for attr in node.attributes: - if attr.name == "perm": - perm = list(attr.value.ints) - break + perm = node.attributes.get_ints("perm") # Default permutation is reversed order if perm is None: @@ -58,8 +51,8 @@ def infer(self, node: ir.Node) -> _common.InferenceResult: ) # Apply permutation to create output shape - output_shape = ir.Shape([0] * rank) - for i, axis in enumerate(perm): + output_dims = [] + for axis in perm: # Handle negative axis if axis < 0: axis += rank @@ -70,9 +63,8 @@ def infer(self, node: ir.Node) -> _common.InferenceResult: ) # Copy dimension from input to output according to permutation - input_dim_expr = _common.get_expr(input_shape, axis) - _common.set_expr(output_shape, i, input_dim_expr) + output_dims.append(input_shape.dims[axis]) return _common.InferenceResult( - values=(ir.Value(shape=output_shape, type=node.inputs[0].type),) + values=(ir.Value(shape=ir.Shape(output_dims), type=node.inputs[0].type),) ) From 9300abaf6aec4807ba2399ac55977315f2b03c6e Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Sun, 29 Jun 2025 20:20:03 -0700 Subject: [PATCH 16/31] Update Signed-off-by: Justin Chu --- src/onnx_ir/_shape_type_inference/_common.py | 62 +++++- .../_shape_type_inference/ops/matmul.py | 25 +-- .../_shape_type_inference/ops/squeeze.py | 181 ++++++++++-------- 3 files changed, 173 insertions(+), 95 deletions(-) diff --git a/src/onnx_ir/_shape_type_inference/_common.py b/src/onnx_ir/_shape_type_inference/_common.py index 758d1e5a..cb84429a 100644 --- a/src/onnx_ir/_shape_type_inference/_common.py +++ b/src/onnx_ir/_shape_type_inference/_common.py @@ -5,7 +5,8 @@ import abc import dataclasses from collections.abc import Collection, Sequence - +import functools +from typing import Any, TypeVar, Callable import numpy as np import sympy @@ -65,3 +66,62 @@ def infer(self, node: ir.Node) -> InferenceResult: A sequence of ONNX values containing the inferred shapes. """ raise NotImplementedError + + +def requires_non_none_inputs( + count: int, / +) -> Callable[[Callable[[Any, ir.Node], InferenceResult]], Callable[[Any, ir.Node], InferenceResult]]: + """Ensure that the node has a specific number of non-None inputs. + + Args: + count: The exact number of non-None inputs required for the node. + + Returns: + A decorator that checks the number of inputs and their non-None status. + """ + + def decorator( + func: Callable[[Any, ir.Node], InferenceResult], + ) -> Callable[[Any, ir.Node], InferenceResult]: + @functools.wraps(func) + def wrapper(self, node: ir.Node) -> InferenceResult: + if len(node.inputs) != count: + return InferenceResult( + failure=f"[{node.op_type} must have {count} inputs, got {len(node.inputs)}." + ) + for i, inp in enumerate(node.inputs): + if inp is None: + return InferenceResult(failure=f"{node.op_type} input {i} cannot be None.") + return func(self, node) + + return wrapper + + return decorator + + +def requires_outputs( + count: int, / +) -> Callable[[Callable[[Any, ir.Node], InferenceResult]], Callable[[Any, ir.Node], InferenceResult]]: + """Ensure that the node has a specific number of outputs. + + Args: + count: The exact number of outputs required for the node. + + Returns: + A decorator that checks the number of outputs. + """ + + def decorator( + func: Callable[[Any, ir.Node], InferenceResult], + ) -> Callable[[Any, ir.Node], InferenceResult]: + @functools.wraps(func) + def wrapper(self, node: ir.Node) -> InferenceResult: + if len(node.outputs) != count: + return InferenceResult( + failure=f"[{node.op_type} must have {count} outputs, got {len(node.outputs)}." + ) + return func(self, node) + + return wrapper + + return decorator diff --git a/src/onnx_ir/_shape_type_inference/ops/matmul.py b/src/onnx_ir/_shape_type_inference/ops/matmul.py index 33fade6f..6ebd48d8 100644 --- a/src/onnx_ir/_shape_type_inference/ops/matmul.py +++ b/src/onnx_ir/_shape_type_inference/ops/matmul.py @@ -14,18 +14,11 @@ def __init__(self) -> None: """Initialize the MatMul inferrer.""" super().__init__("MatMul", opsets=range(sys.maxsize)) + @_common.requires_non_none_inputs(2) + @_common.requires_outputs(1) def infer(self, node: ir.Node) -> _common.InferenceResult: """Infer the output shape and type for MatMul operations.""" - if len(node.inputs) != 2: - return _common.InferenceResult( - failure=f"MatMul operation must have exactly two inputs, got {len(node.inputs)}." - ) - if node.inputs[0] is None or node.inputs[1] is None: - return _common.InferenceResult(failure="MatMul operation inputs cannot be None.") - if len(node.outputs) != 1: - return _common.InferenceResult( - failure=f"MatMul operation must have exactly one output, got {len(node.outputs)}." - ) + assert node.inputs[0] is not None and node.inputs[1] is not None lhs_shape = node.inputs[0].shape rhs_shape = node.inputs[1].shape @@ -44,7 +37,7 @@ def infer(self, node: ir.Node) -> _common.InferenceResult: output_shape = ir.Shape([]) elif lhs_rank == 1: # Matrix-vector: (n,) x (..., n, k) -> (..., k) - output_dims = list(rhs_shape.dims[:-2]) + [rhs_shape.dims[-1]] + output_dims = [*rhs_shape.dims[:-2], rhs_shape.dims[-1]] output_shape = ir.Shape(output_dims) elif rhs_rank == 1: # Vector-matrix: (..., m, n) x (n,) -> (..., m) @@ -56,24 +49,22 @@ def infer(self, node: ir.Node) -> _common.InferenceResult: lhs_batch = lhs_shape.dims[:-2] rhs_batch = rhs_shape.dims[:-2] if lhs_batch and rhs_batch: + # TODO(justinchuby): Ensure this is correct batch_shape = broadcast_shapes_bidirectional( ir.Shape(lhs_batch), ir.Shape(rhs_batch) ) - output_dims = list(batch_shape.dims) + [lhs_shape.dims[-2], rhs_shape.dims[-1]] + output_dims = [*batch_shape.dims, lhs_shape.dims[-2], rhs_shape.dims[-1]] output_shape = ir.Shape(output_dims) elif lhs_batch: - output_dims = list(lhs_batch) + [lhs_shape.dims[-2], rhs_shape.dims[-1]] + output_dims = [*lhs_batch, lhs_shape.dims[-2], rhs_shape.dims[-1]] output_shape = ir.Shape(output_dims) elif rhs_batch: - output_dims = list(rhs_batch) + [lhs_shape.dims[-2], rhs_shape.dims[-1]] + output_dims = [*rhs_batch, lhs_shape.dims[-2], rhs_shape.dims[-1]] output_shape = ir.Shape(output_dims) else: output_dims = [lhs_shape.dims[-2], rhs_shape.dims[-1]] output_shape = ir.Shape(output_dims) - # For symbolic inference, we assume dimensions are compatible - # In practice, this would need runtime verification - output_type = node.inputs[0].type return _common.InferenceResult( values=(ir.Value(shape=output_shape, type=output_type),) diff --git a/src/onnx_ir/_shape_type_inference/ops/squeeze.py b/src/onnx_ir/_shape_type_inference/ops/squeeze.py index 95f4c8ba..defa1a9d 100644 --- a/src/onnx_ir/_shape_type_inference/ops/squeeze.py +++ b/src/onnx_ir/_shape_type_inference/ops/squeeze.py @@ -1,101 +1,128 @@ """Squeeze operation inferrer for ONNX IR nodes.""" -import sys -from collections.abc import Collection +from __future__ import annotations +import sys +import logging +from collections.abc import Sequence import onnx_ir as ir from onnx_ir._shape_type_inference import _common +logger = logging.getLogger(__name__) + + +def _compute_output_shape_no_axes(input_shape: ir.Shape) -> ir.Shape: + """Compute output shape when no axes are specified.""" + output_dims = [] + for dim in input_shape.dims: + # For symbolic dimensions, we assume they are not 1 + # Only squeeze literal 1s + if isinstance(dim, int): + if dim == 1: + continue # Skip dimension of size 1 + else: + output_dims.append(dim) + else: + logger.warning( + "Squeeze operation has symbolic dimension %s, assuming it is not 1.", dim + ) + output_dims.append(dim) + return ir.Shape(output_dims) + + +def _normalize_axes(axes: Sequence[int], rank: int) -> set[int]: + """Normalize axes to be within the valid range for the given rank.""" + normalized_axes = set() + for axis in axes: + if axis < 0: + axis += rank + if axis < 0 or axis >= rank: + raise ValueError(f"Squeeze axis {axis} is out of bounds for rank {rank}.") + normalized_axes.add(axis) + return normalized_axes -class SqueezeInferrer(_common.NodeInferrer): - """Inferrer for Squeeze operations.""" - def __init__(self, opsets: Collection[int] | None = None) -> None: +def _compute_output_shape_with_axes(input_shape: ir.Shape, axes: set[int]) -> ir.Shape: + """Compute output shape when axes are specified.""" + output_dims = [dim for i, dim in enumerate(input_shape.dims) if i not in axes] + return ir.Shape(output_dims) + + +class Squeeze12Inferrer(_common.NodeInferrer): + """Inferrer for Squeeze-12 and lower. + + We assume that axes doesn't have duplicates. + """ + + def __init__(self) -> None: """Initialize the Squeeze inferrer.""" - if opsets is None: - opsets = range(sys.maxsize) - super().__init__("Squeeze", opsets=opsets) + super().__init__("Squeeze", opsets=range(13)) + @_common.requires_non_none_inputs(1) + @_common.requires_outputs(1) def infer(self, node: ir.Node) -> _common.InferenceResult: """Infer the output shape and type for Squeeze operations.""" - if len(node.inputs) < 1 or len(node.inputs) > 2: - return _common.InferenceResult( - failure=f"Squeeze operation must have 1 or 2 inputs, got {len(node.inputs)}." - ) - if node.inputs[0] is None: - return _common.InferenceResult(failure="Squeeze operation input cannot be None.") - if len(node.outputs) != 1: - return _common.InferenceResult( - failure=f"Squeeze operation must have exactly one output, got {len(node.outputs)}." - ) - - input_shape = node.inputs[0].shape + input = node.inputs[0] + assert input is not None + input_shape = input.shape if input_shape is None: - return _common.InferenceResult(failure="Squeeze input shape cannot be None.") + return _common.InferenceResult(failure="Squeeze input shape is not known.") rank = len(input_shape) # Get axes to squeeze - axes = None - - # Check for axes in second input (opset >= 13) - if len(node.inputs) == 2 and node.inputs[1] is not None: - if ( - hasattr(node.inputs[1], "initializer_value") - and node.inputs[1].initializer_value is not None - ): - axes = node.inputs[1].initializer_value.tolist() - if not isinstance(axes, list): - axes = [axes] - else: - # Check for axes attribute (opset < 13) - for attr in node.attributes: - if attr.name == "axes": - axes = list(attr.value.ints) - break + axes = node.attributes.get_ints("axes") if axes is None: - # No axes specified - squeeze all dimensions of size 1 - output_dims = [] - for i, dim in enumerate(input_shape): - dim_expr = _common.get_expr(input_shape, i) - # For symbolic dimensions, we assume they are not 1 - # Only squeeze literal 1s - if isinstance(dim, int) and dim == 1: - continue # Skip dimension of size 1 - else: - output_dims.append(dim_expr) + output_shape = _compute_output_shape_no_axes(input_shape) else: - # Normalize negative axes - normalized_axes = [] - for axis in axes: - if axis < 0: - axis += rank - if axis < 0 or axis >= rank: - return _common.InferenceResult( - failure=f"Squeeze axis {axis} is out of bounds for rank {rank}." - ) - normalized_axes.append(axis) - - # Validate that specified axes have dimension 1 (for literal dimensions) - for axis in normalized_axes: - dim = input_shape[axis] - if isinstance(dim, int) and dim != 1: - return _common.InferenceResult( - failure=f"Cannot squeeze axis {axis} with dimension {dim} (must be 1)." - ) - - # Build output shape excluding squeezed axes - output_dims = [] - for i in range(rank): - if i not in normalized_axes: - output_dims.append(_common.get_expr(input_shape, i)) - - # Create output shape - output_shape = ir.Shape([0] * len(output_dims)) - for i, dim_expr in enumerate(output_dims): - _common.set_expr(output_shape, i, dim_expr) + try: + axes = _normalize_axes(axes, rank) + except ValueError as e: + return _common.InferenceResult(failure=str(e)) + output_shape = _compute_output_shape_with_axes(input_shape, axes) + return _common.InferenceResult(values=(ir.Value(shape=output_shape, type=input.type),)) + + +class Squeeze13Inferrer(_common.NodeInferrer): + """Inferrer for Squeeze-13 and higher. + + We assume that axes doesn't have duplicates. + """ + def __init__(self) -> None: + """Initialize the Squeeze inferrer.""" + super().__init__("Squeeze", opsets=range(14, sys.maxsize)) + + @_common.requires_non_none_inputs(2) + @_common.requires_outputs(1) + def infer(self, node: ir.Node) -> _common.InferenceResult: + """Infer the output shape and type for Squeeze operations.""" + assert node.inputs[0] is not None + assert node.inputs[1] is not None + + input_shape = node.inputs[0].shape + if input_shape is None: + return _common.InferenceResult(failure="Squeeze input shape is not known.") + + rank = len(input_shape) + + axes_tensor = ir.convenience.get_const_tensor(node.inputs[1]) + if axes_tensor is not None: + try: + axes = _normalize_axes(axes_tensor.numpy().tolist(), rank) + except ValueError as e: + return _common.InferenceResult(failure=str(e)) + output_shape = _compute_output_shape_with_axes(input_shape, axes) + else: + axes_shape = node.inputs[1].shape + if axes_shape is None or axes_shape.is_dynamic(): + return _common.InferenceResult( + failure="Squeeze axes input shape is not known or is dynamic" + ) + removed_axes_count = axes_shape[0] + assert isinstance(removed_axes_count, int) + output_shape = ir.Shape([None] * (rank - removed_axes_count)) return _common.InferenceResult( values=(ir.Value(shape=output_shape, type=node.inputs[0].type),) ) From 8747a93c3e53866337a9cada5471bd96da79b962 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Sun, 29 Jun 2025 20:25:20 -0700 Subject: [PATCH 17/31] Handle unknown dims Signed-off-by: Justin Chu --- src/onnx_ir/_core.py | 4 +- src/onnx_ir/_shape_type_inference/_common.py | 10 +- .../_shape_type_inference/_inferencer.py | 4 +- .../_shape_type_inference/ops/concat.py | 4 +- .../_shape_type_inference/ops/squeeze.py | 3 +- .../_shape_type_inference/ops/unsqueeze.py | 184 ++++++++++++------ 6 files changed, 137 insertions(+), 72 deletions(-) diff --git a/src/onnx_ir/_core.py b/src/onnx_ir/_core.py index c454fd27..84281bf2 100644 --- a/src/onnx_ir/_core.py +++ b/src/onnx_ir/_core.py @@ -1357,7 +1357,9 @@ def __getitem__(self, index: slice) -> tuple[int | SymbolicDim, ...]: ... def __getitem__(self, index): return tuple(self._dims)[index] - def __setitem__(self, index: int, value: int | SymbolicDim | str | sympy.Expr | None) -> None: + def __setitem__( + self, index: int, value: int | SymbolicDim | str | sympy.Expr | None + ) -> None: """Set the dimension at the index. Args: diff --git a/src/onnx_ir/_shape_type_inference/_common.py b/src/onnx_ir/_shape_type_inference/_common.py index cb84429a..6ce375c0 100644 --- a/src/onnx_ir/_shape_type_inference/_common.py +++ b/src/onnx_ir/_shape_type_inference/_common.py @@ -27,6 +27,8 @@ def get_expr(shape: ir.Shape, index: int) -> sympy.Expr: if isinstance(dim, ir.SymbolicDim): if dim.expr is not None: return dim.expr + if dim.value is None: + return sympy.Symbol("__unknown__") return sympy.Symbol(dim.value) return sympy.Integer(dim) @@ -70,7 +72,9 @@ def infer(self, node: ir.Node) -> InferenceResult: def requires_non_none_inputs( count: int, / -) -> Callable[[Callable[[Any, ir.Node], InferenceResult]], Callable[[Any, ir.Node], InferenceResult]]: +) -> Callable[ + [Callable[[Any, ir.Node], InferenceResult]], Callable[[Any, ir.Node], InferenceResult] +]: """Ensure that the node has a specific number of non-None inputs. Args: @@ -101,7 +105,9 @@ def wrapper(self, node: ir.Node) -> InferenceResult: def requires_outputs( count: int, / -) -> Callable[[Callable[[Any, ir.Node], InferenceResult]], Callable[[Any, ir.Node], InferenceResult]]: +) -> Callable[ + [Callable[[Any, ir.Node], InferenceResult]], Callable[[Any, ir.Node], InferenceResult] +]: """Ensure that the node has a specific number of outputs. Args: diff --git a/src/onnx_ir/_shape_type_inference/_inferencer.py b/src/onnx_ir/_shape_type_inference/_inferencer.py index 8538bfe0..abb3ea02 100644 --- a/src/onnx_ir/_shape_type_inference/_inferencer.py +++ b/src/onnx_ir/_shape_type_inference/_inferencer.py @@ -949,9 +949,7 @@ def _infer_symbolic_compute_ops(self, node): funcs = { "Add": lambda l: l[0] + l[1], "Div": lambda l: ( - int(l[0] // l[1]) - if isinstance(l[0] // l[1], float) - else l[0] // l[1] + int(l[0] // l[1]) if isinstance(l[0] // l[1], float) else l[0] // l[1] ), # integer div in sympy "Equal": lambda l: l[0] == l[1], "Floor": lambda l: sympy.floor(l[0]), diff --git a/src/onnx_ir/_shape_type_inference/ops/concat.py b/src/onnx_ir/_shape_type_inference/ops/concat.py index 988a475d..318df26f 100644 --- a/src/onnx_ir/_shape_type_inference/ops/concat.py +++ b/src/onnx_ir/_shape_type_inference/ops/concat.py @@ -75,4 +75,6 @@ def infer(self, node: ir.Node) -> _common.InferenceResult: # Set the concat dimension in output shape output_dims[axis] = concat_dim_size - return _common.InferenceResult(values=(ir.Value(shape=ir.Shape(output_dims), type=first_type),)) + return _common.InferenceResult( + values=(ir.Value(shape=ir.Shape(output_dims), type=first_type),) + ) diff --git a/src/onnx_ir/_shape_type_inference/ops/squeeze.py b/src/onnx_ir/_shape_type_inference/ops/squeeze.py index defa1a9d..df4c3b92 100644 --- a/src/onnx_ir/_shape_type_inference/ops/squeeze.py +++ b/src/onnx_ir/_shape_type_inference/ops/squeeze.py @@ -2,9 +2,10 @@ from __future__ import annotations -import sys import logging +import sys from collections.abc import Sequence + import onnx_ir as ir from onnx_ir._shape_type_inference import _common diff --git a/src/onnx_ir/_shape_type_inference/ops/unsqueeze.py b/src/onnx_ir/_shape_type_inference/ops/unsqueeze.py index 9b3b26b8..25b4c5d2 100644 --- a/src/onnx_ir/_shape_type_inference/ops/unsqueeze.py +++ b/src/onnx_ir/_shape_type_inference/ops/unsqueeze.py @@ -1,93 +1,149 @@ """Unsqueeze operation inferrer for ONNX IR nodes.""" +from __future__ import annotations + +import logging import sys -from collections.abc import Collection +from collections.abc import Sequence import onnx_ir as ir from onnx_ir._shape_type_inference import _common +logger = logging.getLogger(__name__) + + +def _normalize_axes(axes: Sequence[int], output_rank: int) -> set[int]: + """Normalize axes to be within the valid range for the given output rank.""" + normalized_axes = set() + for axis in axes: + if axis < 0: + axis += output_rank + if axis < 0 or axis >= output_rank: + raise ValueError( + f"Unsqueeze axis {axis} is out of bounds for output rank {output_rank}." + ) + normalized_axes.add(axis) + + # Check for duplicate axes + if len(normalized_axes) != len(axes): + raise ValueError("Unsqueeze axes must be unique.") -class UnsqueezeInferrer(_common.NodeInferrer): - """Inferrer for Unsqueeze operations.""" + return normalized_axes + + +def _compute_output_shape(input_shape: ir.Shape, axes: set[int]) -> ir.Shape: + """Compute output shape by inserting 1s at specified axes.""" + input_rank = len(input_shape) + output_rank = input_rank + len(axes) + + output_dims = [] + input_axis = 0 + + for output_axis in range(output_rank): + if output_axis in axes: + # Insert dimension of size 1 + output_dims.append(1) + else: + # Copy dimension from input + output_dims.append(input_shape.dims[input_axis]) + input_axis += 1 - def __init__(self, opsets: Collection[int] | None = None) -> None: + return ir.Shape(output_dims) + + +class Unsqueeze12Inferrer(_common.NodeInferrer): + """Inferrer for Unsqueeze-12 and lower. + + In these versions, axes are provided as an attribute. + We assume that axes doesn't have duplicates. + """ + + def __init__(self) -> None: """Initialize the Unsqueeze inferrer.""" - if opsets is None: - opsets = range(sys.maxsize) - super().__init__("Unsqueeze", opsets=opsets) + super().__init__("Unsqueeze", opsets=range(13)) + @_common.requires_non_none_inputs(1) + @_common.requires_outputs(1) def infer(self, node: ir.Node) -> _common.InferenceResult: """Infer the output shape and type for Unsqueeze operations.""" - if len(node.inputs) < 1 or len(node.inputs) > 2: - return _common.InferenceResult( - failure=f"Unsqueeze operation must have 1 or 2 inputs, got {len(node.inputs)}." - ) - if node.inputs[0] is None: - return _common.InferenceResult(failure="Unsqueeze operation input cannot be None.") - if len(node.outputs) != 1: + input = node.inputs[0] + assert input is not None + input_shape = input.shape + if input_shape is None: + return _common.InferenceResult(failure="Unsqueeze input shape is not known.") + + input_rank = len(input_shape) + + # Get axes to unsqueeze from attributes + axes = node.attributes.get_ints("axes") + if axes is None: return _common.InferenceResult( - failure=f"Unsqueeze operation must have exactly one output, got {len(node.outputs)}." + failure="Unsqueeze operation requires axes attribute." ) + output_rank = input_rank + len(axes) + + try: + normalized_axes = _normalize_axes(axes, output_rank) + except ValueError as e: + return _common.InferenceResult(failure=str(e)) + + output_shape = _compute_output_shape(input_shape, normalized_axes) + return _common.InferenceResult(values=(ir.Value(shape=output_shape, type=input.type),)) + + +class Unsqueeze13Inferrer(_common.NodeInferrer): + """Inferrer for Unsqueeze-13 and higher. + + In these versions, axes are provided as a second input tensor. + We assume that axes doesn't have duplicates. + """ + + def __init__(self) -> None: + """Initialize the Unsqueeze inferrer.""" + super().__init__("Unsqueeze", opsets=range(13, sys.maxsize)) + + @_common.requires_non_none_inputs(2) + @_common.requires_outputs(1) + def infer(self, node: ir.Node) -> _common.InferenceResult: + """Infer the output shape and type for Unsqueeze operations.""" + assert node.inputs[0] is not None + assert node.inputs[1] is not None + input_shape = node.inputs[0].shape if input_shape is None: - return _common.InferenceResult(failure="Unsqueeze input shape cannot be None.") + return _common.InferenceResult(failure="Unsqueeze input shape is not known.") input_rank = len(input_shape) - # Get axes to unsqueeze - axes = None - - # Check for axes in second input (opset >= 13) - if len(node.inputs) == 2 and node.inputs[1] is not None: - if ( - hasattr(node.inputs[1], "initializer_value") - and node.inputs[1].initializer_value is not None - ): - axes = node.inputs[1].initializer_value.tolist() - if not isinstance(axes, list): - axes = [axes] - else: - # Check for axes attribute (opset < 13) - for attr in node.attributes: - if attr.name == "axes": - axes = list(attr.value.ints) - break + axes_tensor = ir.convenience.get_const_tensor(node.inputs[1]) + if axes_tensor is not None: + axes = axes_tensor.numpy().tolist() + if not isinstance(axes, list): + axes = [axes] - if axes is None: - return _common.InferenceResult(failure="Unsqueeze operation requires axes.") + output_rank = input_rank + len(axes) - # Calculate output rank - output_rank = input_rank + len(axes) + try: + normalized_axes = _normalize_axes(axes, output_rank) + except ValueError as e: + return _common.InferenceResult(failure=str(e)) - # Normalize negative axes relative to output rank - normalized_axes = [] - for axis in axes: - if axis < 0: - axis += output_rank - if axis < 0 or axis >= output_rank: + output_shape = _compute_output_shape(input_shape, normalized_axes) + else: + # Handle case where axes tensor is not constant + axes_shape = node.inputs[1].shape + if axes_shape is None or axes_shape.is_dynamic(): return _common.InferenceResult( - failure=f"Unsqueeze axis {axis} is out of bounds for output rank {output_rank}." + failure="Unsqueeze axes input shape is not known or is dynamic" ) - normalized_axes.append(axis) - - # Check for duplicate axes - if len(set(normalized_axes)) != len(normalized_axes): - return _common.InferenceResult(failure="Unsqueeze axes must be unique.") - - # Build output shape by inserting 1s at specified axes - output_shape = ir.Shape([0] * output_rank) - input_axis = 0 - - for output_axis in range(output_rank): - if output_axis in normalized_axes: - # Insert dimension of size 1 - _common.set_expr(output_shape, output_axis, 1) - else: - # Copy dimension from input - input_dim_expr = _common.get_expr(input_shape, input_axis) - _common.set_expr(output_shape, output_axis, input_dim_expr) - input_axis += 1 + + # We know the number of axes to insert but not their positions + added_axes_count = axes_shape[0] + assert isinstance(added_axes_count, int) + output_rank = input_rank + added_axes_count + # Create output shape with unknown dimensions + output_shape = ir.Shape([None] * output_rank) return _common.InferenceResult( values=(ir.Value(shape=output_shape, type=node.inputs[0].type),) From 92049c4a3ec2951c6bd49f204109431452f513d3 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Sun, 29 Jun 2025 20:30:02 -0700 Subject: [PATCH 18/31] Simplify Signed-off-by: Justin Chu --- src/onnx_ir/_shape_type_inference/_common.py | 6 ++-- .../_shape_type_inference/_inferencer.py | 11 +++---- .../_shape_type_inference/ops/standard_ops.py | 30 +++++-------------- .../_shape_type_inference/ops/transpose.py | 14 ++------- 4 files changed, 19 insertions(+), 42 deletions(-) diff --git a/src/onnx_ir/_shape_type_inference/_common.py b/src/onnx_ir/_shape_type_inference/_common.py index 6ce375c0..4070edbb 100644 --- a/src/onnx_ir/_shape_type_inference/_common.py +++ b/src/onnx_ir/_shape_type_inference/_common.py @@ -4,10 +4,10 @@ import abc import dataclasses -from collections.abc import Collection, Sequence import functools -from typing import Any, TypeVar, Callable -import numpy as np +from collections.abc import Collection, Sequence +from typing import Any, Callable + import sympy import onnx_ir as ir diff --git a/src/onnx_ir/_shape_type_inference/_inferencer.py b/src/onnx_ir/_shape_type_inference/_inferencer.py index abb3ea02..51d6499f 100644 --- a/src/onnx_ir/_shape_type_inference/_inferencer.py +++ b/src/onnx_ir/_shape_type_inference/_inferencer.py @@ -407,8 +407,9 @@ def _broadcast_shapes(self, shape1, shape2): if self.auto_merge_: self._add_suggested_merge([dim1, dim2], apply=True) else: + # TODO(justinchuby): Error? logger.warning( - f"unsupported broadcast between {str(dim1)} {str(dim2)}" + "unsupported broadcast between %s %s", dim1, dim2 ) new_shape = [new_dim, *new_shape] return new_shape @@ -627,7 +628,7 @@ def _onnx_infer_subgraph(self, node, subgraph, use_node_input=True, inc_subgraph self.auto_merge_, self.guess_output_rank_, self.verbose_, - prefix=f"{self.prefix_}_{str(self.subgraph_id_)}", + prefix=f"{self.prefix_}_{self.subgraph_id_!s}", ) if inc_subgraph_id: self.subgraph_id_ += 1 @@ -2202,7 +2203,7 @@ def handle_negative_index(index, bound): # handle sympy_data if needed, for slice in shape computation if ( node.input[0] in self.sympy_data_ - and [0] == axes + and axes == [0] and starts is not None and len(starts) == 1 and ends is not None @@ -3197,7 +3198,7 @@ def get_prereq(node): ) if self.verbose_ > 2: logger.debug( - f" {node.output[i_o]}: {str(out_shape)} {onnx.TensorProto.DataType.Name(vi.type.tensor_type.elem_type)}" + f" {node.output[i_o]}: {out_shape!s} {onnx.TensorProto.DataType.Name(vi.type.tensor_type.elem_type)}" ) if node.output[i_o] in self.sympy_data_: logger.debug( @@ -3326,7 +3327,7 @@ def get_prereq(node): ) if self.verbose_ > 2: logger.debug( - f" {node.output[i_o]}: {str(new_shape)} {vi.type.tensor_type.elem_type}" + f" {node.output[i_o]}: {new_shape!s} {vi.type.tensor_type.elem_type}" ) self.run_ = True continue # continue the inference after guess, no need to stop as no merge is needed diff --git a/src/onnx_ir/_shape_type_inference/ops/standard_ops.py b/src/onnx_ir/_shape_type_inference/ops/standard_ops.py index 84ed1861..9e21f0ca 100644 --- a/src/onnx_ir/_shape_type_inference/ops/standard_ops.py +++ b/src/onnx_ir/_shape_type_inference/ops/standard_ops.py @@ -20,21 +20,11 @@ def __init__(self, op_type: str, opsets: Collection[int] | None = None) -> None: opsets = range(sys.maxsize) super().__init__(op_type, opsets=opsets) + @_common.requires_non_none_inputs(1) + @_common.requires_outputs(1) def infer(self, node: ir.Node) -> _common.InferenceResult: """Infer the output shape and type for elementwise operations.""" - if len(node.inputs) != 1: - return _common.InferenceResult( - failure=f"Elementwise operation must have exactly one input, got {len(node.inputs)}." - ) - if node.inputs[0] is None: - return _common.InferenceResult( - failure="Elementwise operation input cannot be None." - ) - if len(node.outputs) != 1: - return _common.InferenceResult( - failure=f"Elementwise operation must have exactly one output, got {len(node.outputs)}." - ) - + assert node.inputs[0] is not None return _common.InferenceResult( (ir.Value(shape=node.inputs[0].shape, type=node.inputs[0].type),) ) @@ -89,18 +79,12 @@ def __init__(self, op_type: str) -> None: """Initialize the binary inferrer with the operation type.""" super().__init__(op_type, opsets=range(sys.maxsize)) + @_common.requires_non_none_inputs(2) + @_common.requires_outputs(1) def infer(self, node: ir.Node) -> _common.InferenceResult: """Infer the output shape and type for binary operations.""" - if len(node.inputs) != 2: - return _common.InferenceResult( - failure=f"Binary operation must have exactly two inputs, got {len(node.inputs)}." - ) - if node.inputs[0] is None or node.inputs[1] is None: - return _common.InferenceResult(failure="Binary operation inputs cannot be None.") - if len(node.outputs) != 1: - return _common.InferenceResult( - failure=f"Binary operation must have exactly one output, got {len(node.outputs)}." - ) + assert node.inputs[0] is not None + assert node.inputs[1] is not None first_type = node.inputs[0].type second_type = node.inputs[1].type if first_type is not None and second_type is not None and first_type != second_type: diff --git a/src/onnx_ir/_shape_type_inference/ops/transpose.py b/src/onnx_ir/_shape_type_inference/ops/transpose.py index dd81f145..cae9833e 100644 --- a/src/onnx_ir/_shape_type_inference/ops/transpose.py +++ b/src/onnx_ir/_shape_type_inference/ops/transpose.py @@ -13,19 +13,11 @@ def __init__(self) -> None: """Initialize the Transpose inferrer.""" super().__init__("Transpose", opsets=range(sys.maxsize)) + @_common.requires_non_none_inputs(1) + @_common.requires_outputs(1) def infer(self, node: ir.Node) -> _common.InferenceResult: """Infer the output shape and type for Transpose operations.""" - if len(node.inputs) != 1: - return _common.InferenceResult( - failure=f"Transpose operation must have exactly one input, got {len(node.inputs)}." - ) - if node.inputs[0] is None: - return _common.InferenceResult(failure="Transpose operation input cannot be None.") - if len(node.outputs) != 1: - return _common.InferenceResult( - failure=f"Transpose operation must have exactly one output, got {len(node.outputs)}." - ) - + assert node.inputs[0] is not None input_shape = node.inputs[0].shape if input_shape is None: return _common.InferenceResult(failure="Transpose input shape cannot be None.") From 720845e68883b816d2856e62c041c758a94c7d90 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Mon, 30 Jun 2025 07:27:21 -0700 Subject: [PATCH 19/31] Create inclusive range Signed-off-by: Justin Chu --- src/onnx_ir/_shape_type_inference/_common.py | 26 ++++++++++++++ .../_shape_type_inference/ops/concat.py | 3 +- .../_shape_type_inference/ops/constant.py | 36 +++++++++++++++++++ .../_shape_type_inference/ops/matmul.py | 3 +- .../_shape_type_inference/ops/reshape.py | 4 ++- .../_shape_type_inference/ops/squeeze.py | 3 +- .../_shape_type_inference/ops/standard_ops.py | 5 ++- .../_shape_type_inference/ops/transpose.py | 6 ++-- .../_shape_type_inference/ops/unsqueeze.py | 3 +- 9 files changed, 74 insertions(+), 15 deletions(-) create mode 100644 src/onnx_ir/_shape_type_inference/ops/constant.py diff --git a/src/onnx_ir/_shape_type_inference/_common.py b/src/onnx_ir/_shape_type_inference/_common.py index 4070edbb..01fb69d0 100644 --- a/src/onnx_ir/_shape_type_inference/_common.py +++ b/src/onnx_ir/_shape_type_inference/_common.py @@ -13,6 +13,9 @@ import onnx_ir as ir +MAX_SUPPORTED_OPSET = 23 + + def get_expr(shape: ir.Shape, index: int) -> sympy.Expr: """Get the expression or value at a specific index in the shape. @@ -57,6 +60,10 @@ def __init__(self, op_type: str, opsets: Collection[int], domain: str = "") -> N self.opsets = opsets self.domain = domain + def __repr__(self) -> str: + """Return a string representation of the node inferrer.""" + return f"{self.__class__.__name__}(op_type={self.op_type}, opsets={self.opsets}, domain={self.domain})" + @abc.abstractmethod def infer(self, node: ir.Node) -> InferenceResult: """Infer the shape for the node. @@ -131,3 +138,22 @@ def wrapper(self, node: ir.Node) -> InferenceResult: return wrapper return decorator + + +def inclusive_range(start_or_end: int = 0, end: int | None = None) -> range: + """Create an inclusive range from start to end with a given step. + + Args: + start_or_end: The starting value of the range. + end: The ending value of the range (inclusive). + + Returns: + A range object that includes both start and end. + """ + if end is None: + end = start_or_end + start = 0 + else: + start = start_or_end + + return range(start, end + 1) diff --git a/src/onnx_ir/_shape_type_inference/ops/concat.py b/src/onnx_ir/_shape_type_inference/ops/concat.py index 318df26f..4dadab86 100644 --- a/src/onnx_ir/_shape_type_inference/ops/concat.py +++ b/src/onnx_ir/_shape_type_inference/ops/concat.py @@ -1,6 +1,5 @@ """Concat operation inferrer for ONNX IR nodes.""" -import sys import onnx_ir as ir from onnx_ir._shape_type_inference import _common @@ -11,7 +10,7 @@ class ConcatInferrer(_common.NodeInferrer): def __init__(self) -> None: """Initialize the Concat inferrer.""" - super().__init__("Concat", opsets=range(sys.maxsize)) + super().__init__("Concat", opsets=_common.inclusive_range(_common.MAX_SUPPORTED_OPSET)) def infer(self, node: ir.Node) -> _common.InferenceResult: """Infer the output shape and type for Concat operations.""" diff --git a/src/onnx_ir/_shape_type_inference/ops/constant.py b/src/onnx_ir/_shape_type_inference/ops/constant.py new file mode 100644 index 00000000..aaf74712 --- /dev/null +++ b/src/onnx_ir/_shape_type_inference/ops/constant.py @@ -0,0 +1,36 @@ +"""Constant operation inferrer for ONNX IR nodes.""" + +from __future__ import annotations + +import sys + +import onnx_ir as ir +from onnx_ir._shape_type_inference import _common + + +class ConstantInferrer(_common.NodeInferrer): + """Inferrer for Constant operations.""" + + def __init__(self) -> None: + """Initialize the Constant inferrer.""" + super().__init__( + "Constant", opsets=_common.inclusive_range(_common.MAX_SUPPORTED_OPSET) + ) + + @_common.requires_outputs(1) + def infer(self, node: ir.Node) -> _common.InferenceResult: + """Infer the output shape and type for Constant operations.""" + assert node.inputs[0] is not None + tensor = ir.convenience.get_const_tensor(node.inputs[0]) + if tensor is None: + return _common.InferenceResult(failure="Constant tensor cannot be obtained.") + + # Create shape from the tensor dimensions + output_shape = ir.Shape(tensor.shape) + + # Get the data type from the tensor + output_type = ir.TensorType(tensor.dtype) + + return _common.InferenceResult( + values=(ir.Value(shape=output_shape, type=output_type),) + ) diff --git a/src/onnx_ir/_shape_type_inference/ops/matmul.py b/src/onnx_ir/_shape_type_inference/ops/matmul.py index 6ebd48d8..f414f513 100644 --- a/src/onnx_ir/_shape_type_inference/ops/matmul.py +++ b/src/onnx_ir/_shape_type_inference/ops/matmul.py @@ -1,6 +1,5 @@ """MatMul operation inferrer for ONNX IR nodes.""" -import sys import onnx_ir as ir from onnx_ir._shape_type_inference import _common @@ -12,7 +11,7 @@ class MatMulInferrer(_common.NodeInferrer): def __init__(self) -> None: """Initialize the MatMul inferrer.""" - super().__init__("MatMul", opsets=range(sys.maxsize)) + super().__init__("MatMul", opsets=_common.inclusive_range(_common.MAX_SUPPORTED_OPSET)) @_common.requires_non_none_inputs(2) @_common.requires_outputs(1) diff --git a/src/onnx_ir/_shape_type_inference/ops/reshape.py b/src/onnx_ir/_shape_type_inference/ops/reshape.py index a4b43980..3754ea33 100644 --- a/src/onnx_ir/_shape_type_inference/ops/reshape.py +++ b/src/onnx_ir/_shape_type_inference/ops/reshape.py @@ -12,7 +12,9 @@ class ReshapeInferrer(_common.NodeInferrer): """Inferrer for Reshape operations.""" def __init__(self) -> None: - super().__init__("Reshape", opsets=range(sys.maxsize)) + super().__init__( + "Reshape", opsets=_common.inclusive_range(_common.MAX_SUPPORTED_OPSET) + ) def infer(self, node: ir.Node) -> _common.InferenceResult: """Infer the output shape and type for Reshape operations.""" diff --git a/src/onnx_ir/_shape_type_inference/ops/squeeze.py b/src/onnx_ir/_shape_type_inference/ops/squeeze.py index df4c3b92..1376923f 100644 --- a/src/onnx_ir/_shape_type_inference/ops/squeeze.py +++ b/src/onnx_ir/_shape_type_inference/ops/squeeze.py @@ -3,7 +3,6 @@ from __future__ import annotations import logging -import sys from collections.abc import Sequence import onnx_ir as ir @@ -93,7 +92,7 @@ class Squeeze13Inferrer(_common.NodeInferrer): def __init__(self) -> None: """Initialize the Squeeze inferrer.""" - super().__init__("Squeeze", opsets=range(14, sys.maxsize)) + super().__init__("Squeeze", opsets=range(14, _common.MAX_SUPPORTED_OPSET)) @_common.requires_non_none_inputs(2) @_common.requires_outputs(1) diff --git a/src/onnx_ir/_shape_type_inference/ops/standard_ops.py b/src/onnx_ir/_shape_type_inference/ops/standard_ops.py index 9e21f0ca..b3a4955d 100644 --- a/src/onnx_ir/_shape_type_inference/ops/standard_ops.py +++ b/src/onnx_ir/_shape_type_inference/ops/standard_ops.py @@ -2,7 +2,6 @@ from __future__ import annotations -import sys from collections.abc import Collection import sympy @@ -17,7 +16,7 @@ class ElementwiseInferrer(_common.NodeInferrer): def __init__(self, op_type: str, opsets: Collection[int] | None = None) -> None: """Initialize the elementwise inferrer with the operation type.""" if opsets is None: - opsets = range(sys.maxsize) + opsets = _common.inclusive_range(_common.MAX_SUPPORTED_OPSET) super().__init__(op_type, opsets=opsets) @_common.requires_non_none_inputs(1) @@ -77,7 +76,7 @@ class BinaryInferrer(_common.NodeInferrer): def __init__(self, op_type: str) -> None: """Initialize the binary inferrer with the operation type.""" - super().__init__(op_type, opsets=range(sys.maxsize)) + super().__init__(op_type, opsets=_common.inclusive_range(_common.MAX_SUPPORTED_OPSET)) @_common.requires_non_none_inputs(2) @_common.requires_outputs(1) diff --git a/src/onnx_ir/_shape_type_inference/ops/transpose.py b/src/onnx_ir/_shape_type_inference/ops/transpose.py index cae9833e..6c1ee3fb 100644 --- a/src/onnx_ir/_shape_type_inference/ops/transpose.py +++ b/src/onnx_ir/_shape_type_inference/ops/transpose.py @@ -1,7 +1,5 @@ """Transpose operation inferrer for ONNX IR nodes.""" -import sys - import onnx_ir as ir from onnx_ir._shape_type_inference import _common @@ -11,7 +9,9 @@ class TransposeInferrer(_common.NodeInferrer): def __init__(self) -> None: """Initialize the Transpose inferrer.""" - super().__init__("Transpose", opsets=range(sys.maxsize)) + super().__init__( + "Transpose", opsets=_common.inclusive_range(_common.MAX_SUPPORTED_OPSET) + ) @_common.requires_non_none_inputs(1) @_common.requires_outputs(1) diff --git a/src/onnx_ir/_shape_type_inference/ops/unsqueeze.py b/src/onnx_ir/_shape_type_inference/ops/unsqueeze.py index 25b4c5d2..f1d5c068 100644 --- a/src/onnx_ir/_shape_type_inference/ops/unsqueeze.py +++ b/src/onnx_ir/_shape_type_inference/ops/unsqueeze.py @@ -3,7 +3,6 @@ from __future__ import annotations import logging -import sys from collections.abc import Sequence import onnx_ir as ir @@ -101,7 +100,7 @@ class Unsqueeze13Inferrer(_common.NodeInferrer): def __init__(self) -> None: """Initialize the Unsqueeze inferrer.""" - super().__init__("Unsqueeze", opsets=range(13, sys.maxsize)) + super().__init__("Unsqueeze", opsets=range(13, _common.MAX_SUPPORTED_OPSET)) @_common.requires_non_none_inputs(2) @_common.requires_outputs(1) From bae78abe3d6720a6e6bbac1f4f03560138da80ac Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Mon, 30 Jun 2025 07:31:19 -0700 Subject: [PATCH 20/31] WIP inference engine Signed-off-by: Justin Chu --- src/onnx_ir/_shape_type_inference/__init__.py | 17 +- src/onnx_ir/_shape_type_inference/_engine.py | 324 ++++++++++++++++++ 2 files changed, 339 insertions(+), 2 deletions(-) create mode 100644 src/onnx_ir/_shape_type_inference/_engine.py diff --git a/src/onnx_ir/_shape_type_inference/__init__.py b/src/onnx_ir/_shape_type_inference/__init__.py index 7cb7d817..dd45d913 100644 --- a/src/onnx_ir/_shape_type_inference/__init__.py +++ b/src/onnx_ir/_shape_type_inference/__init__.py @@ -1,2 +1,15 @@ -class SymbolicInferenceEngine: - pass +"""Symbolic shape and type inference for ONNX IR.""" + +__all__ = [ + "SymbolicInferenceEngine", + "InferenceError", + "NodeInferrer", + "InferenceResult", +] + + +from onnx_ir._shape_type_inference._common import InferenceResult, NodeInferrer +from onnx_ir._shape_type_inference._engine import ( + InferenceError, + SymbolicInferenceEngine, +) diff --git a/src/onnx_ir/_shape_type_inference/_engine.py b/src/onnx_ir/_shape_type_inference/_engine.py new file mode 100644 index 00000000..d79e815a --- /dev/null +++ b/src/onnx_ir/_shape_type_inference/_engine.py @@ -0,0 +1,324 @@ +"""Symbolic inference engine for ONNX IR models.""" + +from __future__ import annotations + +import enum +import logging +from collections.abc import Sequence + +import onnx_ir as ir +from onnx_ir._shape_type_inference import _common + +logger = logging.getLogger(__name__) + + +class ReconciliationPolicy(enum.StrEnum): + """Policy for reconciling inferred shapes/types with existing values.""" + + OVERWRITE = "overwrite" # Always use inferred values + IGNORE = "ignore" # Keep existing values if they exist + RECONCILE = "reconcile" # Try to merge/validate inferred vs existing + STRICT = "strict" # Fail if inferred doesn't match existing + + +class InferenceError(RuntimeError): + """Error during shape inference.""" + + +class SymbolicInferenceEngine: + """Engine for performing symbolic shape and type inference on ONNX IR models.""" + + def __init__( + self, + node_inferrers: Sequence[_common.NodeInferrer], + reconciliation_policy: str = "reconcile", + ) -> None: + """Initialize the symbolic inference engine. + + Args: + node_inferrers: List of node inferrers to use for shape inference. + reconciliation_policy: Policy for handling conflicts between inferred and existing values. + """ + self.reconciliation_policy = ReconciliationPolicy(reconciliation_policy) + self._inferrer_registry: dict[tuple[str, str], list[_common.NodeInferrer]] = {} + + # Register inferrers by (op_type, domain) + for inferrer in node_inferrers: + key = (inferrer.op_type, inferrer.domain) + if key not in self._inferrer_registry: + self._inferrer_registry[key] = [] + self._inferrer_registry[key].append(inferrer) + + logger.info(f"Initialized inference engine with {len(node_inferrers)} inferrers") + + def infer_model(self, model: ir.Model) -> None: + """Perform shape and type inference on an entire model. + + Args: + model: The ONNX IR model to perform inference on. + + Raises: + InferenceError: If inference fails for any node. + """ + logger.info(f"Starting inference on model with {len(model.graph.nodes)} nodes") + + # Process nodes in topological order + for i, node in enumerate(model.graph.nodes): + try: + self._infer_node(node, model) + logger.debug(f"Successfully inferred node {i}: {node.op_type}") + except Exception as e: + error_msg = f"Failed to infer node {i} ({node.op_type}): {str(e)}" + logger.error(error_msg) + raise InferenceError(error_msg) from e + + logger.info("Model inference completed successfully") + + def _infer_node(self, node: ir.Node, model: ir.Model) -> None: + """Perform inference on a single node. + + Args: + node: The node to perform inference on. + model: The model containing the node (for context). + + Raises: + InferenceError: If no suitable inferrer is found or inference fails. + """ + # Find suitable inferrer + inferrer = self._find_inferrer(node, model) + if inferrer is None: + raise InferenceError( + f"No inferrer found for op_type '{node.op_type}' domain '{node.domain}'" + ) + + # Perform inference + result = inferrer.infer(node) + + if result.failure is not None: + raise InferenceError(f"Inference failed: {result.failure}") + + if result.values is None: + raise InferenceError("Inference returned no values") + + # Apply reconciliation policy + self._reconcile_outputs(node, result.values) + + def _find_inferrer(self, node: ir.Node, model: ir.Model) -> _common.NodeInferrer | None: + """Find a suitable inferrer for the given node. + + Args: + node: The node to find an inferrer for. + model: The model containing the node. + + Returns: + The best matching inferrer, or None if no suitable inferrer is found. + """ + key = (node.op_type, node.domain) + inferrers = self._inferrer_registry.get(key, []) + + if not inferrers: + return None + + # Get model opset version for this domain + opset_version = self._get_opset_version(model, node.domain) + + # Find inferrers that support this opset version + suitable_inferrers = [ + inferrer for inferrer in inferrers if opset_version in inferrer.opsets + ] + + if not suitable_inferrers: + logger.warning( + f"No inferrer supports opset {opset_version} for {node.op_type} " + f"(domain: {node.domain})" + ) + return None + + # Return the first suitable inferrer (could be enhanced with priority logic) + return suitable_inferrers[0] + + def _get_opset_version(self, model: ir.Model, domain: str) -> int: + """Get the opset version for a given domain in the model. + + Args: + model: The model to check. + domain: The domain to get the opset version for. + + Returns: + The opset version for the domain. + """ + # Look for opset import for this domain + for opset_import in model.opset_imports: + if opset_import.domain == domain: + return opset_import.version + + # Default to a high version if not found + return 999 + + def _reconcile_outputs(self, node: ir.Node, inferred_values: Sequence[ir.Value]) -> None: + """Reconcile inferred output values with existing node outputs. + + Args: + node: The node whose outputs to reconcile. + inferred_values: The inferred output values. + + Raises: + InferenceError: If reconciliation fails under strict policy. + """ + if len(inferred_values) != len(node.outputs): + raise InferenceError( + f"Inference returned {len(inferred_values)} values but node has " + f"{len(node.outputs)} outputs" + ) + + for i, (existing_output, inferred_value) in enumerate( + zip(node.outputs, inferred_values) + ): + if existing_output is None: + # No existing output - create new one + node.outputs[i] = inferred_value + continue + + # Reconcile based on policy + if self.reconciliation_policy == ReconciliationPolicy.OVERWRITE: + node.outputs[i] = inferred_value + + elif self.reconciliation_policy == ReconciliationPolicy.IGNORE: + # Keep existing output if it has shape/type info + if existing_output.shape is None and existing_output.type is None: + node.outputs[i] = inferred_value + # Otherwise keep existing + + elif self.reconciliation_policy == ReconciliationPolicy.RECONCILE: + reconciled_output = self._reconcile_value(existing_output, inferred_value) + node.outputs[i] = reconciled_output + + elif self.reconciliation_policy == ReconciliationPolicy.STRICT: + if not self._values_compatible(existing_output, inferred_value): + raise InferenceError( + f"Output {i} mismatch: existing {existing_output} vs " + f"inferred {inferred_value}" + ) + # Keep existing in strict mode if compatible + + def _reconcile_value(self, existing: ir.Value, inferred: ir.Value) -> ir.Value: + """Reconcile an existing value with an inferred value. + + Args: + existing: The existing value. + inferred: The inferred value. + + Returns: + The reconciled value. + """ + # Start with existing value + result_shape = existing.shape + result_type = existing.type + + # Use inferred shape if existing is None or less specific + if inferred.shape is not None: + if result_shape is None: + result_shape = inferred.shape + else: + # Try to merge shapes (prefer more specific) + result_shape = self._reconcile_shapes(result_shape, inferred.shape) + + # Use inferred type if existing is None + if inferred.type is not None and result_type is None: + result_type = inferred.type + + return ir.Value(shape=result_shape, type=result_type) + + def _reconcile_shapes(self, shape1: ir.Shape, shape2: ir.Shape) -> ir.Shape: + """Reconcile two shapes by preferring more specific dimensions. + + Args: + shape1: First shape. + shape2: Second shape. + + Returns: + The reconciled shape. + """ + if len(shape1) != len(shape2): + logger.warning( + f"Shape rank mismatch: {len(shape1)} vs {len(shape2)}. Using first shape." + ) + return shape1 + + reconciled_dims = [] + for dim1, dim2 in zip(shape1.dims, shape2.dims): + # Prefer concrete dimensions over None/symbolic + if isinstance(dim1, int) and dim1 > 0: + reconciled_dims.append(dim1) + elif isinstance(dim2, int) and dim2 > 0: + reconciled_dims.append(dim2) + elif dim1 is not None: + reconciled_dims.append(dim1) + elif dim2 is not None: + reconciled_dims.append(dim2) + else: + reconciled_dims.append(None) + + return ir.Shape(reconciled_dims) + + def _values_compatible(self, value1: ir.Value, value2: ir.Value) -> bool: + """Check if two values are compatible (for strict mode). + + Args: + value1: First value. + value2: Second value. + + Returns: + True if the values are compatible. + """ + # Check shape compatibility + if value1.shape is not None and value2.shape is not None: + if not self._shapes_compatible(value1.shape, value2.shape): + return False + + # Check type compatibility + if value1.type is not None and value2.type is not None: + if value1.type != value2.type: + return False + + return True + + def _shapes_compatible(self, shape1: ir.Shape, shape2: ir.Shape) -> bool: + """Check if two shapes are compatible. + + Args: + shape1: First shape. + shape2: Second shape. + + Returns: + True if the shapes are compatible. + """ + if len(shape1) != len(shape2): + return False + + for dim1, dim2 in zip(shape1.dims, shape2.dims): + # None/symbolic dimensions are compatible with anything + if dim1 is None or dim2 is None: + continue + + # Both concrete - must match + if isinstance(dim1, int) and isinstance(dim2, int): + if dim1 != dim2: + return False + + # Symbolic dimensions - for now assume compatible + # Could be enhanced with symbolic expression comparison + + return True + + def get_inferrer_info(self) -> dict[str, int]: + """Get information about registered inferrers. + + Returns: + Dictionary mapping operation types to inferrer counts. + """ + info = {} + for (op_type, domain), inferrers in self._inferrer_registry.items(): + key = f"{op_type}:{domain}" if domain else op_type + info[key] = len(inferrers) + return info From a77f4879baa4f2b0952dcf5624f9021f12e7319a Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Mon, 30 Jun 2025 08:29:21 -0700 Subject: [PATCH 21/31] Create readme Signed-off-by: Justin Chu --- src/onnx_ir/_shape_type_inference/README.md | 231 ++++++++++++++++++++ 1 file changed, 231 insertions(+) create mode 100644 src/onnx_ir/_shape_type_inference/README.md diff --git a/src/onnx_ir/_shape_type_inference/README.md b/src/onnx_ir/_shape_type_inference/README.md new file mode 100644 index 00000000..c2419b2d --- /dev/null +++ b/src/onnx_ir/_shape_type_inference/README.md @@ -0,0 +1,231 @@ +# Symbolic Shape and Type Inference + +This module provides symbolic shape and type inference for ONNX IR models, enabling compile-time analysis of tensor shapes and types with support for symbolic dimensions. + +## Overview + +The inference engine performs forward propagation through ONNX models to determine output shapes and types based on input specifications. It supports symbolic dimensions using SymPy expressions, allowing for dynamic shape analysis. + +## Key Components + +### Core Classes + +- **`SymbolicInferenceEngine`**: Main orchestrator that processes models and applies inference +- **`NodeInferrer`**: Base class for operation-specific inference logic +- **`InferenceResult`**: Container for inference results or failure information + +### Reconciliation Policies + +The engine supports different strategies for handling conflicts between inferred and existing values: + +- **`OVERWRITE`**: Always use inferred values +- **`IGNORE`**: Keep existing values if they exist +- **`RECONCILE`**: Merge inferred and existing values intelligently +- **`STRICT`**: Fail if inferred values don't match existing ones + +## Architecture + +```text +SymbolicInferenceEngine +├── NodeInferrer Registry (by op_type + domain) +├── Opset Version Matching +└── Reconciliation Logic + +NodeInferrer Implementations +├── ElementwiseInferrer (unary operations) +├── BinaryInferrer (broadcasting operations) +└── Specialized Inferrers (50+ operations) +``` + +## Inferrer Selection + +The engine selects the appropriate inferrer using a two-stage process: + +1. **Registry Lookup**: Inferrers are registered by `(op_type, domain)` key +2. **Opset Matching**: Among matching inferrers, select those supporting the model's opset version + +```python +# Example: For a Squeeze node with opset 14 +# - Multiple Squeeze inferrers may be registered +# - Engine selects Squeeze13Inferrer (supports opset 13-23) +# - Ignores Squeeze12Inferrer (supports opset 1-12) +``` + +## Symbolic Dimensions + +The system stores symbolic expressions in `ir.SymbolicDim` objects: + +```python +class SymbolicDim: + value: str | None # String identifier (e.g., "N", "batch_size") + expr: sympy.Expr | None # SymPy expression for computed dimensions +``` + +Dimensions are accessed via `get_expr()` which converts to SymPy expressions: + +- `SymbolicDim(value="N")` → `sympy.Symbol("N")` +- `SymbolicDim(expr=N*2)` → `N*2` (SymPy expression) +- Integer dimensions → `sympy.Integer(value)` + +## NodeInferrer Design Decisions + +### Base Class Structure +The `NodeInferrer` abstract base class enforces a consistent interface: + +```python +class NodeInferrer(abc.ABC): + def __init__(self, op_type: str, opsets: Collection[int], domain: str = ""): + # Store operation metadata for registry matching + + @abc.abstractmethod + def infer(self, node: ir.Node) -> InferenceResult: + # Operation-specific inference logic +``` + +### Design Rationale + +1. **Single Responsibility**: Each inferrer handles exactly one operation type +2. **Opset Awareness**: Inferrers declare supported ONNX opset versions for compatibility +3. **Domain Support**: Enables custom domains beyond standard ONNX operators +4. **Validation Decorators**: `@requires_non_none_inputs(n)` and `@requires_outputs(n)` provide consistent input validation +5. **Failure Handling**: Return `InferenceResult` with either `values` or `failure` for graceful error handling + +### Inheritance Patterns + +- **ElementwiseInferrer**: Template for unary operations that preserve input shape/type +- **BinaryInferrer**: Template for binary operations with broadcasting logic +- **Specialized Inferrers**: Custom logic for complex operations (Conv, Reshape, etc.) + +## Usage + +### Basic Usage + +```python +from onnx_ir._shape_type_inference.factory import create_standard_inference_engine +from onnx_ir._shape_type_inference import ReconciliationPolicy + +# Create engine with all standard operations +engine = create_standard_inference_engine(ReconciliationPolicy.RECONCILE) + +# Perform inference on a model +engine.infer_model(model) +``` + +### Custom Engine + +```python +from onnx_ir._shape_type_inference import SymbolicInferenceEngine +from onnx_ir._shape_type_inference.ops.matmul import MatMulInferrer +from onnx_ir._shape_type_inference.ops.standard_ops import BinaryInferrer + +# Create custom engine with specific operations +inferrers = [ + MatMulInferrer(), + BinaryInferrer("Add"), + BinaryInferrer("Mul"), +] + +engine = SymbolicInferenceEngine(inferrers, ReconciliationPolicy.STRICT) +``` + +## Opset Version Support + +Each inferrer specifies supported ONNX opset versions to handle API changes: + +```python +class Squeeze12Inferrer(NodeInferrer): + def __init__(self): + super().__init__("Squeeze", opsets=range(1, 13)) + +class Squeeze13Inferrer(NodeInferrer): + def __init__(self): + super().__init__("Squeeze", opsets=range(13, 24)) +``` + +## Error Handling + +The engine provides comprehensive error handling: + +- **Validation Errors**: Invalid input/output counts, missing shapes +- **Type Mismatches**: Incompatible input types for binary operations +- **Inference Failures**: Operation-specific inference errors +- **Reconciliation Conflicts**: Value mismatches in strict mode + +## Factory Functions + +Pre-configured engines for common use cases: + +- **`create_standard_inference_engine()`**: Full operation coverage (50+ ops) +- **`create_minimal_inference_engine()`**: Essential operations only + +## Subgraphs and ONNX Functions + +### Design Approach + +#### Subgraph Pre-Processing Strategy + +The engine uses a **subgraph-first** approach for cleaner separation of concerns: + +1. **Pre-Processing Phase**: Before running node inference, detect and recursively process all subgraphs +2. **Bottom-Up Inference**: Subgraphs are fully inferred before their parent nodes +3. **Simplified Node Logic**: Control flow inferrers (If, Loop, Scan) can assume subgraph shapes are already available + +```python +class SymbolicInferenceEngine: + def _infer_node(self, node: ir.Node, model: ir.Model) -> None: + # First: recursively infer any subgraphs + for attr in node.attributes: + if isinstance(attr.value, ir.Graph): + self._infer_subgraph(attr.value, model) + + # Then: run node-specific inference with subgraphs already processed + inferrer = self._find_inferrer(node, model) + result = inferrer.infer(node) # Subgraph shapes already available +``` + +#### ONNX Function Support + +Functions are handled through **automatic expansion** without custom inferrer logic: + +1. **Function Context**: Engine maintains intermediate value mappings during function execution +2. **Transparent Expansion**: Function calls are expanded inline and processed like regular subgraphs +3. **No Custom Logic**: Users don't implement function-specific inferrers - the engine handles it automatically + +```python +class SymbolicInferenceEngine: + def _infer_function_call(self, node: ir.Node, function: ir.Function) -> InferenceResult: + # Create isolated context for function execution + function_context = self._create_function_context(node.inputs, function) + + # Process function body as a subgraph + for func_node in function.nodes: + self._infer_node_in_context(func_node, function_context) + + # Map function outputs back to caller node + return self._extract_function_outputs(function_context, function.outputs) +``` + +### Key Benefits + +1. **Cleaner Separation**: Subgraph inference is handled by the engine, not individual inferrers +2. **Automatic Function Support**: No need to implement custom logic for each function +3. **Simplified Debugging**: Each phase (subgraphs → nodes) can be debugged independently +4. **Consistent Context**: Function calls maintain proper variable scoping and type consistency + +## Extension Points + +To add support for new operations: + +1. Create a new inferrer class inheriting from `NodeInferrer` +2. Implement the `infer()` method with operation-specific logic +3. Register with the engine or add to factory functions + +```python +class CustomOpInferrer(NodeInferrer): + def __init__(self): + super().__init__("CustomOp", opsets=range(1, 24), domain="custom_domain") + + def infer(self, node: ir.Node) -> InferenceResult: + # Custom inference logic + return InferenceResult(values=[result_value]) +``` From 6686457082748b206f88c6ce4ab2b4a79e0caad5 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Mon, 30 Jun 2025 08:37:31 -0700 Subject: [PATCH 22/31] Result Signed-off-by: Justin Chu --- src/onnx_ir/_shape_type_inference/README.md | 33 ++++++++++++- src/onnx_ir/_shape_type_inference/__init__.py | 3 +- src/onnx_ir/_shape_type_inference/_common.py | 47 ++++++++++++++++--- src/onnx_ir/_shape_type_inference/_engine.py | 14 +++++- .../_shape_type_inference/ops/standard_ops.py | 8 +++- 5 files changed, 92 insertions(+), 13 deletions(-) diff --git a/src/onnx_ir/_shape_type_inference/README.md b/src/onnx_ir/_shape_type_inference/README.md index c2419b2d..29094c96 100644 --- a/src/onnx_ir/_shape_type_inference/README.md +++ b/src/onnx_ir/_shape_type_inference/README.md @@ -12,7 +12,8 @@ The inference engine performs forward propagation through ONNX models to determi - **`SymbolicInferenceEngine`**: Main orchestrator that processes models and applies inference - **`NodeInferrer`**: Base class for operation-specific inference logic -- **`InferenceResult`**: Container for inference results or failure information +- **`InferenceResult`**: Container for inference results with status and optional message +- **`InferenceStatus`**: Enum for inference operation status (SUCCESS, PARTIAL, MISSING_INFO, INVALID_NODE) ### Reconciliation Policies @@ -23,6 +24,36 @@ The engine supports different strategies for handling conflicts between inferred - **`RECONCILE`**: Merge inferred and existing values intelligently - **`STRICT`**: Fail if inferred values don't match existing ones +### Inference Status System + +The `InferenceResult` uses a status-based approach for granular error handling: + +- **`SUCCESS`**: Complete inference successful with full shape/type information +- **`PARTIAL`**: Partial information available (e.g., type only, rank only) +- **`MISSING_INFO`**: Missing required input information (shapes, types) +- **`INVALID_NODE`**: Node is invalid or malformed + +```python +# Example usage in inferrers +def infer(self, node: ir.Node) -> InferenceResult: + if node.inputs[0].shape is None: + return InferenceResult( + status="missing_info", + msg="Input shape is required" + ) + + # Partial inference - only type available + if can_infer_type_only(): + return InferenceResult( + values=[ir.Value(type=inferred_type)], + status="partial", + msg="Shape unavailable, type only" + ) + + # Full inference (status defaults to "success") + return InferenceResult(values=[full_value]) +``` + ## Architecture ```text diff --git a/src/onnx_ir/_shape_type_inference/__init__.py b/src/onnx_ir/_shape_type_inference/__init__.py index dd45d913..b3d8fb69 100644 --- a/src/onnx_ir/_shape_type_inference/__init__.py +++ b/src/onnx_ir/_shape_type_inference/__init__.py @@ -5,10 +5,11 @@ "InferenceError", "NodeInferrer", "InferenceResult", + "InferenceStatus", ] -from onnx_ir._shape_type_inference._common import InferenceResult, NodeInferrer +from onnx_ir._shape_type_inference._common import InferenceResult, InferenceStatus, NodeInferrer from onnx_ir._shape_type_inference._engine import ( InferenceError, SymbolicInferenceEngine, diff --git a/src/onnx_ir/_shape_type_inference/_common.py b/src/onnx_ir/_shape_type_inference/_common.py index 01fb69d0..c39923d8 100644 --- a/src/onnx_ir/_shape_type_inference/_common.py +++ b/src/onnx_ir/_shape_type_inference/_common.py @@ -3,7 +3,7 @@ from __future__ import annotations import abc -import dataclasses +import enum import functools from collections.abc import Collection, Sequence from typing import Any, Callable @@ -36,10 +36,38 @@ def get_expr(shape: ir.Shape, index: int) -> sympy.Expr: return sympy.Integer(dim) -@dataclasses.dataclass +@enum.unique +class InferenceStatus(enum.Enum): + """Status of shape inference operation.""" + SUCCESS = "success" # Complete inference successful + PARTIAL = "partial" # Partial information available (e.g., type only, rank only) + MISSING_INFO = "missing_info" # Missing required input information + INVALID_NODE = "invalid_node" # Node is invalid or malformed + + class InferenceResult: - values: Sequence[ir.Value] | None = None - failure: str | None = None + """Container for inference results with status and optional message.""" + + def __init__( + self, + values: Sequence[ir.Value] | None = None, + status: str | InferenceStatus = "success", + msg: str | None = None, + ) -> None: + """Initialize inference result. + + Args: + values: Sequence of inferred values. + status: Status of inference operation (string or enum). + msg: Optional message for context. + """ + self.values = values + self.status = InferenceStatus(status) + self.msg = msg + + def __repr__(self) -> str: + """Return string representation of the result.""" + return f"InferenceResult(values={self.values}, status={self.status.value}, msg={self.msg!r})" class NodeInferrer(abc.ABC): @@ -98,11 +126,15 @@ def decorator( def wrapper(self, node: ir.Node) -> InferenceResult: if len(node.inputs) != count: return InferenceResult( - failure=f"[{node.op_type} must have {count} inputs, got {len(node.inputs)}." + status="invalid_node", + msg=f"{node.op_type} must have {count} inputs, got {len(node.inputs)}." ) for i, inp in enumerate(node.inputs): if inp is None: - return InferenceResult(failure=f"{node.op_type} input {i} cannot be None.") + return InferenceResult( + status="missing_info", + msg=f"{node.op_type} input {i} cannot be None." + ) return func(self, node) return wrapper @@ -131,7 +163,8 @@ def decorator( def wrapper(self, node: ir.Node) -> InferenceResult: if len(node.outputs) != count: return InferenceResult( - failure=f"[{node.op_type} must have {count} outputs, got {len(node.outputs)}." + status="invalid_node", + msg=f"{node.op_type} must have {count} outputs, got {len(node.outputs)}." ) return func(self, node) diff --git a/src/onnx_ir/_shape_type_inference/_engine.py b/src/onnx_ir/_shape_type_inference/_engine.py index d79e815a..db26ac61 100644 --- a/src/onnx_ir/_shape_type_inference/_engine.py +++ b/src/onnx_ir/_shape_type_inference/_engine.py @@ -94,8 +94,18 @@ def _infer_node(self, node: ir.Node, model: ir.Model) -> None: # Perform inference result = inferrer.infer(node) - if result.failure is not None: - raise InferenceError(f"Inference failed: {result.failure}") + if result.status == _common.InferenceStatus.INVALID_NODE: + raise InferenceError(f"Invalid node: {result.msg}") + + if result.status == _common.InferenceStatus.MISSING_INFO: + logger.warning(f"Missing info for node {node.op_type}: {result.msg}") + # Continue with partial inference or skip + if result.values is None: + return # Skip this node + + if result.status == _common.InferenceStatus.PARTIAL: + logger.info(f"Partial inference for node {node.op_type}: {result.msg}") + # Continue with partial results if result.values is None: raise InferenceError("Inference returned no values") diff --git a/src/onnx_ir/_shape_type_inference/ops/standard_ops.py b/src/onnx_ir/_shape_type_inference/ops/standard_ops.py index b3a4955d..b218f624 100644 --- a/src/onnx_ir/_shape_type_inference/ops/standard_ops.py +++ b/src/onnx_ir/_shape_type_inference/ops/standard_ops.py @@ -88,14 +88,18 @@ def infer(self, node: ir.Node) -> _common.InferenceResult: second_type = node.inputs[1].type if first_type is not None and second_type is not None and first_type != second_type: return _common.InferenceResult( - failure=f"Input types do not match: {first_type} vs {second_type}." + status="invalid_node", + msg=f"Input types do not match: {first_type} vs {second_type}." ) # Broadcast the input shapes first_shape = node.inputs[0].shape second_shape = node.inputs[1].shape if first_shape is None or second_shape is None: - return _common.InferenceResult(failure="Input shapes cannot be None.") + return _common.InferenceResult( + status="missing_info", + msg="Input shapes cannot be None." + ) output_shape = broadcast_shapes_bidirectional(first_shape, second_shape) output_type = first_type if first_type is not None else second_type From 3207e849cb37b11d6427df5f0f2da1c3aa95a684 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Mon, 30 Jun 2025 08:43:15 -0700 Subject: [PATCH 23/31] Summary of Complete Refactoring MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit I have successfully updated the entire InferenceResult system across all files in the src/onnx_ir/_shape_type_inference/ops directory: ✅ What Was Accomplished: 1. Converted InferenceResult from dataclass to normal class with string-based status initialization 2. Updated all validation decorators in _common.py to use string status 3. Updated the engine in _engine.py to handle different status types appropriately 4. Updated all 8 operation files in the ops directory: - standard_ops.py (BinaryInferrer) - matmul.py (MatMulInferrer) - concat.py (ConcatInferrer) - reshape.py (ReshapeInferrer) - constant.py (ConstantInferrer) - squeeze.py (Squeeze12Inferrer, Squeeze13Inferrer) - transpose.py (TransposeInferrer) - unsqueeze.py (Unsqueeze12Inferrer, Unsqueeze13Inferrer) 5. Updated exports in __init__.py to include InferenceStatus 6. Updated documentation in README.md with examples ✅ Key Benefits: - More convenient API: status="missing_info" instead of status=InferenceStatus.MISSING_INFO - Type safety: Automatic enum conversion with clear error messages for invalid strings - Better categorization: Proper error classification (missing_info, invalid_node, partial, success) - Cleaner code: Less imports needed, more readable error handling - Graceful degradation: Engine can handle partial inference and missing information The refactoring is now complete and all files consistently use the improved InferenceResult class with string-based status initialization! Signed-off-by: Justin Chu --- .../_shape_type_inference/ops/concat.py | 45 ++++++++++++++----- .../_shape_type_inference/ops/constant.py | 2 +- .../_shape_type_inference/ops/matmul.py | 10 ++++- .../_shape_type_inference/ops/reshape.py | 27 ++++++++--- .../_shape_type_inference/ops/squeeze.py | 10 ++--- .../_shape_type_inference/ops/transpose.py | 8 ++-- .../_shape_type_inference/ops/unsqueeze.py | 12 ++--- 7 files changed, 78 insertions(+), 36 deletions(-) diff --git a/src/onnx_ir/_shape_type_inference/ops/concat.py b/src/onnx_ir/_shape_type_inference/ops/concat.py index 4dadab86..54d4c5a0 100644 --- a/src/onnx_ir/_shape_type_inference/ops/concat.py +++ b/src/onnx_ir/_shape_type_inference/ops/concat.py @@ -16,29 +16,43 @@ def infer(self, node: ir.Node) -> _common.InferenceResult: """Infer the output shape and type for Concat operations.""" if len(node.inputs) < 1: return _common.InferenceResult( - failure="Concat operation must have at least one input." + status="invalid_node", + msg="Concat operation must have at least one input." ) if any(inp is None for inp in node.inputs): - return _common.InferenceResult(failure="Concat operation inputs cannot be None.") + return _common.InferenceResult( + status="missing_info", + msg="Concat operation inputs cannot be None." + ) if len(node.outputs) != 1: return _common.InferenceResult( - failure=f"Concat operation must have exactly one output, got {len(node.outputs)}." + status="invalid_node", + msg=f"Concat operation must have exactly one output, got {len(node.outputs)}." ) # Get axis attribute axis = node.attributes.get_int("axis") if axis is None: - return _common.InferenceResult(failure="Concat operation requires axis attribute.") + return _common.InferenceResult( + status="invalid_node", + msg="Concat operation requires axis attribute." + ) # Get first input shape as base first_shape = node.inputs[0].shape if first_shape is None: - return _common.InferenceResult(failure="Concat input shapes cannot be None.") + return _common.InferenceResult( + status="missing_info", + msg="Concat input shapes cannot be None." + ) first_type = node.inputs[0].type rank = len(first_shape) if rank == 0: - return _common.InferenceResult(failure="Concat inputs cannot be scalars.") + return _common.InferenceResult( + status="invalid_node", + msg="Concat inputs cannot be scalars." + ) # Handle negative axis if axis < 0: @@ -46,7 +60,8 @@ def infer(self, node: ir.Node) -> _common.InferenceResult: if axis < 0 or axis >= rank: return _common.InferenceResult( - failure=f"Concat axis {axis} is out of bounds for rank {rank}." + status="invalid_node", + msg=f"Concat axis {axis} is out of bounds for rank {rank}." ) # Check that all inputs have compatible shapes @@ -55,21 +70,29 @@ def infer(self, node: ir.Node) -> _common.InferenceResult: for i, inp in enumerate(node.inputs[1:], 1): if inp is None: - return _common.InferenceResult(failure=f"Input {i} cannot be None.") + return _common.InferenceResult( + status="missing_info", + msg=f"Input {i} cannot be None." + ) if inp.shape is None: - return _common.InferenceResult(failure=f"Input {i} shape cannot be None.") + return _common.InferenceResult( + status="missing_info", + msg=f"Input {i} shape cannot be None." + ) input_shape = inp.shape if len(input_shape) != rank: return _common.InferenceResult( - failure=f"All inputs must have same rank. Input {i} has rank {len(input_shape)}, expected {rank}." + status="invalid_node", + msg=f"All inputs must have same rank. Input {i} has rank {len(input_shape)}, expected {rank}." ) # TODO(justinchuby): Check non-concat dimensions are compatible concat_dim_size = concat_dim_size + _common.get_expr(input_shape, axis) if inp.type != first_type: return _common.InferenceResult( - failure=f"Input {i} type {inp.type} does not match first input type {first_type}." + status="invalid_node", + msg=f"Input {i} type {inp.type} does not match first input type {first_type}." ) # Set the concat dimension in output shape diff --git a/src/onnx_ir/_shape_type_inference/ops/constant.py b/src/onnx_ir/_shape_type_inference/ops/constant.py index aaf74712..b0f1cc2b 100644 --- a/src/onnx_ir/_shape_type_inference/ops/constant.py +++ b/src/onnx_ir/_shape_type_inference/ops/constant.py @@ -23,7 +23,7 @@ def infer(self, node: ir.Node) -> _common.InferenceResult: assert node.inputs[0] is not None tensor = ir.convenience.get_const_tensor(node.inputs[0]) if tensor is None: - return _common.InferenceResult(failure="Constant tensor cannot be obtained.") + return _common.InferenceResult(status="missing_info", msg="Constant tensor cannot be obtained.") # Create shape from the tensor dimensions output_shape = ir.Shape(tensor.shape) diff --git a/src/onnx_ir/_shape_type_inference/ops/matmul.py b/src/onnx_ir/_shape_type_inference/ops/matmul.py index f414f513..a13942c1 100644 --- a/src/onnx_ir/_shape_type_inference/ops/matmul.py +++ b/src/onnx_ir/_shape_type_inference/ops/matmul.py @@ -22,13 +22,19 @@ def infer(self, node: ir.Node) -> _common.InferenceResult: lhs_shape = node.inputs[0].shape rhs_shape = node.inputs[1].shape if lhs_shape is None or rhs_shape is None: - return _common.InferenceResult(failure="MatMul input shapes cannot be None.") + return _common.InferenceResult( + status="missing_info", + msg="MatMul input shapes cannot be None." + ) lhs_rank = len(lhs_shape) rhs_rank = len(rhs_shape) if lhs_rank == 0 or rhs_rank == 0: - return _common.InferenceResult(failure="MatMul inputs cannot be scalars.") + return _common.InferenceResult( + status="invalid_node", + msg="MatMul inputs cannot be scalars." + ) # Compute output shape based on matrix multiplication rules if lhs_rank == 1 and rhs_rank == 1: diff --git a/src/onnx_ir/_shape_type_inference/ops/reshape.py b/src/onnx_ir/_shape_type_inference/ops/reshape.py index 3754ea33..3a6ec550 100644 --- a/src/onnx_ir/_shape_type_inference/ops/reshape.py +++ b/src/onnx_ir/_shape_type_inference/ops/reshape.py @@ -20,26 +20,37 @@ def infer(self, node: ir.Node) -> _common.InferenceResult: """Infer the output shape and type for Reshape operations.""" if len(node.inputs) != 2: return _common.InferenceResult( - failure=f"Reshape operation must have exactly two inputs, got {len(node.inputs)}." + status="invalid_node", + msg=f"Reshape operation must have exactly two inputs, got {len(node.inputs)}." ) if node.inputs[0] is None or node.inputs[1] is None: - return _common.InferenceResult(failure="Reshape operation inputs cannot be None.") + return _common.InferenceResult( + status="missing_info", + msg="Reshape operation inputs cannot be None." + ) if len(node.outputs) != 1: return _common.InferenceResult( - failure=f"Reshape operation must have exactly one output, got {len(node.outputs)}." + status="invalid_node", + msg=f"Reshape operation must have exactly one output, got {len(node.outputs)}." ) input_shape = node.inputs[0].shape shape_input = node.inputs[1] if input_shape is None: - return _common.InferenceResult(failure="Reshape input shape cannot be None.") + return _common.InferenceResult( + status="missing_info", + msg="Reshape input shape cannot be None." + ) # Try to get the shape values from the second input # For symbolic inference, we may not have concrete values shape = ir.convenience.get_const_tensor(shape_input) if shape is None: - return _common.InferenceResult(failure="Reshape shape input is not known.") + return _common.InferenceResult( + status="missing_info", + msg="Reshape shape input is not known." + ) shape_values = shape.numpy().tolist() @@ -57,7 +68,8 @@ def infer(self, node: ir.Node) -> _common.InferenceResult: if dim_value == -1: if deferred_dim_idx != -1: return _common.InferenceResult( - failure="Reshape can have at most one -1 dimension." + status="invalid_node", + msg="Reshape can have at most one -1 dimension." ) deferred_dim_idx = i output_dims.append(None) # Placeholder @@ -65,7 +77,8 @@ def infer(self, node: ir.Node) -> _common.InferenceResult: # Copy from input shape if i >= len(input_shape): return _common.InferenceResult( - failure=f"Cannot copy dimension {i} from input shape of rank {len(input_shape)}." + status="invalid_node", + msg=f"Cannot copy dimension {i} from input shape of rank {len(input_shape)}." ) dim_expr = _common.get_expr(input_shape, i) output_dims.append(dim_expr) diff --git a/src/onnx_ir/_shape_type_inference/ops/squeeze.py b/src/onnx_ir/_shape_type_inference/ops/squeeze.py index 1376923f..e07215a7 100644 --- a/src/onnx_ir/_shape_type_inference/ops/squeeze.py +++ b/src/onnx_ir/_shape_type_inference/ops/squeeze.py @@ -66,7 +66,7 @@ def infer(self, node: ir.Node) -> _common.InferenceResult: assert input is not None input_shape = input.shape if input_shape is None: - return _common.InferenceResult(failure="Squeeze input shape is not known.") + return _common.InferenceResult(status="missing_info", msg="Squeeze input shape is not known.") rank = len(input_shape) @@ -79,7 +79,7 @@ def infer(self, node: ir.Node) -> _common.InferenceResult: try: axes = _normalize_axes(axes, rank) except ValueError as e: - return _common.InferenceResult(failure=str(e)) + return _common.InferenceResult(status="invalid_node", msg=str(e)) output_shape = _compute_output_shape_with_axes(input_shape, axes) return _common.InferenceResult(values=(ir.Value(shape=output_shape, type=input.type),)) @@ -103,7 +103,7 @@ def infer(self, node: ir.Node) -> _common.InferenceResult: input_shape = node.inputs[0].shape if input_shape is None: - return _common.InferenceResult(failure="Squeeze input shape is not known.") + return _common.InferenceResult(status="missing_info", msg="Squeeze input shape is not known.") rank = len(input_shape) @@ -112,13 +112,13 @@ def infer(self, node: ir.Node) -> _common.InferenceResult: try: axes = _normalize_axes(axes_tensor.numpy().tolist(), rank) except ValueError as e: - return _common.InferenceResult(failure=str(e)) + return _common.InferenceResult(status="invalid_node", msg=str(e)) output_shape = _compute_output_shape_with_axes(input_shape, axes) else: axes_shape = node.inputs[1].shape if axes_shape is None or axes_shape.is_dynamic(): return _common.InferenceResult( - failure="Squeeze axes input shape is not known or is dynamic" + status="missing_info", msg="Squeeze axes input shape is not known or is dynamic" ) removed_axes_count = axes_shape[0] assert isinstance(removed_axes_count, int) diff --git a/src/onnx_ir/_shape_type_inference/ops/transpose.py b/src/onnx_ir/_shape_type_inference/ops/transpose.py index 6c1ee3fb..0dafd793 100644 --- a/src/onnx_ir/_shape_type_inference/ops/transpose.py +++ b/src/onnx_ir/_shape_type_inference/ops/transpose.py @@ -20,7 +20,7 @@ def infer(self, node: ir.Node) -> _common.InferenceResult: assert node.inputs[0] is not None input_shape = node.inputs[0].shape if input_shape is None: - return _common.InferenceResult(failure="Transpose input shape cannot be None.") + return _common.InferenceResult(status="missing_info", msg="Transpose input shape cannot be None.") rank = len(input_shape) @@ -34,12 +34,12 @@ def infer(self, node: ir.Node) -> _common.InferenceResult: # Validate permutation if len(perm) != rank: return _common.InferenceResult( - failure=f"Permutation length {len(perm)} does not match input rank {rank}." + status="invalid_node", msg=f"Permutation length {len(perm)} does not match input rank {rank}." ) if sorted(perm) != list(range(rank)): return _common.InferenceResult( - failure=f"Invalid permutation {perm}. Must be a permutation of [0, 1, ..., {rank - 1}]." + status="invalid_node", msg=f"Invalid permutation {perm}. Must be a permutation of [0, 1, ..., {rank - 1}]." ) # Apply permutation to create output shape @@ -51,7 +51,7 @@ def infer(self, node: ir.Node) -> _common.InferenceResult: if axis < 0 or axis >= rank: return _common.InferenceResult( - failure=f"Permutation axis {axis} is out of bounds for rank {rank}." + status="invalid_node", msg=f"Permutation axis {axis} is out of bounds for rank {rank}." ) # Copy dimension from input to output according to permutation diff --git a/src/onnx_ir/_shape_type_inference/ops/unsqueeze.py b/src/onnx_ir/_shape_type_inference/ops/unsqueeze.py index f1d5c068..91f1a9a2 100644 --- a/src/onnx_ir/_shape_type_inference/ops/unsqueeze.py +++ b/src/onnx_ir/_shape_type_inference/ops/unsqueeze.py @@ -69,7 +69,7 @@ def infer(self, node: ir.Node) -> _common.InferenceResult: assert input is not None input_shape = input.shape if input_shape is None: - return _common.InferenceResult(failure="Unsqueeze input shape is not known.") + return _common.InferenceResult(status="missing_info", msg="Unsqueeze input shape is not known.") input_rank = len(input_shape) @@ -77,7 +77,7 @@ def infer(self, node: ir.Node) -> _common.InferenceResult: axes = node.attributes.get_ints("axes") if axes is None: return _common.InferenceResult( - failure="Unsqueeze operation requires axes attribute." + status="invalid_node", msg="Unsqueeze operation requires axes attribute." ) output_rank = input_rank + len(axes) @@ -85,7 +85,7 @@ def infer(self, node: ir.Node) -> _common.InferenceResult: try: normalized_axes = _normalize_axes(axes, output_rank) except ValueError as e: - return _common.InferenceResult(failure=str(e)) + return _common.InferenceResult(status="invalid_node", msg=str(e)) output_shape = _compute_output_shape(input_shape, normalized_axes) return _common.InferenceResult(values=(ir.Value(shape=output_shape, type=input.type),)) @@ -111,7 +111,7 @@ def infer(self, node: ir.Node) -> _common.InferenceResult: input_shape = node.inputs[0].shape if input_shape is None: - return _common.InferenceResult(failure="Unsqueeze input shape is not known.") + return _common.InferenceResult(status="missing_info", msg="Unsqueeze input shape is not known.") input_rank = len(input_shape) @@ -126,7 +126,7 @@ def infer(self, node: ir.Node) -> _common.InferenceResult: try: normalized_axes = _normalize_axes(axes, output_rank) except ValueError as e: - return _common.InferenceResult(failure=str(e)) + return _common.InferenceResult(status="invalid_node", msg=str(e)) output_shape = _compute_output_shape(input_shape, normalized_axes) else: @@ -134,7 +134,7 @@ def infer(self, node: ir.Node) -> _common.InferenceResult: axes_shape = node.inputs[1].shape if axes_shape is None or axes_shape.is_dynamic(): return _common.InferenceResult( - failure="Unsqueeze axes input shape is not known or is dynamic" + status="missing_info", msg="Unsqueeze axes input shape is not known or is dynamic" ) # We know the number of axes to insert but not their positions From a572145df7f6022b242e4f283e86bb73a50ba25e Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Mon, 30 Jun 2025 08:44:26 -0700 Subject: [PATCH 24/31] lint Signed-off-by: Justin Chu --- src/onnx_ir/_shape_type_inference/__init__.py | 6 +++- src/onnx_ir/_shape_type_inference/_common.py | 15 +++++----- src/onnx_ir/_shape_type_inference/_engine.py | 4 +-- .../_shape_type_inference/ops/concat.py | 30 +++++++------------ .../_shape_type_inference/ops/constant.py | 6 ++-- .../_shape_type_inference/ops/matmul.py | 7 ++--- .../_shape_type_inference/ops/reshape.py | 20 +++++-------- .../_shape_type_inference/ops/squeeze.py | 11 +++++-- .../_shape_type_inference/ops/standard_ops.py | 5 ++-- .../_shape_type_inference/ops/transpose.py | 13 +++++--- .../_shape_type_inference/ops/unsqueeze.py | 11 +++++-- 11 files changed, 64 insertions(+), 64 deletions(-) diff --git a/src/onnx_ir/_shape_type_inference/__init__.py b/src/onnx_ir/_shape_type_inference/__init__.py index b3d8fb69..a9a892e3 100644 --- a/src/onnx_ir/_shape_type_inference/__init__.py +++ b/src/onnx_ir/_shape_type_inference/__init__.py @@ -9,7 +9,11 @@ ] -from onnx_ir._shape_type_inference._common import InferenceResult, InferenceStatus, NodeInferrer +from onnx_ir._shape_type_inference._common import ( + InferenceResult, + InferenceStatus, + NodeInferrer, +) from onnx_ir._shape_type_inference._engine import ( InferenceError, SymbolicInferenceEngine, diff --git a/src/onnx_ir/_shape_type_inference/_common.py b/src/onnx_ir/_shape_type_inference/_common.py index c39923d8..5c27830b 100644 --- a/src/onnx_ir/_shape_type_inference/_common.py +++ b/src/onnx_ir/_shape_type_inference/_common.py @@ -12,7 +12,6 @@ import onnx_ir as ir - MAX_SUPPORTED_OPSET = 23 @@ -37,10 +36,11 @@ def get_expr(shape: ir.Shape, index: int) -> sympy.Expr: @enum.unique -class InferenceStatus(enum.Enum): +class InferenceStatus(enum.StrEnum): """Status of shape inference operation.""" - SUCCESS = "success" # Complete inference successful - PARTIAL = "partial" # Partial information available (e.g., type only, rank only) + + SUCCESS = "success" # Complete inference successful + PARTIAL = "partial" # Partial information available (e.g., type only, rank only) MISSING_INFO = "missing_info" # Missing required input information INVALID_NODE = "invalid_node" # Node is invalid or malformed @@ -127,13 +127,12 @@ def wrapper(self, node: ir.Node) -> InferenceResult: if len(node.inputs) != count: return InferenceResult( status="invalid_node", - msg=f"{node.op_type} must have {count} inputs, got {len(node.inputs)}." + msg=f"{node.op_type} must have {count} inputs, got {len(node.inputs)}.", ) for i, inp in enumerate(node.inputs): if inp is None: return InferenceResult( - status="missing_info", - msg=f"{node.op_type} input {i} cannot be None." + status="missing_info", msg=f"{node.op_type} input {i} cannot be None." ) return func(self, node) @@ -164,7 +163,7 @@ def wrapper(self, node: ir.Node) -> InferenceResult: if len(node.outputs) != count: return InferenceResult( status="invalid_node", - msg=f"{node.op_type} must have {count} outputs, got {len(node.outputs)}." + msg=f"{node.op_type} must have {count} outputs, got {len(node.outputs)}.", ) return func(self, node) diff --git a/src/onnx_ir/_shape_type_inference/_engine.py b/src/onnx_ir/_shape_type_inference/_engine.py index db26ac61..8f639da2 100644 --- a/src/onnx_ir/_shape_type_inference/_engine.py +++ b/src/onnx_ir/_shape_type_inference/_engine.py @@ -68,8 +68,8 @@ def infer_model(self, model: ir.Model) -> None: self._infer_node(node, model) logger.debug(f"Successfully inferred node {i}: {node.op_type}") except Exception as e: - error_msg = f"Failed to infer node {i} ({node.op_type}): {str(e)}" - logger.error(error_msg) + error_msg = f"Failed to infer node {i} ({node.op_type}): {e!s}" + logger.exception(error_msg) raise InferenceError(error_msg) from e logger.info("Model inference completed successfully") diff --git a/src/onnx_ir/_shape_type_inference/ops/concat.py b/src/onnx_ir/_shape_type_inference/ops/concat.py index 54d4c5a0..98adb1c6 100644 --- a/src/onnx_ir/_shape_type_inference/ops/concat.py +++ b/src/onnx_ir/_shape_type_inference/ops/concat.py @@ -1,6 +1,5 @@ """Concat operation inferrer for ONNX IR nodes.""" - import onnx_ir as ir from onnx_ir._shape_type_inference import _common @@ -16,42 +15,37 @@ def infer(self, node: ir.Node) -> _common.InferenceResult: """Infer the output shape and type for Concat operations.""" if len(node.inputs) < 1: return _common.InferenceResult( - status="invalid_node", - msg="Concat operation must have at least one input." + status="invalid_node", msg="Concat operation must have at least one input." ) if any(inp is None for inp in node.inputs): return _common.InferenceResult( - status="missing_info", - msg="Concat operation inputs cannot be None." + status="missing_info", msg="Concat operation inputs cannot be None." ) if len(node.outputs) != 1: return _common.InferenceResult( status="invalid_node", - msg=f"Concat operation must have exactly one output, got {len(node.outputs)}." + msg=f"Concat operation must have exactly one output, got {len(node.outputs)}.", ) # Get axis attribute axis = node.attributes.get_int("axis") if axis is None: return _common.InferenceResult( - status="invalid_node", - msg="Concat operation requires axis attribute." + status="invalid_node", msg="Concat operation requires axis attribute." ) # Get first input shape as base first_shape = node.inputs[0].shape if first_shape is None: return _common.InferenceResult( - status="missing_info", - msg="Concat input shapes cannot be None." + status="missing_info", msg="Concat input shapes cannot be None." ) first_type = node.inputs[0].type rank = len(first_shape) if rank == 0: return _common.InferenceResult( - status="invalid_node", - msg="Concat inputs cannot be scalars." + status="invalid_node", msg="Concat inputs cannot be scalars." ) # Handle negative axis @@ -61,7 +55,7 @@ def infer(self, node: ir.Node) -> _common.InferenceResult: if axis < 0 or axis >= rank: return _common.InferenceResult( status="invalid_node", - msg=f"Concat axis {axis} is out of bounds for rank {rank}." + msg=f"Concat axis {axis} is out of bounds for rank {rank}.", ) # Check that all inputs have compatible shapes @@ -71,20 +65,18 @@ def infer(self, node: ir.Node) -> _common.InferenceResult: for i, inp in enumerate(node.inputs[1:], 1): if inp is None: return _common.InferenceResult( - status="missing_info", - msg=f"Input {i} cannot be None." + status="missing_info", msg=f"Input {i} cannot be None." ) if inp.shape is None: return _common.InferenceResult( - status="missing_info", - msg=f"Input {i} shape cannot be None." + status="missing_info", msg=f"Input {i} shape cannot be None." ) input_shape = inp.shape if len(input_shape) != rank: return _common.InferenceResult( status="invalid_node", - msg=f"All inputs must have same rank. Input {i} has rank {len(input_shape)}, expected {rank}." + msg=f"All inputs must have same rank. Input {i} has rank {len(input_shape)}, expected {rank}.", ) # TODO(justinchuby): Check non-concat dimensions are compatible @@ -92,7 +84,7 @@ def infer(self, node: ir.Node) -> _common.InferenceResult: if inp.type != first_type: return _common.InferenceResult( status="invalid_node", - msg=f"Input {i} type {inp.type} does not match first input type {first_type}." + msg=f"Input {i} type {inp.type} does not match first input type {first_type}.", ) # Set the concat dimension in output shape diff --git a/src/onnx_ir/_shape_type_inference/ops/constant.py b/src/onnx_ir/_shape_type_inference/ops/constant.py index b0f1cc2b..d9338ec4 100644 --- a/src/onnx_ir/_shape_type_inference/ops/constant.py +++ b/src/onnx_ir/_shape_type_inference/ops/constant.py @@ -2,8 +2,6 @@ from __future__ import annotations -import sys - import onnx_ir as ir from onnx_ir._shape_type_inference import _common @@ -23,7 +21,9 @@ def infer(self, node: ir.Node) -> _common.InferenceResult: assert node.inputs[0] is not None tensor = ir.convenience.get_const_tensor(node.inputs[0]) if tensor is None: - return _common.InferenceResult(status="missing_info", msg="Constant tensor cannot be obtained.") + return _common.InferenceResult( + status="missing_info", msg="Constant tensor cannot be obtained." + ) # Create shape from the tensor dimensions output_shape = ir.Shape(tensor.shape) diff --git a/src/onnx_ir/_shape_type_inference/ops/matmul.py b/src/onnx_ir/_shape_type_inference/ops/matmul.py index a13942c1..fb61b6d5 100644 --- a/src/onnx_ir/_shape_type_inference/ops/matmul.py +++ b/src/onnx_ir/_shape_type_inference/ops/matmul.py @@ -1,6 +1,5 @@ """MatMul operation inferrer for ONNX IR nodes.""" - import onnx_ir as ir from onnx_ir._shape_type_inference import _common from onnx_ir._shape_type_inference.ops.standard_ops import broadcast_shapes_bidirectional @@ -23,8 +22,7 @@ def infer(self, node: ir.Node) -> _common.InferenceResult: rhs_shape = node.inputs[1].shape if lhs_shape is None or rhs_shape is None: return _common.InferenceResult( - status="missing_info", - msg="MatMul input shapes cannot be None." + status="missing_info", msg="MatMul input shapes cannot be None." ) lhs_rank = len(lhs_shape) @@ -32,8 +30,7 @@ def infer(self, node: ir.Node) -> _common.InferenceResult: if lhs_rank == 0 or rhs_rank == 0: return _common.InferenceResult( - status="invalid_node", - msg="MatMul inputs cannot be scalars." + status="invalid_node", msg="MatMul inputs cannot be scalars." ) # Compute output shape based on matrix multiplication rules diff --git a/src/onnx_ir/_shape_type_inference/ops/reshape.py b/src/onnx_ir/_shape_type_inference/ops/reshape.py index 3a6ec550..80659b5e 100644 --- a/src/onnx_ir/_shape_type_inference/ops/reshape.py +++ b/src/onnx_ir/_shape_type_inference/ops/reshape.py @@ -1,7 +1,5 @@ """Reshape operation inferrer for ONNX IR nodes.""" -import sys - import sympy import onnx_ir as ir @@ -21,17 +19,16 @@ def infer(self, node: ir.Node) -> _common.InferenceResult: if len(node.inputs) != 2: return _common.InferenceResult( status="invalid_node", - msg=f"Reshape operation must have exactly two inputs, got {len(node.inputs)}." + msg=f"Reshape operation must have exactly two inputs, got {len(node.inputs)}.", ) if node.inputs[0] is None or node.inputs[1] is None: return _common.InferenceResult( - status="missing_info", - msg="Reshape operation inputs cannot be None." + status="missing_info", msg="Reshape operation inputs cannot be None." ) if len(node.outputs) != 1: return _common.InferenceResult( status="invalid_node", - msg=f"Reshape operation must have exactly one output, got {len(node.outputs)}." + msg=f"Reshape operation must have exactly one output, got {len(node.outputs)}.", ) input_shape = node.inputs[0].shape @@ -39,8 +36,7 @@ def infer(self, node: ir.Node) -> _common.InferenceResult: if input_shape is None: return _common.InferenceResult( - status="missing_info", - msg="Reshape input shape cannot be None." + status="missing_info", msg="Reshape input shape cannot be None." ) # Try to get the shape values from the second input @@ -48,8 +44,7 @@ def infer(self, node: ir.Node) -> _common.InferenceResult: shape = ir.convenience.get_const_tensor(shape_input) if shape is None: return _common.InferenceResult( - status="missing_info", - msg="Reshape shape input is not known." + status="missing_info", msg="Reshape shape input is not known." ) shape_values = shape.numpy().tolist() @@ -68,8 +63,7 @@ def infer(self, node: ir.Node) -> _common.InferenceResult: if dim_value == -1: if deferred_dim_idx != -1: return _common.InferenceResult( - status="invalid_node", - msg="Reshape can have at most one -1 dimension." + status="invalid_node", msg="Reshape can have at most one -1 dimension." ) deferred_dim_idx = i output_dims.append(None) # Placeholder @@ -78,7 +72,7 @@ def infer(self, node: ir.Node) -> _common.InferenceResult: if i >= len(input_shape): return _common.InferenceResult( status="invalid_node", - msg=f"Cannot copy dimension {i} from input shape of rank {len(input_shape)}." + msg=f"Cannot copy dimension {i} from input shape of rank {len(input_shape)}.", ) dim_expr = _common.get_expr(input_shape, i) output_dims.append(dim_expr) diff --git a/src/onnx_ir/_shape_type_inference/ops/squeeze.py b/src/onnx_ir/_shape_type_inference/ops/squeeze.py index e07215a7..bfe5e582 100644 --- a/src/onnx_ir/_shape_type_inference/ops/squeeze.py +++ b/src/onnx_ir/_shape_type_inference/ops/squeeze.py @@ -66,7 +66,9 @@ def infer(self, node: ir.Node) -> _common.InferenceResult: assert input is not None input_shape = input.shape if input_shape is None: - return _common.InferenceResult(status="missing_info", msg="Squeeze input shape is not known.") + return _common.InferenceResult( + status="missing_info", msg="Squeeze input shape is not known." + ) rank = len(input_shape) @@ -103,7 +105,9 @@ def infer(self, node: ir.Node) -> _common.InferenceResult: input_shape = node.inputs[0].shape if input_shape is None: - return _common.InferenceResult(status="missing_info", msg="Squeeze input shape is not known.") + return _common.InferenceResult( + status="missing_info", msg="Squeeze input shape is not known." + ) rank = len(input_shape) @@ -118,7 +122,8 @@ def infer(self, node: ir.Node) -> _common.InferenceResult: axes_shape = node.inputs[1].shape if axes_shape is None or axes_shape.is_dynamic(): return _common.InferenceResult( - status="missing_info", msg="Squeeze axes input shape is not known or is dynamic" + status="missing_info", + msg="Squeeze axes input shape is not known or is dynamic", ) removed_axes_count = axes_shape[0] assert isinstance(removed_axes_count, int) diff --git a/src/onnx_ir/_shape_type_inference/ops/standard_ops.py b/src/onnx_ir/_shape_type_inference/ops/standard_ops.py index b218f624..90b4201e 100644 --- a/src/onnx_ir/_shape_type_inference/ops/standard_ops.py +++ b/src/onnx_ir/_shape_type_inference/ops/standard_ops.py @@ -89,7 +89,7 @@ def infer(self, node: ir.Node) -> _common.InferenceResult: if first_type is not None and second_type is not None and first_type != second_type: return _common.InferenceResult( status="invalid_node", - msg=f"Input types do not match: {first_type} vs {second_type}." + msg=f"Input types do not match: {first_type} vs {second_type}.", ) # Broadcast the input shapes @@ -97,8 +97,7 @@ def infer(self, node: ir.Node) -> _common.InferenceResult: second_shape = node.inputs[1].shape if first_shape is None or second_shape is None: return _common.InferenceResult( - status="missing_info", - msg="Input shapes cannot be None." + status="missing_info", msg="Input shapes cannot be None." ) output_shape = broadcast_shapes_bidirectional(first_shape, second_shape) diff --git a/src/onnx_ir/_shape_type_inference/ops/transpose.py b/src/onnx_ir/_shape_type_inference/ops/transpose.py index 0dafd793..4b221f73 100644 --- a/src/onnx_ir/_shape_type_inference/ops/transpose.py +++ b/src/onnx_ir/_shape_type_inference/ops/transpose.py @@ -20,7 +20,9 @@ def infer(self, node: ir.Node) -> _common.InferenceResult: assert node.inputs[0] is not None input_shape = node.inputs[0].shape if input_shape is None: - return _common.InferenceResult(status="missing_info", msg="Transpose input shape cannot be None.") + return _common.InferenceResult( + status="missing_info", msg="Transpose input shape cannot be None." + ) rank = len(input_shape) @@ -34,12 +36,14 @@ def infer(self, node: ir.Node) -> _common.InferenceResult: # Validate permutation if len(perm) != rank: return _common.InferenceResult( - status="invalid_node", msg=f"Permutation length {len(perm)} does not match input rank {rank}." + status="invalid_node", + msg=f"Permutation length {len(perm)} does not match input rank {rank}.", ) if sorted(perm) != list(range(rank)): return _common.InferenceResult( - status="invalid_node", msg=f"Invalid permutation {perm}. Must be a permutation of [0, 1, ..., {rank - 1}]." + status="invalid_node", + msg=f"Invalid permutation {perm}. Must be a permutation of [0, 1, ..., {rank - 1}].", ) # Apply permutation to create output shape @@ -51,7 +55,8 @@ def infer(self, node: ir.Node) -> _common.InferenceResult: if axis < 0 or axis >= rank: return _common.InferenceResult( - status="invalid_node", msg=f"Permutation axis {axis} is out of bounds for rank {rank}." + status="invalid_node", + msg=f"Permutation axis {axis} is out of bounds for rank {rank}.", ) # Copy dimension from input to output according to permutation diff --git a/src/onnx_ir/_shape_type_inference/ops/unsqueeze.py b/src/onnx_ir/_shape_type_inference/ops/unsqueeze.py index 91f1a9a2..4559f4cf 100644 --- a/src/onnx_ir/_shape_type_inference/ops/unsqueeze.py +++ b/src/onnx_ir/_shape_type_inference/ops/unsqueeze.py @@ -69,7 +69,9 @@ def infer(self, node: ir.Node) -> _common.InferenceResult: assert input is not None input_shape = input.shape if input_shape is None: - return _common.InferenceResult(status="missing_info", msg="Unsqueeze input shape is not known.") + return _common.InferenceResult( + status="missing_info", msg="Unsqueeze input shape is not known." + ) input_rank = len(input_shape) @@ -111,7 +113,9 @@ def infer(self, node: ir.Node) -> _common.InferenceResult: input_shape = node.inputs[0].shape if input_shape is None: - return _common.InferenceResult(status="missing_info", msg="Unsqueeze input shape is not known.") + return _common.InferenceResult( + status="missing_info", msg="Unsqueeze input shape is not known." + ) input_rank = len(input_shape) @@ -134,7 +138,8 @@ def infer(self, node: ir.Node) -> _common.InferenceResult: axes_shape = node.inputs[1].shape if axes_shape is None or axes_shape.is_dynamic(): return _common.InferenceResult( - status="missing_info", msg="Unsqueeze axes input shape is not known or is dynamic" + status="missing_info", + msg="Unsqueeze axes input shape is not known or is dynamic", ) # We know the number of axes to insert but not their positions From 11f895844c9bf3b22b96b654b0dd7ecbc18fdea7 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Mon, 30 Jun 2025 08:44:52 -0700 Subject: [PATCH 25/31] Removes unused shape inference code Signed-off-by: Justin Chu --- .../_shape_type_inference/_inferencer.py | 3472 ----------------- 1 file changed, 3472 deletions(-) delete mode 100644 src/onnx_ir/_shape_type_inference/_inferencer.py diff --git a/src/onnx_ir/_shape_type_inference/_inferencer.py b/src/onnx_ir/_shape_type_inference/_inferencer.py deleted file mode 100644 index 51d6499f..00000000 --- a/src/onnx_ir/_shape_type_inference/_inferencer.py +++ /dev/null @@ -1,3472 +0,0 @@ -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. - -import argparse -import logging - -import numpy as np -import onnx -import sympy -from onnx import helper, numpy_helper, shape_inference - -logger = logging.getLogger(__name__) - - -def get_attribute(node, attr_name, default_value=None): - """Retrieve the value of an attribute from an ONNX node, returning a default if the attribute is not found.""" - found = [attr for attr in node.attribute if attr.name == attr_name] - return helper.get_attribute_value(found[0]) if found else default_value - - -def get_dim_from_proto(dim): - """Retrieve the dimension value from the ONNX protobuf object if it is a string.""" - return ( - getattr(dim, dim.WhichOneof("value")) if type(dim.WhichOneof("value")) is str else None - ) - - -def is_sequence(type_proto): - """Check if the given ONNX proto type is a sequence.""" - cls_type = type_proto.WhichOneof("value") - assert cls_type in {"tensor_type", "sequence_type"} - return cls_type == "sequence_type" - - -def get_shape_from_type_proto(type_proto): - """Extract the shape of a tensor from an ONNX type proto if available, otherwise return None.""" - assert not is_sequence(type_proto) - if type_proto.tensor_type.HasField("shape"): - return [get_dim_from_proto(d) for d in type_proto.tensor_type.shape.dim] - else: - return None # note no shape is different from shape without dim (scalar) - - -def get_elem_type_from_type_proto(type_proto): - """Return the element type from a given TypeProto object, either from sequence type or tensor type.""" - if is_sequence(type_proto): - return type_proto.sequence_type.elem_type.tensor_type.elem_type - else: - return type_proto.tensor_type.elem_type - - -def get_shape_from_value_info(vi): - """Return the shape from the given ValueInfoProto object, either from sequence type or tensor type.""" - cls_type = vi.type.WhichOneof("value") - if cls_type is None: - return None - if not is_sequence(vi.type): - return get_shape_from_type_proto(vi.type) - if vi.type.sequence_type.elem_type.WhichOneof("value") == "tensor_type": - return get_shape_from_type_proto(vi.type.sequence_type.elem_type) - else: - return None - - -def make_named_value_info(name): - """Create and return an ONNX ValueInfoProto object with the specified name.""" - vi = onnx.ValueInfoProto() - vi.name = name - return vi - - -def get_shape_from_sympy_shape(sympy_shape): - """Convert a sympy shape to a list with int, str, or None elements.""" - return [None if i is None else (int(i) if is_literal(i) else str(i)) for i in sympy_shape] - - -def is_literal(dim): - """Check if a dimension is a literal number (int, np.int64, np.int32, sympy.Integer) or has an 'is_number' - attribute. - """ - return type(dim) in {int, np.int64, np.int32, sympy.Integer} or ( - hasattr(dim, "is_number") and dim.is_number - ) - - -def handle_negative_axis(axis, rank): - """Convert a potentially negative axis to a positive axis based on the given rank.""" - assert axis < rank and axis >= -rank - return axis if axis >= 0 else rank + axis - - -def get_opset(mp, domain=None): - """Retrieve the opset version for a given model namespace, defaulting to common ONNX domains if no specific domain - is provided. - """ - domain = domain or ["", "onnx", "ai.onnx"] - if type(domain) != list: - domain = [domain] - for opset in mp.opset_import: - if opset.domain in domain: - return opset.version - - return None - - -def as_scalar(x): - """Convert input to scalar if input is a list with a single item or a NumPy ndarray.""" - if type(x) == list: - assert len(x) == 1 - return x[0] - elif type(x) == np.ndarray: - return x.item() - else: - return x - - -def as_list(x, keep_none): - """Convert input to list, optionally preserving None values.""" - if type(x) == list: - return x - elif type(x) == np.ndarray: - return list(x) - elif keep_none and x is None: - return None - else: - return [x] - - -def sympy_reduce_product(x): - """Reduce a list or element to a product using Sympy's Integer.""" - if type(x) == list: - value = sympy.Integer(1) - for v in x: - value = value * v - else: - value = x - return value - - -class SymbolicShapeInference: - def __init__(self, int_max, auto_merge, guess_output_rank, verbose, prefix=""): - """Initializes the SymbolicShapeInference class with configuration parameters for symbolic shape inference.""" - self.dispatcher_ = { - "Add": self._infer_symbolic_compute_ops, - "ArrayFeatureExtractor": self._infer_ArrayFeatureExtractor, - "AveragePool": self._infer_Pool, - "BatchNormalization": self._infer_BatchNormalization, - "Cast": self._infer_Cast, - "CategoryMapper": self._infer_CategoryMapper, - "Compress": self._infer_Compress, - "Concat": self._infer_Concat, - "ConcatFromSequence": self._infer_ConcatFromSequence, - "Constant": self._infer_Constant, - "ConstantOfShape": self._infer_ConstantOfShape, - "Conv": self._infer_Conv, - "CumSum": self._pass_on_shape_and_type, - "Div": self._infer_symbolic_compute_ops, - "Einsum": self._infer_Einsum, - "Expand": self._infer_Expand, - "Equal": self._infer_symbolic_compute_ops, - "Floor": self._infer_symbolic_compute_ops, - "Gather": self._infer_Gather, - "GatherElements": self._infer_GatherElements, - "GatherND": self._infer_GatherND, - "Identity": self._pass_on_shape_and_type, - "AllReduce": self._pass_on_shape_and_type, - "If": self._infer_If, - "Loop": self._infer_Loop, - "MatMul": self._infer_MatMul, - "MatMulInteger16": self._infer_MatMulInteger, - "MaxPool": self._infer_Pool, - "Max": self._infer_symbolic_compute_ops, - "MemcpyFromHost": self._pass_on_shape_and_type, - "MemcpyToHost": self._pass_on_shape_and_type, - "Min": self._infer_symbolic_compute_ops, - "MoE": self._pass_on_shape_and_type, - "Mul": self._infer_symbolic_compute_ops, - "NonMaxSuppression": self._infer_NonMaxSuppression, - "NonZero": self._infer_NonZero, - "OneHot": self._infer_OneHot, - "Pad": self._infer_Pad, - "Range": self._infer_Range, - "Reciprocal": self._pass_on_shape_and_type, - "ReduceSum": self._infer_ReduceSum, - "ReduceProd": self._infer_ReduceProd, - "Reshape": self._infer_Reshape, - "Resize": self._infer_Resize, - "Round": self._pass_on_shape_and_type, - "Scan": self._infer_Scan, - "ScatterElements": self._infer_ScatterElements, - "SequenceAt": self._infer_SequenceAt, - "SequenceInsert": self._infer_SequenceInsert, - "Shape": self._infer_Shape, - "Size": self._infer_Size, - "Slice": self._infer_Slice, - "SoftmaxCrossEntropyLoss": self._infer_SoftmaxCrossEntropyLoss, - "SoftmaxCrossEntropyLossInternal": self._infer_SoftmaxCrossEntropyLoss, - "NegativeLogLikelihoodLossInternal": self._infer_SoftmaxCrossEntropyLoss, - "Split": self._infer_Split, - "SplitToSequence": self._infer_SplitToSequence, - "Squeeze": self._infer_Squeeze, - "Sub": self._infer_symbolic_compute_ops, - "Tile": self._infer_Tile, - "TopK": self._infer_TopK, - "Transpose": self._infer_Transpose, - "Unsqueeze": self._infer_Unsqueeze, - "Where": self._infer_symbolic_compute_ops, - "ZipMap": self._infer_ZipMap, - "Neg": self._infer_symbolic_compute_ops, - # contrib ops: - "Attention": self._infer_Attention, - "BiasAdd": self._infer_BiasAdd, - "BiasGelu": self._infer_BiasGelu, - "BiasSplitGelu": self._infer_BiasSplitGelu, - "DecoderMaskedMultiHeadAttention": self._infer_DecoderMaskedMultiHeadAttention, - "DequantizeLinear": self._infer_DequantizeLinear, - "EmbedLayerNormalization": self._infer_EmbedLayerNormalization, - "FastGelu": self._infer_FastGelu, - "GatedRelativePositionBias": self._infer_GatedRelativePositionBias, - "Gelu": self._infer_Gelu, - "GemmFastGelu": self._infer_GemmFastGelu, - "GemmFloat8": self._infer_GemmFloat8, - "GroupNorm": self._infer_GroupNorm, - "SkipGroupNorm": self._infer_SkipGroupNorm, - "LayerNormalization": self._infer_LayerNormalization, - "LongformerAttention": self._infer_LongformerAttention, - "MultiHeadAttention": self._infer_MultiHeadAttention, - "NhwcConv": self._infer_NhwcConv, - "PackedAttention": self._infer_PackedAttention, - "PackedMultiHeadAttention": self._infer_PackedMultiHeadAttention, - "MultiScaleDeformableAttnTRT": self._infer_MultiScaleDeformableAttnTRT, - "PythonOp": self._infer_PythonOp, - "QuantizeLinear": self._infer_QuantizeLinear, - "QuickGelu": self._infer_FastGelu, - "RelativePositionBias": self._infer_RelativePositionBias, - "RemovePadding": self._infer_RemovePadding, - "RestorePadding": self._infer_RestorePadding, - "RotaryEmbedding": self._infer_RotaryEmbedding, - "SimplifiedLayerNormalization": self._infer_LayerNormalization, - "SkipLayerNormalization": self._infer_SkipLayerNormalization, - "SkipSimplifiedLayerNormalization": self._infer_SkipLayerNormalization, - } - self.aten_op_dispatcher_ = { - "embedding": self._infer_Gather, - "bitwise_or": self._infer_aten_bitwise_or, - "diagonal": self._infer_aten_diagonal, - "max_pool2d_with_indices": self._infer_aten_pool2d, - "max": self._infer_aten_minmax, - "min": self._infer_aten_minmax, - "multinomial": self._infer_aten_multinomial, - "unfold": self._infer_aten_unfold, - "argmax": self._infer_aten_argmax, - "avg_pool2d": self._infer_aten_pool2d, - "_adaptive_avg_pool2d": self._infer_aten_pool2d, - "numpy_T": self._infer_Transpose, - "native_group_norm": self._infer_aten_group_norm, - "upsample_nearest1d": self._infer_aten_upsample, - "upsample_nearest2d": self._infer_aten_upsample, - "upsample_nearest3d": self._infer_aten_upsample, - "upsample_bicubic2d": self._infer_aten_upsample, - } - self.run_ = True - self.suggested_merge_ = {} - self.symbolic_dims_ = {} - self.input_symbols_ = {} - self.auto_merge_ = auto_merge - self.guess_output_rank_ = guess_output_rank - self.verbose_ = verbose - self.int_max_ = int_max - self.subgraph_id_ = 0 - self.prefix_ = prefix - - def _add_suggested_merge(self, symbols, apply=False): - """Add suggested merges for input symbols, prioritizing literals, input symbolic dims, or existing symbolic - dims. - """ - assert all( - (type(s) == str and s in self.symbolic_dims_) or is_literal(s) for s in symbols - ) - symbols = set(symbols) - for k, v in self.suggested_merge_.items(): - if k in symbols: - symbols.remove(k) - symbols.add(v) - map_to = None - # if there is literal, map to it first - for s in symbols: - if is_literal(s): - map_to = s - break - # when no literals, map to input symbolic dims, then existing symbolic dims - if map_to is None: - for s in symbols: - if s in self.input_symbols_: - map_to = s - break - if map_to is None: - for s in symbols: - if type(self.symbolic_dims_[s]) == sympy.Symbol: - map_to = s - break - # when nothing to map to, use the shorter one - if map_to is None: - if self.verbose_ > 0: - logger.warning( - f"Potential unsafe merge between symbolic expressions: ({','.join(symbols)})" - ) - symbols_list = list(symbols) - lens = [len(s) for s in symbols_list] - map_to = symbols_list[lens.index(min(lens))] - symbols.remove(map_to) - - for s in symbols: - if s == map_to: - continue - if is_literal(map_to) and is_literal(s): - assert int(map_to) == int(s) - self.suggested_merge_[s] = int(map_to) if is_literal(map_to) else map_to - for k, v in self.suggested_merge_.items(): - if v == s: - self.suggested_merge_[k] = map_to - if apply and self.auto_merge_: - self._apply_suggested_merge() - - def _apply_suggested_merge(self, graph_input_only=False): - """Applies suggested merges to graph dimensions based on predefined rules in the `suggested_merge_` - dictionary. - """ - if not self.suggested_merge_: - return - for i in list(self.out_mp_.graph.input) + ( - [] if graph_input_only else list(self.out_mp_.graph.value_info) - ): - for d in i.type.tensor_type.shape.dim: - if d.dim_param in self.suggested_merge_: - v = self.suggested_merge_[d.dim_param] - if is_literal(v): - d.dim_value = int(v) - else: - d.dim_param = v - - def _preprocess(self, in_mp): - """Preprocess ONNX model by copying its structure and updating graph input and initializer dictionaries.""" - self.out_mp_ = onnx.ModelProto() - self.out_mp_.CopyFrom(in_mp) - self.graph_inputs_ = {i.name: i for i in list(self.out_mp_.graph.input)} - self.initializers_ = {i.name: i for i in self.out_mp_.graph.initializer} - self.known_vi_ = {i.name: i for i in list(self.out_mp_.graph.input)} - self.known_vi_.update( - { - i.name: helper.make_tensor_value_info(i.name, i.data_type, list(i.dims)) - for i in self.out_mp_.graph.initializer - } - ) - - def _merge_symbols(self, dims): - """Merge dimension symbols, handling automatic merging and validation of symbolic dimensions.""" - if any(type(d) != str for d in dims): - if not self.auto_merge_: - return None - unique_dims = list(set(dims)) - is_int = [is_literal(d) for d in unique_dims] - assert sum(is_int) <= 1 # if there are more than 1 unique ints, something is wrong - if sum(is_int) == 1: - int_dim = is_int.index(1) - if self.verbose_ > 0: - logger.debug( - f"dim {unique_dims[:int_dim] + unique_dims[int_dim + 1 :]} has been merged with value {unique_dims[int_dim]}" - ) - self._check_merged_dims(unique_dims, allow_broadcast=False) - return unique_dims[int_dim] - else: - if self.verbose_ > 0: - logger.debug( - f"dim {unique_dims[1:]} has been merged with dim {unique_dims[0]}" - ) - return dims[0] - if all(d == dims[0] for d in dims): - return dims[0] - merged = [self.suggested_merge_[d] if d in self.suggested_merge_ else d for d in dims] - if all(d == merged[0] for d in merged): - assert merged[0] in self.symbolic_dims_ - return merged[0] - else: - return None - - # broadcast from right to left, and merge symbolic dims if needed - def _broadcast_shapes(self, shape1, shape2): - """Broadcast two shapes from right to left, merging symbolic dimensions if necessary.""" - new_shape = [] - rank1 = len(shape1) - rank2 = len(shape2) - new_rank = max(rank1, rank2) - for i in range(new_rank): - dim1 = shape1[rank1 - 1 - i] if i < rank1 else 1 - dim2 = shape2[rank2 - 1 - i] if i < rank2 else 1 - if dim1 in [1, dim2]: - new_dim = dim2 - elif dim2 == 1: - new_dim = dim1 - else: - new_dim = self._merge_symbols([dim1, dim2]) - if not new_dim: - # warning about unsupported broadcast when not auto merge - # note that auto merge has the risk of incorrectly merge symbols while one of them being 1 - # for example, 'a' = 1, 'b' = 5 at runtime is valid broadcasting, but with auto merge 'a' == 'b' - if self.auto_merge_: - self._add_suggested_merge([dim1, dim2], apply=True) - else: - # TODO(justinchuby): Error? - logger.warning( - "unsupported broadcast between %s %s", dim1, dim2 - ) - new_shape = [new_dim, *new_shape] - return new_shape - - def _get_shape(self, node, idx): - """Retrieve the shape of a tensor from a node's inputs based on known value info or initializers.""" - name = node.input[idx] - if name in self.known_vi_: - vi = self.known_vi_[name] - return get_shape_from_value_info(vi) - else: - assert name in self.initializers_ - return list(self.initializers_[name].dims) - - def _try_get_shape(self, node, idx): - """Attempts to retrieve the shape of the input node at the specified index if available.""" - if idx > len(node.input) - 1: - return None - name = node.input[idx] - if name in self.known_vi_: - vi = self.known_vi_[name] - return get_shape_from_value_info(vi) - if name in self.initializers_: - return list(self.initializers_[name].dims) - return None - - def _get_shape_rank(self, node, idx): - """Return the rank (number of dimensions) of the shape of the input tensor at the specified index for a given - node. - """ - return len(self._get_shape(node, idx)) - - def _get_sympy_shape(self, node, idx): - """Return the symbolic shape dimensions using SymPy for the given input tensor at the specified index for a - node. - """ - sympy_shape = [] - for d in self._get_shape(node, idx): - if type(d) == str: - sympy_shape.append( - self.symbolic_dims_[d] - if d in self.symbolic_dims_ - else sympy.Symbol(d, integer=True, nonnegative=True) - ) - else: - assert None is not d - sympy_shape.append(d) - return sympy_shape - - def _get_value(self, node, idx): - """Retrieve the value associated with a node's input index from sympy_data_ or initializers_.""" - name = node.input[idx] - assert name in self.sympy_data_ or name in self.initializers_ - return ( - self.sympy_data_[name] - if name in self.sympy_data_ - else numpy_helper.to_array(self.initializers_[name]) - ) - - def _try_get_value(self, node, idx): - """Try to retrieve the value associated with a node's input index from sympy_data_ or initializers_.""" - if idx >= len(node.input): - return None - name = node.input[idx] - if name in self.sympy_data_ or name in self.initializers_: - return self._get_value(node, idx) - return None - - def _update_computed_dims(self, new_sympy_shape): - """Update dimensions in new_sympy_shape based on suggested merges and computational expressions.""" - for i, new_dim in enumerate(new_sympy_shape): - if not is_literal(new_dim) and type(new_dim) != str: - str_dim = str(new_dim) - if str_dim in self.suggested_merge_: - if not is_literal(self.suggested_merge_[str_dim]): - new_sympy_shape[i] = self.symbolic_dims_[ - self.suggested_merge_[str_dim] - ] - elif str_dim not in self.symbolic_dims_: - self.symbolic_dims_[str_dim] = new_dim - - def _onnx_infer_single_node(self, node): - """Performs ONNX shape inference for a single node, skipping inference for specified operation types.""" - skip_infer = node.op_type in { - "If", - "Loop", - "Scan", - "SplitToSequence", - "ZipMap", # contrib ops - "Attention", - "BiasGelu", - "EmbedLayerNormalization", - "FastGelu", - "Gelu", - "GemmFastGelu", - "LayerNormalization", - "LongformerAttention", - "DequantizeLinear", - "QuantizeLinear", - "RelativePositionBias", - "RemovePadding", - "RestorePadding", - "SimplifiedLayerNormalization", - "SkipLayerNormalization", - "SkipSimplifiedLayerNormalization", - "PackedAttention", - "PythonOp", - "MultiHeadAttention", - "GroupNorm", - "SkipGroupNorm", - "BiasSplitGelu", - "BiasAdd", - "NhwcConv", - "QuickGelu", - "RotaryEmbedding", - } - - if not skip_infer: - # Only pass initializers that satisfy the following condition: - # (1) Operator need value of some input for shape inference. - # For example, Unsqueeze in opset 13 uses the axes input to calculate shape of output. - # (2) opset version >= 9. In older version, initializer is required in graph input by onnx spec. - # (3) The initializer is not in graph input. The means the node input is "constant" in inference. - initializers = [] - if (get_opset(self.out_mp_) >= 9) and node.op_type == "Unsqueeze": - initializers = [ - self.initializers_[name] - for name in node.input - if (name in self.initializers_ and name not in self.graph_inputs_) - ] - - if ( - node.op_type - in { - "Add", - "Sub", - "Mul", - "Div", - "MatMul", - "MatMulInteger", - "MatMulInteger16", - "Where", - "Sum", - } - and node.output[0] in self.known_vi_ - ): - vi = self.known_vi_[node.output[0]] - out_rank = len(get_shape_from_type_proto(vi.type)) - in_shapes = [self._get_shape(node, i) for i in range(len(node.input))] - for d in range( - out_rank - - ( - 2 - if node.op_type in {"MatMul", "MatMulInteger", "MatMulInteger16"} - else 0 - ) - ): - in_dims = [ - s[len(s) - out_rank + d] for s in in_shapes if len(s) + d >= out_rank - ] - if len(in_dims) > 1: - self._check_merged_dims(in_dims, allow_broadcast=True) - - # run single node inference with self.known_vi_ shapes - tmp_graph = helper.make_graph( - [node], - "tmp", - [self.known_vi_[i] for i in node.input if i], - [make_named_value_info(i) for i in node.output], - initializers, - ) - self.tmp_mp_.graph.CopyFrom(tmp_graph) - - self.tmp_mp_ = shape_inference.infer_shapes(self.tmp_mp_) - - for i_o in range(len(node.output)): - o = node.output[i_o] - if o: # skip optional output - vi = self.out_mp_.graph.value_info.add() - if not skip_infer: - vi.CopyFrom(self.tmp_mp_.graph.output[i_o]) - else: - vi.name = o - self.known_vi_[o] = vi - - def _onnx_infer_subgraph(self, node, subgraph, use_node_input=True, inc_subgraph_id=True): - """Infer shapes and types within a subgraph for a given ONNX node using temporary graphs and known value - information. - """ - if self.verbose_ > 2: - logger.debug( - f"Inferencing subgraph of node {node.name} with output({node.output[0]}...): {node.op_type}" - ) - # node inputs are not passed directly to the subgraph - # it's up to the node dispatcher to prepare subgraph input - # for example, with Scan/Loop, subgraph input shape would be trimmed from node input shape - # besides, inputs in subgraph could shadow implicit inputs - subgraph_inputs = {i.name for i in list(subgraph.initializer) + list(subgraph.input)} - subgraph_implicit_input = { - name for name in self.known_vi_ if name not in subgraph_inputs - } - tmp_graph = helper.make_graph( - list(subgraph.node), - "tmp", - list(subgraph.input) + [self.known_vi_[i] for i in subgraph_implicit_input], - [make_named_value_info(i.name) for i in subgraph.output], - ) - tmp_graph.initializer.extend( - [i for i in self.out_mp_.graph.initializer if i.name in subgraph_implicit_input] - ) - tmp_graph.initializer.extend(subgraph.initializer) - self.tmp_mp_.graph.CopyFrom(tmp_graph) - - symbolic_shape_inference = SymbolicShapeInference( - self.int_max_, - self.auto_merge_, - self.guess_output_rank_, - self.verbose_, - prefix=f"{self.prefix_}_{self.subgraph_id_!s}", - ) - if inc_subgraph_id: - self.subgraph_id_ += 1 - - symbolic_shape_inference._preprocess(self.tmp_mp_) - symbolic_shape_inference.suggested_merge_ = self.suggested_merge_.copy() - while symbolic_shape_inference.run_: - symbolic_shape_inference._infer_impl(self.sympy_data_.copy()) - symbolic_shape_inference._update_output_from_vi() - if use_node_input: - # if subgraph uses node input, it needs to update to merged dims - subgraph.ClearField("input") - subgraph.input.extend( - symbolic_shape_inference.out_mp_.graph.input[: len(node.input)] - ) - subgraph.ClearField("output") - subgraph.output.extend(symbolic_shape_inference.out_mp_.graph.output) - subgraph.ClearField("value_info") - subgraph.value_info.extend(symbolic_shape_inference.out_mp_.graph.value_info) - subgraph.ClearField("node") - subgraph.node.extend(symbolic_shape_inference.out_mp_.graph.node) - # for new symbolic dims from subgraph output, add to main graph symbolic dims - subgraph_shapes = [ - get_shape_from_value_info(o) for o in symbolic_shape_inference.out_mp_.graph.output - ] - subgraph_new_symbolic_dims = { - d - for s in subgraph_shapes - if s - for d in s - if type(d) == str and d not in self.symbolic_dims_ - } - new_dims = {} - for d in subgraph_new_symbolic_dims: - assert d in symbolic_shape_inference.symbolic_dims_ - new_dims[d] = symbolic_shape_inference.symbolic_dims_[d] - self.symbolic_dims_.update(new_dims) - return symbolic_shape_inference - - def _get_int_or_float_values(self, node, broadcast=False, allow_float_values=False): - """Extracts integer or float values from a node, with options for broadcasting and allowing float values.""" - - def int_or_float(value, allow_float_values): - """Converts a value to an integer unless precision loss occurs and allow_float_values is True.""" - return value if allow_float_values and value % 1 != 0 else int(value) - - values = [self._try_get_value(node, i) for i in range(len(node.input))] - if all(v is not None for v in values): - # some shape compute is in floating point, cast to int for sympy - for i, v in enumerate(values): - if type(v) != np.ndarray: - continue - if len(v.shape) > 1: - new_v = None # ignore value for rank > 1 - elif len(v.shape) == 0: - new_v = int_or_float(v.item(), allow_float_values) - else: - assert len(v.shape) == 1 - new_v = [int_or_float(vv, allow_float_values) for vv in v] - values[i] = new_v - values_len = [len(v) if isinstance(v, list) else 0 for v in values] - max_len = max(values_len) - if max_len >= 1 and broadcast: - # broadcast - for i, v in enumerate(values): - if v is None: - continue # don't broadcast if value is unknown - if isinstance(v, list): - if len(v) < max_len: - values[i] = v * max_len - else: - assert len(v) == max_len - else: - values[i] = [v] * max_len - return values - - def _compute_on_sympy_data(self, node, op_func): - """Calculate the result using Sympy data and a specified operation function.""" - assert len(node.output) == 1 - - # Before mul & div operations - # cast inputs into integer might lose decimal part and reduce precision - # keep them as float, finish the operation, then cast the result into integer - if node.op_type in {"Mul", "Div"}: - values = self._get_int_or_float_values( - node, broadcast=True, allow_float_values=True - ) - else: - values = self._get_int_or_float_values(node, broadcast=True) - - if all(v is not None for v in values): - is_list = [isinstance(v, list) for v in values] - as_list = any(is_list) - if as_list: - self.sympy_data_[node.output[0]] = [op_func(vs) for vs in zip(*values)] - else: - self.sympy_data_[node.output[0]] = op_func(values) - - def _pass_on_sympy_data(self, node): - """Pass Sympy data through a node, validating input length or node operation type 'Reshape', 'Unsqueeze', - 'Squeeze'. - """ - assert len(node.input) == 1 or node.op_type in { - "Reshape", - "Unsqueeze", - "Squeeze", - } - self._compute_on_sympy_data(node, lambda x: x[0]) - - def _pass_on_shape_and_type(self, node): - """Propagates the shape and type information from input to output for a given node.""" - vi = self.known_vi_[node.output[0]] - vi.CopyFrom( - helper.make_tensor_value_info( - node.output[0], - get_elem_type_from_type_proto(self.known_vi_[node.input[0]].type), - self._get_shape(node, 0), - ) - ) - - def _new_symbolic_dim(self, prefix, dim): - """Create and return a new symbolic dimension, handling literal values and caching for repeated uses.""" - new_dim = f"{prefix}_d{dim}" - if new_dim in self.suggested_merge_: - v = self.suggested_merge_[new_dim] - new_symbolic_dim = sympy.Integer(int(v)) if is_literal(v) else v - else: - new_symbolic_dim = sympy.Symbol(new_dim, integer=True, nonnegative=True) - self.symbolic_dims_[new_dim] = new_symbolic_dim - return new_symbolic_dim - - def _new_symbolic_dim_from_output(self, node, out_idx=0, dim=0): - """Generates a new symbolic dimension for a given node's output using the node's operation type, prefix, and - output index. - """ - return self._new_symbolic_dim( - f"{node.op_type}{self.prefix_}_{list(self.out_mp_.graph.node).index(node)}_o{out_idx}_", - dim, - ) - - def _new_symbolic_shape(self, rank, node, out_idx=0): - """Generate a new symbolic shape for a node output based on its rank and index.""" - return [self._new_symbolic_dim_from_output(node, out_idx, i) for i in range(rank)] - - def _compute_conv_pool_shape(self, node, channels_last=False): - """Calculate the output shape of a convolutional or pooling layer node, optionally considering channels_last - format. - """ - sympy_shape = self._get_sympy_shape(node, 0) - if len(node.input) > 1: - W_shape = self._get_sympy_shape(node, 1) - rank = len(W_shape) - 2 # number of spatial axes - kernel_shape = W_shape[-rank - 1 : -1] if channels_last else W_shape[-rank:] - sympy_shape[3 if channels_last else 1] = W_shape[0] - else: - W_shape = None - kernel_shape = get_attribute(node, "kernel_shape") - rank = len(kernel_shape) - - assert len(sympy_shape) == rank + 2 - - # only need to symbolic shape inference if input has symbolic dims in spatial axes - spatial_shape = sympy_shape[-rank - 1 : -1] if channels_last else sympy_shape[-rank:] - is_symbolic_dims = [not is_literal(i) for i in spatial_shape] - - if not any(is_symbolic_dims): - shape = get_shape_from_value_info(self.known_vi_[node.output[0]]) - if len(shape) > 0: - assert len(sympy_shape) == len(shape) - if channels_last: - sympy_shape[-rank - 1 : -1] = [ - sympy.Integer(d) for d in shape[-rank - 1 : -1] - ] - else: - sympy_shape[-rank:] = [sympy.Integer(d) for d in shape[-rank:]] - return sympy_shape - - dilations = get_attribute(node, "dilations", [1] * rank) - strides = get_attribute(node, "strides", [1] * rank) - effective_kernel_shape = [(k - 1) * d + 1 for k, d in zip(kernel_shape, dilations)] - pads = get_attribute(node, "pads") - if pads is None: - pads = [0] * (2 * rank) - auto_pad = get_attribute(node, "auto_pad", b"NOTSET").decode("utf-8") - if auto_pad not in {"VALID", "NOTSET"}: - try: - residual = [sympy.Mod(d, s) for d, s in zip(sympy_shape[-rank:], strides)] - total_pads = [ - max(0, (k - s) if r == 0 else (k - r)) - for k, s, r in zip(effective_kernel_shape, strides, residual) - ] - except ( - TypeError - ): # sympy may throw TypeError: cannot determine truth value of Relational - total_pads = [ - max(0, (k - s)) for k, s in zip(effective_kernel_shape, strides) - ] # assuming no residual if sympy throws error - elif auto_pad == "VALID": - total_pads = [] - else: - total_pads = [0] * rank - else: - assert len(pads) == 2 * rank - total_pads = [p1 + p2 for p1, p2 in zip(pads[:rank], pads[rank:])] - - ceil_mode = get_attribute(node, "ceil_mode", 0) - for i in range(rank): - effective_input_size = sympy_shape[-rank + i + (-1 if channels_last else 0)] - if len(total_pads) > 0: - effective_input_size = effective_input_size + total_pads[i] - if ceil_mode: - strided_kernel_positions = sympy.ceiling( - (effective_input_size - effective_kernel_shape[i]) / strides[i] - ) - else: - strided_kernel_positions = ( - effective_input_size - effective_kernel_shape[i] - ) // strides[i] - sympy_shape[-rank + i + (-1 if channels_last else 0)] = ( - strided_kernel_positions + 1 - ) - return sympy_shape - - def _check_merged_dims(self, dims, allow_broadcast=True): - """Checks merged dimensions for consistency, optionally allowing broadcasting.""" - if allow_broadcast: - dims = [d for d in dims if not (is_literal(d) and int(d) <= 1)] - if any(d != dims[0] for d in dims): - self._add_suggested_merge(dims, apply=True) - - def _compute_matmul_shape(self, node, output_dtype=None): - """Compute the output shape for a matrix multiplication operation based on input shapes and optionally infer the - output data type. - """ - lhs_shape = self._get_shape(node, 0) - rhs_shape = self._get_shape(node, 1) - lhs_rank = len(lhs_shape) - rhs_rank = len(rhs_shape) - lhs_reduce_dim = 0 - rhs_reduce_dim = 0 - assert lhs_rank > 0 and rhs_rank > 0 - if lhs_rank == 1 and rhs_rank == 1: - new_shape = [] - elif lhs_rank == 1: - rhs_reduce_dim = -2 - new_shape = rhs_shape[:rhs_reduce_dim] + [rhs_shape[-1]] - elif rhs_rank == 1: - lhs_reduce_dim = -1 - new_shape = lhs_shape[:lhs_reduce_dim] - else: - lhs_reduce_dim = -1 - rhs_reduce_dim = -2 - new_shape = [ - *self._broadcast_shapes(lhs_shape[:-2], rhs_shape[:-2]), - lhs_shape[-2], - rhs_shape[-1], - ] - # merge reduce dim - self._check_merged_dims( - [lhs_shape[lhs_reduce_dim], rhs_shape[rhs_reduce_dim]], - allow_broadcast=False, - ) - if output_dtype is None: - # infer output_dtype from input type when not specified - output_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type - vi = self.known_vi_[node.output[0]] - vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, new_shape)) - - def _fuse_tensor_type(self, node, out_idx, dst_type, src_type): - """Update dst_tensor_type to be compatible with src_tensor_type when dimension mismatches.""" - dst_tensor_type = ( - dst_type.sequence_type.elem_type.tensor_type - if is_sequence(dst_type) - else dst_type.tensor_type - ) - src_tensor_type = ( - src_type.sequence_type.elem_type.tensor_type - if is_sequence(src_type) - else src_type.tensor_type - ) - if dst_tensor_type.elem_type != src_tensor_type.elem_type: - node_id = node.name or node.op_type - raise ValueError( - f"For node {node_id}, dst_tensor_type.elem_type != src_tensor_type.elem_type: " - f"{onnx.onnx_pb.TensorProto.DataType.Name(dst_tensor_type.elem_type)} vs " - f"{onnx.onnx_pb.TensorProto.DataType.Name(src_tensor_type.elem_type)}" - ) - if dst_tensor_type.HasField("shape"): - for di, ds in enumerate(zip(dst_tensor_type.shape.dim, src_tensor_type.shape.dim)): - if ds[0] != ds[1]: - # create a new symbolic dimension for node/out_idx/mismatch dim id in dst_tensor_type for tensor_type - # for sequence_type, clear the dimension - new_dim = onnx.TensorShapeProto.Dimension() - if not is_sequence(dst_type): - new_dim.dim_param = str( - self._new_symbolic_dim_from_output(node, out_idx, di) - ) - dst_tensor_type.shape.dim[di].CopyFrom(new_dim) - else: - dst_tensor_type.CopyFrom(src_tensor_type) - - def _infer_ArrayFeatureExtractor(self, node): - """Infer and update the shape and type information for the ArrayFeatureExtractor node using input data and - indices shapes. - """ - data_shape = self._get_shape(node, 0) - indices_shape = self._get_shape(node, 1) - vi = self.known_vi_[node.output[0]] - vi.CopyFrom( - helper.make_tensor_value_info( - node.output[0], - self.known_vi_[node.input[0]].type.tensor_type.elem_type, - data_shape[:-1] + indices_shape, - ) - ) - - def _infer_symbolic_compute_ops(self, node): - """Handles symbolic computation operations for given node based on predefined functions.""" - funcs = { - "Add": lambda l: l[0] + l[1], - "Div": lambda l: ( - int(l[0] // l[1]) if isinstance(l[0] // l[1], float) else l[0] // l[1] - ), # integer div in sympy - "Equal": lambda l: l[0] == l[1], - "Floor": lambda l: sympy.floor(l[0]), - "Max": lambda l: ( - l[1] - if is_literal(l[0]) and int(l[0]) < -self.int_max_ - else ( - l[0] - if is_literal(l[1]) and int(l[1]) < -self.int_max_ - else sympy.Max(l[0], l[1]) - ) - ), - "Min": lambda l: ( - l[1] - if is_literal(l[0]) and int(l[0]) > self.int_max_ - else ( - l[0] - if is_literal(l[1]) and int(l[1]) > self.int_max_ - else sympy.Min(l[0], l[1]) - ) - ), - "Mul": lambda l: ( - int(l[0] * l[1]) if isinstance(l[0] * l[1], float) else l[0] * l[1] - ), - "Sub": lambda l: l[0] - l[1], - "Where": lambda l: l[1] if l[0] else l[2], - "Neg": lambda l: -l[0], - } - assert node.op_type in funcs - self._compute_on_sympy_data(node, funcs[node.op_type]) - - def _infer_Cast(self, node): - """Pass node's data to SymPy representation without alteration.""" - self._pass_on_sympy_data(node) - - def _infer_CategoryMapper(self, node): - """Infer and set output tensor type for ONNX CategoryMapper nodes based on input tensor type.""" - input_type = self.known_vi_[node.input[0]].type.tensor_type.elem_type - if input_type == onnx.TensorProto.STRING: - output_type = onnx.TensorProto.INT64 - else: - output_type = onnx.TensorProto.STRING - vi = self.known_vi_[node.output[0]] - vi.CopyFrom( - helper.make_tensor_value_info( - node.output[0], output_type, self._get_shape(node, 0) - ) - ) - - def _infer_Compress(self, node): - """Infer the output shape and type for the Compress operation based on input shape and axis attribute.""" - input_shape = self._get_shape(node, 0) - # create a new symbolic dimension for Compress output - compress_len = str(self._new_symbolic_dim_from_output(node)) - axis = get_attribute(node, "axis") - if axis is None: - # when axis is not specified, input is flattened before compress so output is 1D - output_shape = [compress_len] - else: - output_shape = input_shape - output_shape[handle_negative_axis(axis, len(input_shape))] = compress_len - vi = self.known_vi_[node.output[0]] - vi.CopyFrom( - helper.make_tensor_value_info( - node.output[0], - self.known_vi_[node.input[0]].type.tensor_type.elem_type, - output_shape, - ) - ) - - def _infer_Concat(self, node): - """Infer the output shape and type for the Concat operation based on input node values.""" - if any(i in self.sympy_data_ or i in self.initializers_ for i in node.input): - values = self._get_int_or_float_values(node) - if all(v is not None for v in values): - assert get_attribute(node, "axis") == 0 - self.sympy_data_[node.output[0]] = [] - for i in range(len(node.input)): - value = values[i] - if isinstance(value, list): - self.sympy_data_[node.output[0]].extend(value) - else: - self.sympy_data_[node.output[0]].append(value) - - sympy_shape = self._get_sympy_shape(node, 0) - axis = handle_negative_axis(get_attribute(node, "axis"), len(sympy_shape)) - for i_idx in range(1, len(node.input)): - input_shape = self._get_sympy_shape(node, i_idx) - if input_shape: - sympy_shape[axis] = sympy_shape[axis] + input_shape[axis] - self._update_computed_dims(sympy_shape) - # merge symbolic dims for non-concat axes - for d in range(len(sympy_shape)): - if d == axis: - continue - dims = [ - self._get_shape(node, i_idx)[d] - for i_idx in range(len(node.input)) - if self._get_shape(node, i_idx) - ] - if all(d == dims[0] for d in dims): - continue - merged = self._merge_symbols(dims) - if type(merged) == str: - sympy_shape[d] = self.symbolic_dims_[merged] if merged else None - else: - sympy_shape[d] = merged - vi = self.known_vi_[node.output[0]] - vi.CopyFrom( - helper.make_tensor_value_info( - node.output[0], - self.known_vi_[node.input[0]].type.tensor_type.elem_type, - get_shape_from_sympy_shape(sympy_shape), - ) - ) - - def _infer_ConcatFromSequence(self, node): - """Infers the output shape and type info for ConcatFromSequence operation in a computational graph node.""" - seq_shape = self._get_shape(node, 0) - new_axis = 1 if get_attribute(node, "new_axis") else 0 - axis = handle_negative_axis(get_attribute(node, "axis"), len(seq_shape) + new_axis) - concat_dim = str(self._new_symbolic_dim_from_output(node, 0, axis)) - new_shape = seq_shape - if new_axis: - new_shape = seq_shape[:axis] + [concat_dim] + seq_shape[axis:] - else: - new_shape[axis] = concat_dim - vi = self.known_vi_[node.output[0]] - vi.CopyFrom( - helper.make_tensor_value_info( - node.output[0], - self.known_vi_[ - node.input[0] - ].type.sequence_type.elem_type.tensor_type.elem_type, - new_shape, - ) - ) - - def _infer_Constant(self, node): - """Infer the constant value for a given node and store it in sympy_data_.""" - t = get_attribute(node, "value") - self.sympy_data_[node.output[0]] = numpy_helper.to_array(t) - - def _infer_ConstantOfShape(self, node): - """Infer the constant tensor of a given shape from a node and update sympy_data_.""" - sympy_shape = self._get_int_or_float_values(node)[0] - vi = self.known_vi_[node.output[0]] - if sympy_shape is not None: - if type(sympy_shape) != list: - sympy_shape = [sympy_shape] - self._update_computed_dims(sympy_shape) - # update sympy data if output type is int, and shape is known - if vi.type.tensor_type.elem_type == onnx.TensorProto.INT64 and all( - is_literal(x) for x in sympy_shape - ): - self.sympy_data_[node.output[0]] = np.ones( - [int(x) for x in sympy_shape], dtype=np.int64 - ) * numpy_helper.to_array(get_attribute(node, "value", 0)) - else: - # create new dynamic shape - # note input0 is a 1D vector of shape, the new symbolic shape has the rank of the shape vector length - sympy_shape = self._new_symbolic_shape(self._get_shape(node, 0)[0], node) - - vi.CopyFrom( - helper.make_tensor_value_info( - node.output[0], - vi.type.tensor_type.elem_type, - get_shape_from_sympy_shape(sympy_shape), - ) - ) - - def _infer_Conv(self, node): - """Infers the shape of the output tensor for a convolution operation node and updates the known value info.""" - sympy_shape = self._compute_conv_pool_shape(node) - self._update_computed_dims(sympy_shape) - vi = self.known_vi_[node.output[0]] - vi.CopyFrom( - helper.make_tensor_value_info( - node.output[0], - vi.type.tensor_type.elem_type, - get_shape_from_sympy_shape(sympy_shape), - ) - ) - - def _infer_NhwcConv(self, node): - """Infer the shape of the output tensor for a convolution operation with NHWC format.""" - sympy_shape = self._compute_conv_pool_shape(node, channels_last=True) - self._update_computed_dims(sympy_shape) - vi = self.known_vi_[node.output[0]] - vi.CopyFrom( - helper.make_tensor_value_info( - node.output[0], - self.known_vi_[node.input[0]].type.tensor_type.elem_type, - get_shape_from_sympy_shape(sympy_shape), - ) - ) - - def _infer_DequantizeLinear(self, node): - """Infer output type and shape for the DequantizeLinear node based on input 1's scale data type.""" - output_dtype = self.known_vi_[node.input[1]].type.tensor_type.elem_type - - # Get the output shape from the first input. - output_shape = self._get_shape(node, 0) - - vi = self.known_vi_[node.output[0]] - vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, output_shape)) - - def _infer_QuantizeLinear(self, node): - """Infer the output data type and shape for the QuantizeLinear ONNX node, defaulting to uint8 if not - specified. - """ - # Otherwise, default to uint8 - output_dtype = onnx.TensorProto.UINT8 - if len(node.input) > 2 and node.input[2]: - output_dtype = self.known_vi_[node.input[2]].type.tensor_type.elem_type - - # Get the output shape from the first input. - output_shape = self._get_shape(node, 0) - - vi = self.known_vi_[node.output[0]] - vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, output_shape)) - - def _infer_Einsum(self, node): - """Infer the output shape and type for the Einsum operation as per ONNX standards: https://github.com/onnx/onnx/blob/623dfaa/onnx/defs/math/defs.cc#L3275.""" - equation = get_attribute(node, "equation") - equation = equation.replace(b" ", b"") - mid_index = equation.find(b"->") - left_equation = equation[:mid_index] if mid_index != -1 else equation - - num_operands = 0 - num_ellipsis = 0 - num_ellipsis_indices = 0 - - letter_to_dim = {} - - terms = left_equation.split(b",") - for term in terms: - ellipsis_index = term.find(b"...") - shape = self._get_shape(node, num_operands) - rank = len(shape) - if ellipsis_index != -1: - if num_ellipsis == 0: - num_ellipsis_indices = rank - len(term) + 3 - num_ellipsis = num_ellipsis + 1 - for i in range(1, rank + 1): - letter = term[-i] - if letter != 46: # letter != b'.' - dim = shape[-i] - if letter not in letter_to_dim or type(dim) != sympy.Symbol: - letter_to_dim[letter] = dim - num_operands = num_operands + 1 - - new_sympy_shape = [] - from collections import OrderedDict - - num_letter_occurrences = OrderedDict() - if mid_index != -1: - right_equation = equation[mid_index + 2 :] - right_ellipsis_index = right_equation.find(b"...") - if right_ellipsis_index != -1: - for i in range(num_ellipsis_indices): - new_sympy_shape.append(shape[i]) - for c in right_equation: - if c != 46: # c != b'.' - new_sympy_shape.append(letter_to_dim[c]) - else: - for i in range(num_ellipsis_indices): - new_sympy_shape.append(shape[i]) - for c in left_equation: - if c not in {44, 46}: # c != b',' and c != b'.': - if c in num_letter_occurrences: - num_letter_occurrences[c] = num_letter_occurrences[c] + 1 - else: - num_letter_occurrences[c] = 1 - for key, value in num_letter_occurrences.items(): - if value == 1: - new_sympy_shape.append(letter_to_dim[key]) - - output_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type - vi = self.known_vi_[node.output[0]] - vi.CopyFrom( - helper.make_tensor_value_info(node.output[0], output_dtype, new_sympy_shape) - ) - - def _infer_Expand(self, node): - """Infers and updates the output shape for the Expand operation based on broadcasted input shapes.""" - expand_to_shape = as_list(self._try_get_value(node, 1), keep_none=True) - if expand_to_shape is not None: - # new_shape's dim can come from shape value - self._update_computed_dims(expand_to_shape) - shape = self._get_shape(node, 0) - new_shape = self._broadcast_shapes( - shape, get_shape_from_sympy_shape(expand_to_shape) - ) - vi = self.known_vi_[node.output[0]] - vi.CopyFrom( - helper.make_tensor_value_info( - node.output[0], - self.known_vi_[node.input[0]].type.tensor_type.elem_type, - new_shape, - ) - ) - - def _infer_Gather(self, node): - """Infer the output shape of the Gather operation based on the input data and indices shapes.""" - data_shape = self._get_shape(node, 0) - axis = handle_negative_axis(get_attribute(node, "axis", 0), len(data_shape)) - indices_shape = self._get_shape(node, 1) - vi = self.known_vi_[node.output[0]] - vi.CopyFrom( - helper.make_tensor_value_info( - node.output[0], - self.known_vi_[node.input[0]].type.tensor_type.elem_type, - data_shape[:axis] + indices_shape + data_shape[axis + 1 :], - ) - ) - # for 1D input, do some sympy compute - if ( - node.input[0] in self.sympy_data_ - and len(data_shape) == 1 - and get_attribute(node, "axis", 0) == 0 - ): - idx = self._try_get_value(node, 1) - if idx is not None: - data = self.sympy_data_[node.input[0]] - if type(data) == list: - if type(idx) == np.ndarray and len(idx.shape) == 1: - self.sympy_data_[node.output[0]] = [data[int(i)] for i in idx] - else: - self.sympy_data_[node.output[0]] = data[int(idx)] - else: - assert idx in {0, -1} - self.sympy_data_[node.output[0]] = data - - def _infer_GatherElements(self, node): - """Infers the output shape and type for the GatherElements node based on input tensors and updates the node's - value information. - """ - indices_shape = self._get_shape(node, 1) - vi = self.known_vi_[node.output[0]] - vi.CopyFrom( - helper.make_tensor_value_info( - node.output[0], - self.known_vi_[node.input[0]].type.tensor_type.elem_type, - indices_shape, - ) - ) - - def _infer_GatherND(self, node): - """Infers the output shape and type for the GatherND operation based on input data and indices shapes.""" - data_shape = self._get_shape(node, 0) - data_rank = len(data_shape) - indices_shape = self._get_shape(node, 1) - len(indices_shape) - last_index_dimension = indices_shape[-1] - batch_dims = get_attribute(node, "batch_dims", 0) - assert ( - is_literal(last_index_dimension) - and is_literal(batch_dims) - and (batch_dims + last_index_dimension) <= data_rank - ) - new_shape = indices_shape[:-1] + data_shape[batch_dims + last_index_dimension :] - vi = self.known_vi_[node.output[0]] - vi.CopyFrom( - helper.make_tensor_value_info( - node.output[0], - self.known_vi_[node.input[0]].type.tensor_type.elem_type, - new_shape, - ) - ) - - def _infer_If(self, node): - """Infer the output shape for an If node, handling constant conditions to ensure shape consistency between - branches. - """ - subgraphs = [ - get_attribute(node, "then_branch"), - get_attribute(node, "else_branch"), - ] - cond = self._try_get_value(node, 0) - if cond is not None: - if as_scalar(cond) > 0: - subgraphs[1].CopyFrom(subgraphs[0]) - else: - subgraphs[0].CopyFrom(subgraphs[1]) - - for i_sub, subgraph in enumerate(subgraphs): - subgraph_infer = self._onnx_infer_subgraph(node, subgraph, use_node_input=False) - for i_out in range(len(node.output)): - vi = self.known_vi_[node.output[i_out]] - if i_sub == 0: - vi.CopyFrom(subgraph.output[i_out]) - vi.name = node.output[i_out] - else: - self._fuse_tensor_type(node, i_out, vi.type, subgraph.output[i_out].type) - - # pass on sympy data from subgraph, if cond is constant - if ( - cond is not None - and i_sub == (0 if as_scalar(cond) > 0 else 1) - and subgraph.output[i_out].name in subgraph_infer.sympy_data_ - ): - self.sympy_data_[vi.name] = subgraph_infer.sympy_data_[ - subgraph.output[i_out].name - ] - - def _infer_Loop(self, node): - """Infer the shape and type of variables produced by the 'Loop' operation in an ONNX graph.""" - subgraph = get_attribute(node, "body") - assert len(subgraph.input) == len(node.input) - num_loop_carried = len(node.input) - 2 # minus the length and initial loop condition - # when sequence_type is used as loop carried input - # needs to run subgraph infer twice if the tensor shape in sequence contains None - for i, si in enumerate(subgraph.input): - si_name = si.name - si.CopyFrom(self.known_vi_[node.input[i]]) - si.name = si_name - - self._onnx_infer_subgraph(node, subgraph) - - # check subgraph input/output for shape changes in loop carried variables - # for tensor_type, create new symbolic dim when changing, i.e., output = Concat(input, a) - # for sequence_type, propagate from output to input - need_second_infer = False - for i_out in range(1, num_loop_carried + 1): - so = subgraph.output[i_out] - so_shape = get_shape_from_value_info(so) - if is_sequence(so.type): - if so_shape and None in so_shape: - # copy shape from output to input - # note that loop input is [loop_len, cond, input_0, input_1, ...] - # while loop output is [cond, output_0, output_1, ...] - subgraph.input[i_out + 1].type.sequence_type.elem_type.CopyFrom( - so.type.sequence_type.elem_type - ) - need_second_infer = True - else: - si = subgraph.input[i_out + 1] - si_shape = get_shape_from_value_info(si) - for di, dims in enumerate(zip(si_shape, so_shape)): - if dims[0] != dims[1]: - new_dim = onnx.TensorShapeProto.Dimension() - new_dim.dim_param = str( - self._new_symbolic_dim_from_output(node, i_out, di) - ) - si.type.tensor_type.shape.dim[di].CopyFrom(new_dim) - so.type.tensor_type.shape.dim[di].CopyFrom(new_dim) - need_second_infer = True - - if need_second_infer: - if self.verbose_ > 2: - logger.debug( - f"Rerun Loop: {node.name}({node.output[0]}...), because of sequence in loop carried variables" - ) - self._onnx_infer_subgraph(node, subgraph, inc_subgraph_id=False) - - # create a new symbolic dimension for iteration dependent dimension - loop_iter_dim = str(self._new_symbolic_dim_from_output(node)) - for i in range(len(node.output)): - vi = self.known_vi_[node.output[i]] - vi.CopyFrom( - subgraph.output[i + 1] - ) # first subgraph output is condition, not in node output - if i >= num_loop_carried: - assert not is_sequence( - vi.type - ) # TODO: handle loop accumulation in sequence_type - subgraph_vi_dim = subgraph.output[i + 1].type.tensor_type.shape.dim - vi.type.tensor_type.shape.ClearField("dim") - vi_dim = vi.type.tensor_type.shape.dim - vi_dim.add().dim_param = loop_iter_dim - vi_dim.extend(list(subgraph_vi_dim)) - vi.name = node.output[i] - - def _infer_MatMul(self, node): - """Infer the output shape of a matrix multiplication node.""" - self._compute_matmul_shape(node) - - def _infer_MatMulInteger(self, node): - """Infer the output shape of an integer matrix multiplication node.""" - self._compute_matmul_shape(node, onnx.TensorProto.INT32) - - def _infer_NonMaxSuppression(self, node): - """Infer the output shape of a NonMaxSuppression node and update the value info.""" - selected = str(self._new_symbolic_dim_from_output(node)) - vi = self.known_vi_[node.output[0]] - vi.CopyFrom( - helper.make_tensor_value_info( - node.output[0], onnx.TensorProto.INT64, [selected, 3] - ) - ) - - def _infer_NonZero(self, node): - """Infer the output shape of a NonZero node and update the value info.""" - input_rank = self._get_shape_rank(node, 0) - # create a new symbolic dimension for NonZero output - nz_len = str(self._new_symbolic_dim_from_output(node, 0, 1)) - vi = self.known_vi_[node.output[0]] - vi.CopyFrom( - helper.make_tensor_value_info( - node.output[0], vi.type.tensor_type.elem_type, [input_rank, nz_len] - ) - ) - - def _infer_OneHot(self, node): - """Infer the shape and type of the output tensor for the OneHot node operation.""" - sympy_shape = self._get_sympy_shape(node, 0) - depth = self._try_get_value(node, 1) - axis = get_attribute(node, "axis", -1) - axis = handle_negative_axis(axis, len(sympy_shape) + 1) - new_shape = get_shape_from_sympy_shape( - sympy_shape[:axis] - + [(depth if is_literal(depth) else self._new_symbolic_dim_from_output(node))] - + sympy_shape[axis:] - ) - vi = self.known_vi_[node.output[0]] - vi.CopyFrom( - helper.make_tensor_value_info( - node.output[0], - self.known_vi_[node.input[2]].type.tensor_type.elem_type, - new_shape, - ) - ) - - def _infer_Pad(self, node): - """Infers the output shape and type for the Pad operation based on ONNX node attributes and opset version.""" - if get_opset(self.out_mp_) <= 10: - pads = get_attribute(node, "pads") - else: - pads = self._try_get_value(node, 1) - - sympy_shape = self._get_sympy_shape(node, 0) - rank = len(sympy_shape) - - if pads is not None: - assert len(pads) == 2 * rank - new_sympy_shape = [ - d + pad_up + pad_down - for d, pad_up, pad_down in zip(sympy_shape, pads[:rank], pads[rank:]) - ] - self._update_computed_dims(new_sympy_shape) - else: - # dynamic pads, create new symbolic dimensions - new_sympy_shape = self._new_symbolic_shape(rank, node) - output_tp = self.known_vi_[node.input[0]].type.tensor_type.elem_type - - vi = self.known_vi_[node.output[0]] - vi.CopyFrom( - helper.make_tensor_value_info( - node.output[0], output_tp, get_shape_from_sympy_shape(new_sympy_shape) - ) - ) - - def _infer_Pool(self, node): - """Infer and update dimensions for pooling layers based on the input node.""" - sympy_shape = self._compute_conv_pool_shape(node) - self._update_computed_dims(sympy_shape) - for o in node.output: - if not o: - continue - vi = self.known_vi_[o] - vi.CopyFrom( - helper.make_tensor_value_info( - o, - vi.type.tensor_type.elem_type, - get_shape_from_sympy_shape(sympy_shape), - ) - ) - - def _infer_aten_bitwise_or(self, node): - """Infers the output shape for Aten bitwise OR operation based on input node shapes.""" - shape0 = self._get_shape(node, 0) - shape1 = self._get_shape(node, 1) - new_shape = self._broadcast_shapes(shape0, shape1) - t0 = self.known_vi_[node.input[0]] - vi = self.known_vi_[node.output[0]] - vi.CopyFrom( - helper.make_tensor_value_info( - node.output[0], t0.type.tensor_type.elem_type, new_shape - ) - ) - - def _infer_aten_diagonal(self, node): - """Infers the shape of the diagonal of a tensor given a node, offset, and dimensions.""" - sympy_shape = self._get_sympy_shape(node, 0) - rank = len(sympy_shape) - offset = self._try_get_value(node, 1) - dim1 = self._try_get_value(node, 2) - dim2 = self._try_get_value(node, 3) - - assert offset is not None and dim1 is not None and dim2 is not None - dim1 = handle_negative_axis(dim1, rank) - dim2 = handle_negative_axis(dim2, rank) - - new_shape = [val for dim, val in enumerate(sympy_shape) if dim not in {dim1, dim2}] - shape1 = sympy_shape[dim1] - shape2 = sympy_shape[dim2] - if offset >= 0: - diag_shape = sympy.Max(0, sympy.Min(shape1, shape2 - offset)) - else: - diag_shape = sympy.Max(0, sympy.Min(shape1 + offset, shape2)) - new_shape.append(diag_shape) - - if node.output[0]: - vi = self.known_vi_[node.output[0]] - vi.CopyFrom( - helper.make_tensor_value_info( - node.output[0], - self.known_vi_[node.input[0]].type.tensor_type.elem_type, - get_shape_from_sympy_shape(new_shape), - ) - ) - - def _infer_aten_multinomial(self, node): - """Infers the output shape and type for the PyTorch multinomial operation in an ONNX graph node.""" - sympy_shape = self._get_sympy_shape(node, 0) - rank = len(sympy_shape) - assert rank in {1, 2} - num_samples = self._try_get_value(node, 1) - di = rank - 1 - last_dim = num_samples or str(self._new_symbolic_dim_from_output(node, 0, di)) - output_shape = sympy_shape[:-1] + [last_dim] - vi = self.known_vi_[node.output[0]] - vi.CopyFrom( - helper.make_tensor_value_info( - node.output[0], - onnx.TensorProto.INT64, - get_shape_from_sympy_shape(output_shape), - ) - ) - - def _infer_aten_pool2d(self, node): - """Infer the output shape of a 2D pooling operation in an ONNX graph node.""" - sympy_shape = self._get_sympy_shape(node, 0) - assert len(sympy_shape) == 4 - sympy_shape[-2:] = [self._new_symbolic_dim_from_output(node, 0, i) for i in {2, 3}] - self._update_computed_dims(sympy_shape) - for i, o in enumerate(node.output): - if not o: - continue - vi = self.known_vi_[o] - elem_type = ( - onnx.TensorProto.INT64 - if i == 1 - else self.known_vi_[node.input[0]].type.tensor_type.elem_type - ) - vi.CopyFrom( - helper.make_tensor_value_info( - o, elem_type, get_shape_from_sympy_shape(sympy_shape) - ) - ) - - def _infer_aten_minmax(self, node): - """Infer the output shape and type for the ATen MinMax operation in an ONNX node.""" - vi = self.known_vi_[node.output[0]] - if len(node.input) == 1: - vi.CopyFrom( - helper.make_tensor_value_info( - node.output[0], - self.known_vi_[node.input[0]].type.tensor_type.elem_type, - [], - ) - ) - else: - assert len(node.input) == 3 - keepdim = self._try_get_value(node, 2) - assert keepdim is not None # can only handle known keepdim case. - dim = self._try_get_value(node, 1) - if dim is None: - rank = self._get_shape_rank(node, 0) - output_shape = self._new_symbolic_shape(rank if keepdim else rank - 1, node) - else: - shape = self._get_sympy_shape(node, 0) - dim = handle_negative_axis(dim, len(shape)) - output_shape = shape[:dim] - if keepdim: - output_shape += [1] - output_shape += shape[dim + 1 :] - - output_shape = get_shape_from_sympy_shape(output_shape) - vi.CopyFrom( - helper.make_tensor_value_info( - node.output[0], - self.known_vi_[node.input[0]].type.tensor_type.elem_type, - output_shape, - ) - ) - vi1 = self.known_vi_[node.output[1]] - vi1.CopyFrom( - helper.make_tensor_value_info( - node.output[1], onnx.TensorProto.INT64, output_shape - ) - ) - - def _infer_aten_unfold(self, node): - """Infer the tensor shape for the 'aten::unfold' operation based on input shape and parameters dimension, size, and step.""" - sympy_shape = self._get_sympy_shape(node, 0) - dimension = self._try_get_value(node, 1) - size = self._try_get_value(node, 2) - step = self._try_get_value(node, 3) - if dimension is not None and size is not None and step is not None: - assert dimension < len(sympy_shape) - sympy_shape[dimension] = (sympy_shape[dimension] - size) // step + 1 - sympy_shape.append(size) - else: - rank = len(sympy_shape) - sympy_shape = self._new_symbolic_shape(rank + 1, node) - self._update_computed_dims(sympy_shape) - if node.output[0]: - vi = self.known_vi_[node.output[0]] - vi.CopyFrom( - helper.make_tensor_value_info( - node.output[0], - self.known_vi_[node.input[0]].type.tensor_type.elem_type, - get_shape_from_sympy_shape(sympy_shape), - ) - ) - - def _infer_aten_argmax(self, node): - """Infers the output shape for the ONNX ATen argmax operation.""" - new_shape = None - if not node.input[1]: - # The argmax of the flattened input is returned. - new_shape = [] - else: - dim = self._try_get_value(node, 1) - keepdim = self._try_get_value(node, 2) - if keepdim is not None: - sympy_shape = self._get_sympy_shape(node, 0) - if dim is not None: - dim = handle_negative_axis(dim, len(sympy_shape)) - if keepdim: - sympy_shape[dim] = 1 - else: - del sympy_shape[dim] - else: - rank = len(sympy_shape) - sympy_shape = self._new_symbolic_shape(rank if keepdim else rank - 1, node) - self._update_computed_dims(sympy_shape) - new_shape = get_shape_from_sympy_shape(sympy_shape) - if node.output[0] and new_shape is not None: - vi = self.known_vi_[node.output[0]] - vi.CopyFrom( - helper.make_tensor_value_info( - node.output[0], onnx.TensorProto.INT64, new_shape - ) - ) - - def _infer_aten_group_norm(self, node): - """Infers the output shapes and types for the ATen GroupNorm operation based on the provided node - information. - """ - self._propagate_shape_and_type(node) - input_shape = self._get_shape(node, 0) - N = input_shape[0] if input_shape is not None and len(input_shape) != 0 else None - group = self._try_get_value(node, 6) - output_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type - for i in {1, 2}: - if node.output[i]: - vi = self.known_vi_[node.output[i]] - vi.CopyFrom( - helper.make_tensor_value_info( - node.output[i], - output_dtype, - [ - ( - N - if N is not None - else str(self._new_symbolic_dim_from_output(node, i, 0)) - ), - ( - as_scalar(group) - if group is not None - else str(self._new_symbolic_dim_from_output(node, i, 1)) - ), - ], - ) - ) - - def _infer_aten_upsample(self, node): - """Infers the output shape for an aten::upsample operation based on the input shape and specified upsampling parameters.""" - new_shape = None - input_shape = self._get_shape(node, 0) - if input_shape is not None: - new_shape = input_shape[:2] - output_size = self._try_get_value(node, 1) - if output_size is not None: - new_shape += [ - dim_size.item() if type(dim_size) == np.int64 else dim_size - for dim_size in output_size - ] - else: - rank = len(input_shape) - new_shape += [ - str(self._new_symbolic_dim_from_output(node, 0, i)) for i in range(2, rank) - ] - if node.output[0] and new_shape is not None: - output_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type - vi = self.known_vi_[node.output[0]] - vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, new_shape)) - - def _infer_BatchNormalization(self, node): - """Propagate the shape and type information for the BatchNormalization node.""" - self._propagate_shape_and_type(node) - - # this works for opsets < 14 and 14 since we check i < len(node.output) in the loop - for i in {1, 2, 3, 4}: - if i < len(node.output) and node.output[i]: - # all of these parameters have the same shape as the 1st input - self._propagate_shape_and_type(node, input_index=1, output_index=i) - - def _infer_Range(self, node): - """Infers the shape and type for Range nodes based on the provided start, limit, and delta values.""" - vi = self.known_vi_[node.output[0]] - input_data = self._get_int_or_float_values(node) - if all(i is not None for i in input_data): - start = as_scalar(input_data[0]) - limit = as_scalar(input_data[1]) - delta = as_scalar(input_data[2]) - new_sympy_shape = [sympy.Max(sympy.ceiling((limit - start) / delta), 0)] - else: - new_sympy_shape = [self._new_symbolic_dim_from_output(node)] - self._update_computed_dims(new_sympy_shape) - vi.CopyFrom( - helper.make_tensor_value_info( - node.output[0], - self.known_vi_[node.input[0]].type.tensor_type.elem_type, - get_shape_from_sympy_shape(new_sympy_shape), - ) - ) - - def _infer_ReduceSum(self, node): - """Infer output shape for ReduceSum operation based on input shape, axes, and keep_dims attribute.""" - keep_dims = get_attribute(node, "keepdims", 1) - if get_opset(self.out_mp_) >= 13 and len(node.input) > 1: - # ReduceSum changes axes to input[1] in opset 13 - axes = self._try_get_value(node, 1) - vi = self.known_vi_[node.output[0]] - if axes is None: - assert keep_dims # can only handle keep_dims==True when axes is unknown, by generating new ranks - vi.CopyFrom( - helper.make_tensor_value_info( - node.output[0], - self.known_vi_[node.input[0]].type.tensor_type.elem_type, - get_shape_from_sympy_shape( - self._new_symbolic_shape(self._get_shape_rank(node, 0), node) - ), - ) - ) - else: - shape = self._get_shape(node, 0) - output_shape = [] - axes = [handle_negative_axis(a, len(shape)) for a in axes] - for i, d in enumerate(shape): - if i in axes: - if keep_dims: - output_shape.append(1) - else: - output_shape.append(d) - vi.CopyFrom( - helper.make_tensor_value_info( - node.output[0], - self.known_vi_[node.input[0]].type.tensor_type.elem_type, - output_shape, - ) - ) - - def _infer_ReduceProd(self, node): - """Infer the ReduceProd operation on a node, considering axes and keep dimensions attributes.""" - axes = get_attribute(node, "axes") - keep_dims = get_attribute(node, "keepdims", 1) - if keep_dims == 0 and axes == [0]: - data = self._get_int_or_float_values(node)[0] - if data is not None: - self.sympy_data_[node.output[0]] = sympy_reduce_product(data) - - def _infer_RelativePositionBias(self, node): - """Infers the relative position bias for a given ONNX node.""" - seq_len = self._try_get_value(node, 1) - real_seq_len = self._try_get_value(node, 2) - if seq_len is None or real_seq_len is None: - return - num_heads = self._get_sympy_shape(node, 0)[1] - - new_shape = [1, num_heads, str(seq_len), str(real_seq_len)] - - output_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type - vi = self.known_vi_[node.output[0]] - vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, new_shape)) - - def _infer_Reshape(self, node): - """Infer the output shape for the Reshape operation based on the provided input shape and reshape parameters.""" - shape_value = self._try_get_value(node, 1) - vi = self.known_vi_[node.output[0]] - if shape_value is None: - shape_shape = self._get_shape(node, 1) - assert len(shape_shape) == 1 - shape_rank = shape_shape[0] - assert is_literal(shape_rank) - vi.CopyFrom( - helper.make_tensor_value_info( - node.output[0], - vi.type.tensor_type.elem_type, - get_shape_from_sympy_shape(self._new_symbolic_shape(shape_rank, node)), - ) - ) - else: - input_sympy_shape = self._get_sympy_shape(node, 0) - total = 1 - for d in input_sympy_shape: - total = total * d - new_sympy_shape = [] - deferred_dim_idx = -1 - non_deferred_size = 1 - for i, d in enumerate(shape_value): - if type(d) == sympy.Symbol or d != 0: - new_sympy_shape.append(d) - else: - new_sympy_shape.append(input_sympy_shape[i]) - non_deferred_size = non_deferred_size * input_sympy_shape[i] - if d == -1: - deferred_dim_idx = i - elif d != 0: - non_deferred_size = non_deferred_size * d - - assert new_sympy_shape.count(-1) < 2 - if -1 in new_sympy_shape: - new_dim = total // non_deferred_size - new_sympy_shape[deferred_dim_idx] = new_dim - - self._update_computed_dims(new_sympy_shape) - vi.CopyFrom( - helper.make_tensor_value_info( - node.output[0], - vi.type.tensor_type.elem_type, - get_shape_from_sympy_shape(new_sympy_shape), - ) - ) - - self._pass_on_sympy_data(node) - - def _infer_Resize(self, node): - """Infers and updates the shape of the output tensor for a Resize node based on scales or sizes.""" - vi = self.known_vi_[node.output[0]] - input_sympy_shape = self._get_sympy_shape(node, 0) - if get_opset(self.out_mp_) <= 10: - scales = self._try_get_value(node, 1) - if scales is not None: - new_sympy_shape = [ - sympy.simplify(sympy.floor(d * s)) - for d, s in zip(input_sympy_shape, scales) - ] - self._update_computed_dims(new_sympy_shape) - vi.CopyFrom( - helper.make_tensor_value_info( - node.output[0], - self.known_vi_[node.input[0]].type.tensor_type.elem_type, - get_shape_from_sympy_shape(new_sympy_shape), - ) - ) - else: - roi = self._try_get_value(node, 1) - scales = self._try_get_value(node, 2) - sizes = self._try_get_value(node, 3) - if sizes is not None: - new_sympy_shape = [sympy.simplify(sympy.floor(s)) for s in sizes] - self._update_computed_dims(new_sympy_shape) - elif scales is not None: - rank = len(scales) - if ( - get_attribute(node, "coordinate_transformation_mode") - == "tf_crop_and_resize" - ): - assert len(roi) == 2 * rank - roi_start = list(roi)[:rank] - roi_end = list(roi)[rank:] - else: - roi_start = [0] * rank - roi_end = [1] * rank - if isinstance(scales, np.ndarray): - scales = scales.tolist() - else: - scales = list(scales) - new_sympy_shape = [ - (sympy.floor(d * (end - start) * scale)) - for d, start, end, scale in zip( - input_sympy_shape, roi_start, roi_end, scales - ) - ] - self._update_computed_dims(new_sympy_shape) - else: - new_sympy_shape = self._new_symbolic_shape(self._get_shape_rank(node, 0), node) - - vi.CopyFrom( - helper.make_tensor_value_info( - node.output[0], - self.known_vi_[node.input[0]].type.tensor_type.elem_type, - get_shape_from_sympy_shape(new_sympy_shape), - ) - ) - - def _infer_Scan(self, node): - """Infer shape and type information for the ONNX 'Scan' operator node.""" - subgraph = get_attribute(node, "body") - num_scan_inputs = get_attribute(node, "num_scan_inputs") - scan_input_axes = get_attribute(node, "scan_input_axes", [0] * num_scan_inputs) - num_scan_states = len(node.input) - num_scan_inputs - scan_input_axes = [ - handle_negative_axis(ax, self._get_shape_rank(node, i + num_scan_states)) - for i, ax in enumerate(scan_input_axes) - ] - # We may have cases where the subgraph has optional inputs that appear in both subgraph's input and initializer, - # but not in the node's input. In such cases, the input model might be invalid, but let's skip those optional inputs. - assert len(subgraph.input) >= len(node.input) - subgraph_inputs = subgraph.input[: len(node.input)] - for i, si in enumerate(subgraph_inputs): - subgraph_name = si.name - si.CopyFrom(self.known_vi_[node.input[i]]) - if i >= num_scan_states: - scan_input_dim = si.type.tensor_type.shape.dim - scan_input_dim.remove(scan_input_dim[scan_input_axes[i - num_scan_states]]) - si.name = subgraph_name - self._onnx_infer_subgraph(node, subgraph) - num_scan_outputs = len(node.output) - num_scan_states - scan_output_axes = get_attribute(node, "scan_output_axes", [0] * num_scan_outputs) - scan_input_dim = get_shape_from_type_proto(self.known_vi_[node.input[-1]].type)[ - scan_input_axes[-1] - ] - for i, o in enumerate(node.output): - vi = self.known_vi_[o] - if i >= num_scan_states: - shape = get_shape_from_type_proto(subgraph.output[i].type) - new_dim = handle_negative_axis( - scan_output_axes[i - num_scan_states], len(shape) + 1 - ) - shape = shape[:new_dim] + [scan_input_dim] + shape[new_dim:] - vi.CopyFrom( - helper.make_tensor_value_info( - o, subgraph.output[i].type.tensor_type.elem_type, shape - ) - ) - else: - vi.CopyFrom(subgraph.output[i]) - vi.name = o - - def _infer_ScatterElements(self, node): - """Infer the output shape and type for ScatterElements node and update known value infos.""" - data_shape = self._get_shape(node, 0) - vi = self.known_vi_[node.output[0]] - vi.CopyFrom( - helper.make_tensor_value_info( - node.output[0], - self.known_vi_[node.input[0]].type.tensor_type.elem_type, - data_shape, - ) - ) - - def _infer_SequenceAt(self, node): - """Infers the shape and type for the output of the 'SequenceAt' ONNX operation, handling symbolic dimensions if - necessary. - """ - seq_shape = self._get_shape(node, 0) - if seq_shape is not None: - vi = self.known_vi_[node.output[0]] - for di, d in enumerate(seq_shape): - if d is not None: - continue - new_dim = onnx.TensorShapeProto.Dimension() - new_dim.dim_param = str(self._new_symbolic_dim_from_output(node, 0, di)) - vi.type.tensor_type.shape.dim[di].CopyFrom(new_dim) - - def _infer_SequenceInsert(self, node): - """Workaround ONNX's shape inference bug by inferring sequence insertion shapes and types for the provided - node. - """ - vi_seq = self.known_vi_[node.input[0]] - vi_tensor = self.known_vi_[node.input[1]] - vi_out_seq = self.known_vi_[node.output[0]] - vi_out_seq.CopyFrom(vi_seq) - vi_out_seq.name = node.output[0] - self._fuse_tensor_type(node, 0, vi_out_seq.type, vi_tensor.type) - - def _infer_Shape(self, node): - """Infers and sets the symbolic shape for the output node in the computation graph.""" - self.sympy_data_[node.output[0]] = self._get_sympy_shape(node, 0) - - def _infer_Size(self, node): - """Infers and sets the size of the output node by computing the product of its shape in the computation - graph. - """ - sympy_shape = self._get_sympy_shape(node, 0) - self.sympy_data_[node.output[0]] = sympy_reduce_product(sympy_shape) - self.known_vi_[node.output[0]].CopyFrom( - helper.make_tensor_value_info(node.output[0], onnx.TensorProto.INT64, []) - ) - - def _infer_Slice(self, node): - """Infer the shape and value information for the Slice node using SymPy and ONNX helper methods.""" - - # even when the relation holds for both `a` and `b`. - # - # When given `expr` of form `min(a, b) + ...`, this function returns `[a + ..., b + ...]`, - # so that we can prove inequalities for both expressions separately. - # - # If the number of `min(...)` subexpressions is not exactly one, this function just returns `[expr]`. - def flatten_min(expr): - """Returns a list with expressions split by min() for inequality proof or original expr if no single min() - found. - """ - assert isinstance(expr, sympy.Add), f"Expected a sum of two arguments, got {expr}" - min_positions = [ - idx for idx in range(len(expr.args)) if isinstance(expr.args[idx], sympy.Min) - ] - if len(min_positions) == 1: - min_pos = min_positions[0] - - def replace_min_with_arg(arg_idx): - """Replace the sympy.Min() function at a specified position in a sympy.Add() expression with one of - its arguments. - """ - replaced = list(expr.args) - assert isinstance(replaced[min_pos], sympy.Min), ( - f"Expected a sympy.Min() at position {min_pos}, got {replaced[min_pos]}" - ) - assert len(replaced[min_pos].args) == 2, ( - f"Expected a sympy.Min() with exactly 2 arguments, got {replaced[min_pos]}" - ) - replaced[min_pos] = replaced[min_pos].args[arg_idx] - return sympy.Add(*replaced) - - return [ - replace_min_with_arg(0), - replace_min_with_arg(1), - ] - return [expr] - - def less_equal(x, y): - """Returns True if x is less than or equal to y, otherwise False.""" - try: - return x <= y - except TypeError: - pass - try: - return y >= x - except TypeError: - pass - try: - return -x >= -y - except TypeError: - pass - try: - return -y <= -x - except TypeError: - pass - try: - return y - x >= 0 - except TypeError: - # the last attempt; this may raise TypeError - return all(d >= 0 for d in flatten_min(y - x)) - - def handle_negative_index(index, bound): - """Normalizes a negative index to be in [0, bound).""" - try: - if not less_equal(0, index): - if is_literal(index) and index <= -self.int_max_: - # this case is handled separately - return index - return bound + index - except TypeError: - logger.warning(f"Cannot determine if {index} < 0") - return index - - if get_opset(self.out_mp_) <= 9: - axes = get_attribute(node, "axes") - starts = get_attribute(node, "starts") - ends = get_attribute(node, "ends") - if not axes: - axes = list(range(len(starts))) - steps = [1] * len(axes) - else: - starts = as_list(self._try_get_value(node, 1), keep_none=True) - ends = as_list(self._try_get_value(node, 2), keep_none=True) - axes = self._try_get_value(node, 3) - steps = self._try_get_value(node, 4) - if axes is None and (starts is not None or ends is not None): - axes = list(range(len(starts if starts is not None else ends))) - if steps is None and (starts is not None or ends is not None): - steps = [1] * len(starts if starts is not None else ends) - axes = as_list(axes, keep_none=True) - steps = as_list(steps, keep_none=True) - - new_sympy_shape = self._get_sympy_shape(node, 0) - if starts is None or ends is None: - if axes is None: - for i in range(len(new_sympy_shape)): - new_sympy_shape[i] = self._new_symbolic_dim_from_output(node, 0, i) - else: - new_sympy_shape = get_shape_from_sympy_shape(new_sympy_shape) - for i in axes: - new_sympy_shape[i] = self._new_symbolic_dim_from_output(node, 0, i) - else: - for i, s, e, t in zip(axes, starts, ends, steps): - if is_literal(e): - e = handle_negative_index(e, new_sympy_shape[i]) - if is_literal(e): - if e >= self.int_max_: - e = new_sympy_shape[i] - elif e <= -self.int_max_: - e = 0 if s > 0 else -1 - elif is_literal(new_sympy_shape[i]): - if e < 0: - e = max(0, e + new_sympy_shape[i]) - e = min(e, new_sympy_shape[i]) - else: - if e > 0: - e = ( - sympy.Min(e, new_sympy_shape[i]) if e > 1 else e - ) # special case for slicing first to make computation easier - else: - if is_literal(new_sympy_shape[i]): - if new_sympy_shape[i] < 0: - e = sympy.Min(e, new_sympy_shape[i]) - else: - try: - if not less_equal(e, new_sympy_shape[i]): - e = new_sympy_shape[i] - except Exception: - logger.warning( - f"Unable to determine if {e} <= {new_sympy_shape[i]}, treat as equal" - ) - e = new_sympy_shape[i] - - s = handle_negative_index(s, new_sympy_shape[i]) - if is_literal(new_sympy_shape[i]) and is_literal(s): - s = max(0, min(s, new_sympy_shape[i])) - - new_sympy_shape[i] = sympy.simplify((e - s + t + (-1 if t > 0 else 1)) // t) - - self._update_computed_dims(new_sympy_shape) - - vi = self.known_vi_[node.output[0]] - vi.CopyFrom( - helper.make_tensor_value_info( - node.output[0], - vi.type.tensor_type.elem_type, - get_shape_from_sympy_shape(new_sympy_shape), - ) - ) - - # handle sympy_data if needed, for slice in shape computation - if ( - node.input[0] in self.sympy_data_ - and axes == [0] - and starts is not None - and len(starts) == 1 - and ends is not None - and len(ends) == 1 - and steps is not None - and len(steps) == 1 - ): - input_sympy_data = self.sympy_data_[node.input[0]] - if type(input_sympy_data) == list or ( - type(input_sympy_data) == np.array and len(input_sympy_data.shape) == 1 - ): - self.sympy_data_[node.output[0]] = input_sympy_data[ - starts[0] : ends[0] : steps[0] - ] - - def _infer_SoftmaxCrossEntropyLoss(self, node): - """Infer the softmax cross-entropy loss for a given node in the computation graph.""" - vi = self.known_vi_[node.output[0]] - elem_type = self.known_vi_[node.input[0]].type.tensor_type.elem_type - - # If output type is explicit specified in attribute, we use it as output tensor type. - specified_output_type = get_attribute(node, "output_type", None) - if specified_output_type is not None: - elem_type = specified_output_type - - vi.type.tensor_type.elem_type = elem_type - vi.type.tensor_type.shape.CopyFrom(onnx.TensorShapeProto()) - - if len(node.output) > 1: - data_shape = self._get_shape(node, 0) - vi = self.known_vi_[node.output[1]] - vi.CopyFrom(helper.make_tensor_value_info(vi.name, elem_type, data_shape)) - - def _infer_Split_Common(self, node, make_value_info_func): - """Infers the output shape for the Split operator given an ONNX node and a function to create tensor value - info. - """ - input_sympy_shape = self._get_sympy_shape(node, 0) - axis = handle_negative_axis(get_attribute(node, "axis", 0), len(input_sympy_shape)) - op_set = get_opset(self.out_mp_) - - # Depending on op-version 'split' are provided as attribute or via 2nd input - if op_set < 13: - split = get_attribute(node, "split") - assert self._try_get_value(node, 1) is None - else: - split = self._try_get_value(node, 1) - assert get_attribute(node, "split") is None - - if split is None: - num_outputs = len(node.output) - split = [input_sympy_shape[axis] / sympy.Integer(num_outputs)] * num_outputs - self._update_computed_dims(split) - else: - split = [sympy.Integer(s) for s in split] - - for i_o in range(len(split)): - vi = self.known_vi_[node.output[i_o]] - vi.CopyFrom( - make_value_info_func( - node.output[i_o], - self.known_vi_[node.input[0]].type.tensor_type.elem_type, - get_shape_from_sympy_shape( - input_sympy_shape[:axis] + [split[i_o]] + input_sympy_shape[axis + 1 :] - ), - ) - ) - self.known_vi_[vi.name] = vi - - def _infer_Split(self, node): - """Infers the output shapes and types for the Split operation node.""" - self._infer_Split_Common(node, helper.make_tensor_value_info) - - def _infer_SplitToSequence(self, node): - """Infers the output shapes and types for the SplitToSequence operation node.""" - self._infer_Split_Common(node, helper.make_sequence_value_info) - - def _infer_Squeeze(self, node): - """Infers the output shapes and types for the Squeeze operation node.""" - input_shape = self._get_shape(node, 0) - op_set = get_opset(self.out_mp_) - - # Depending on op-version 'axes' are provided as attribute or via 2nd input - if op_set < 13: - axes = get_attribute(node, "axes") - assert self._try_get_value(node, 1) is None - else: - axes = self._try_get_value(node, 1) - assert get_attribute(node, "axes") is None - - if axes is None: - # No axes have been provided (neither via attribute nor via input). - # In this case the 'Shape' op should remove all axis with dimension 1. - # For symbolic dimensions we guess they are !=1. - output_shape = [s for s in input_shape if s != 1] - if self.verbose_ > 0: - symbolic_dimensions = [s for s in input_shape if type(s) != int] - if symbolic_dimensions: - logger.debug( - f"Symbolic dimensions in input shape of op: '{node.op_type}' node: '{node.name}'. " - f"Assuming the following dimensions are never equal to 1: {symbolic_dimensions}" - ) - else: - axes = [handle_negative_axis(a, len(input_shape)) for a in axes] - output_shape = [] - for i in range(len(input_shape)): - if i not in axes: - output_shape.append(input_shape[i]) - else: - assert input_shape[i] == 1 or type(input_shape[i]) != int - if self.verbose_ > 0 and type(input_shape[i]) != int: - logger.debug( - f"Symbolic dimensions in input shape of op: '{node.op_type}' node: '{node.name}'. " - f"Assuming the dimension '{input_shape[i]}' at index {i} of the input to be equal to 1." - ) - - vi = self.known_vi_[node.output[0]] - vi.CopyFrom( - helper.make_tensor_value_info( - node.output[0], - self.known_vi_[node.input[0]].type.tensor_type.elem_type, - output_shape, - ) - ) - self._pass_on_sympy_data(node) - - def _infer_Tile(self, node): - """Infers the output shape for the Tile operation in a computation graph based on input shape and repeat - values. - """ - repeats_value = self._try_get_value(node, 1) - new_sympy_shape = [] - if repeats_value is not None: - input_sympy_shape = self._get_sympy_shape(node, 0) - for i, d in enumerate(input_sympy_shape): - new_dim = d * repeats_value[i] - new_sympy_shape.append(new_dim) - self._update_computed_dims(new_sympy_shape) - else: - new_sympy_shape = self._new_symbolic_shape(self._get_shape_rank(node, 0), node) - vi = self.known_vi_[node.output[0]] - vi.CopyFrom( - helper.make_tensor_value_info( - node.output[0], - vi.type.tensor_type.elem_type, - get_shape_from_sympy_shape(new_sympy_shape), - ) - ) - - def _infer_TopK(self, node): - """Infers the output shape for the TopK operation in an ONNX graph node based on input shape and specified - axis. - """ - rank = self._get_shape_rank(node, 0) - axis = handle_negative_axis(get_attribute(node, "axis", -1), rank) - new_shape = self._get_shape(node, 0) - - if get_opset(self.out_mp_) <= 9: - k = get_attribute(node, "k") - else: - k = self._get_int_or_float_values(node)[1] - - k = self._new_symbolic_dim_from_output(node) if k is None else as_scalar(k) - if type(k) in {int, str}: - new_shape[axis] = k - else: - new_sympy_shape = self._get_sympy_shape(node, 0) - new_sympy_shape[axis] = k - self._update_computed_dims( - new_sympy_shape - ) # note that TopK dim could be computed in sympy_data, so need to update computed_dims when it enters shape - new_shape = get_shape_from_sympy_shape(new_sympy_shape) - - for i_o in range(len(node.output)): - vi = self.known_vi_[node.output[i_o]] - vi.CopyFrom( - helper.make_tensor_value_info( - node.output[i_o], vi.type.tensor_type.elem_type, new_shape - ) - ) - - def _infer_Transpose(self, node): - """Infer and update the shape information for a Transpose node based on its input shape and permutation - attributes. - """ - if node.input[0] in self.sympy_data_: - data_shape = self._get_shape(node, 0) - perm = get_attribute(node, "perm", reversed(list(range(len(data_shape))))) - input_data = self.sympy_data_[node.input[0]] - self.sympy_data_[node.output[0]] = ( - np.transpose(np.array(input_data).reshape(*data_shape), axes=tuple(perm)) - .flatten() - .tolist() - ) - - def _infer_Unsqueeze(self, node): - """Infers the output shape for the Unsqueeze operation based on the input shape and operator set.""" - input_shape = self._get_shape(node, 0) - op_set = get_opset(self.out_mp_) - - # Depending on op-version 'axes' are provided as attribute or via 2nd input - if op_set < 13: - axes = get_attribute(node, "axes") - assert self._try_get_value(node, 1) is None - else: - axes = self._try_get_value(node, 1) - assert get_attribute(node, "axes") is None - - output_rank = len(input_shape) + len(axes) - axes = [handle_negative_axis(a, output_rank) for a in axes] - - input_axis = 0 - output_shape = [] - for i in range(output_rank): - if i in axes: - output_shape.append(1) - else: - output_shape.append(input_shape[input_axis]) - input_axis += 1 - - vi = self.known_vi_[node.output[0]] - vi.CopyFrom( - helper.make_tensor_value_info( - node.output[0], - self.known_vi_[node.input[0]].type.tensor_type.elem_type, - output_shape, - ) - ) - - self._pass_on_sympy_data(node) - - def _infer_ZipMap(self, node): - """Infer the type of keys for a ZipMap node based on its class labels attribute.""" - map_key_type = None - if get_attribute(node, "classlabels_int64s") is not None: - map_key_type = onnx.TensorProto.INT64 - elif get_attribute(node, "classlabels_strings") is not None: - map_key_type = onnx.TensorProto.STRING - - assert map_key_type is not None - new_vi = onnx.ValueInfoProto() - new_vi.name = node.output[0] - new_vi.type.sequence_type.elem_type.map_type.value_type.tensor_type.elem_type = ( - onnx.TensorProto.FLOAT - ) - new_vi.type.sequence_type.elem_type.map_type.key_type = map_key_type - vi = self.known_vi_[node.output[0]] - vi.CopyFrom(new_vi) - - def _infer_Attention(self, node): - """Infer shape and data type for ONNX Attention node outputs given input shapes and attributes.""" - shape = self._get_shape(node, 0) - shape_weights = self._get_shape(node, 1) - shape_bias = self._try_get_shape(node, 2) - if shape_bias is not None: - assert len(shape_bias) == 1 - tripled_hidden_size = shape_bias[0] if shape_bias is not None else shape_weights[1] - if shape and len(shape) == 3: - qkv_hidden_sizes_attr = get_attribute(node, "qkv_hidden_sizes") - if qkv_hidden_sizes_attr is not None: - assert len(qkv_hidden_sizes_attr) == 3 - shape[2] = int(qkv_hidden_sizes_attr[2]) - elif isinstance(tripled_hidden_size, int): - shape[2] = int(tripled_hidden_size / 3) - output_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type - vi = self.known_vi_[node.output[0]] - vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, shape)) - - if len(node.output) > 1: - # input shape: (batch_size, sequence_length, hidden_size) - # past shape: (2, batch_size, num_heads, past_sequence_length, head_size) - # mask shape: (batch_size, total_sequence_length) or (batch_size, sequence_length, total_sequence_length) or (batch_size, 1, max_seq_len, max_seq_len) - # present shape: (2, batch_size, num_heads, total_sequence_length, head_size), where total_sequence_length=sequence_length+past_sequence_length - input_shape = self._get_shape(node, 0) - past_shape = ( - self._get_shape(node, 4) if len(node.input) > 4 and node.input[4] else [] - ) - mask_shape = ( - self._get_shape(node, 3) if len(node.input) > 3 and node.input[3] else [] - ) - - if past_shape and len(past_shape) == 5: - if mask_shape and len(mask_shape) in {2, 3}: - past_shape[3] = mask_shape[-1] - elif input_shape and len(input_shape) == 3: - if isinstance(input_shape[1], int) and isinstance(past_shape[3], int): - past_shape[3] = input_shape[1] + past_shape[3] - else: - past_shape[3] = f"{past_shape[3]}+{input_shape[1]}" - vi = self.known_vi_[node.output[1]] - vi.CopyFrom( - helper.make_tensor_value_info(vi.name, output_dtype, past_shape) - ) - else: - num_heads = get_attribute(node, "num_heads") - head_size = input_shape[2] // num_heads - present_shape = [ - 2, - input_shape[0], - num_heads, - input_shape[1], - head_size, - ] - vi = self.known_vi_[node.output[1]] - vi.CopyFrom( - helper.make_tensor_value_info(vi.name, output_dtype, present_shape) - ) - - def _infer_GatedRelativePositionBias(self, node): - """Infer the shape for gated relative position bias given the node attributes.""" - # query_layer: (token_count, num_heads x head_size) - # token_offset: (batch_size, seq_len) - # Otherwise: - # query_layer: (batch_size, seq_len, num_heads x head_size) - # token_offset: None - # Output shape: (batch_size, num_heads, seq_len, seq_len) - num_heads = get_attribute(node, "num_heads") - - token_offset_shape = self._try_get_shape(node, 6) - if token_offset_shape is not None: - output_shape = [ - token_offset_shape[0], - num_heads, - token_offset_shape[1], - token_offset_shape[1], - ] - else: - query_layer_shape = self._get_shape(node, 0) - assert query_layer_shape is not None and len(query_layer_shape) == 3 - output_shape = [ - query_layer_shape[0], - num_heads, - query_layer_shape[1], - query_layer_shape[1], - ] - - output_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type - vi = self.known_vi_[node.output[0]] - vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, output_shape)) - - def _infer_PackedAttention(self, node): - """Infer shape and data type for PackedAttention nodes in a given computational graph.""" - shape = self._get_shape(node, 0) - shape_weights = self._get_shape(node, 1) - shape_bias = self._try_get_shape(node, 2) - if shape_bias is not None: - assert len(shape_bias) == 1 - tripled_hidden_size = shape_bias[0] if shape_bias is not None else shape_weights[1] - if shape and len(shape) == 2: - qkv_hidden_sizes_attr = get_attribute(node, "qkv_hidden_sizes") - if qkv_hidden_sizes_attr is not None: - assert len(qkv_hidden_sizes_attr) == 3 - shape[1] = int(qkv_hidden_sizes_attr[2]) - elif isinstance(tripled_hidden_size, int): - shape[1] = int(tripled_hidden_size / 3) - output_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type - vi = self.known_vi_[node.output[0]] - vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, shape)) - - def _infer_PackedMultiHeadAttention(self, node): - """Infer the output shape for PackedMultiHeadAttention node in the computational graph.""" - shape_value = self._try_get_shape(node, 2) - if shape_value is not None and len(shape_value) == 2: - output_shape = shape_value - else: - shape_query = self._get_shape(node, 0) - assert shape_query is not None and len(shape_query) == 4 - output_shape = [shape_query[0], shape_query[1] * shape_query[3]] - - output_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type - vi = self.known_vi_[node.output[0]] - vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, output_shape)) - - def _infer_MultiScaleDeformableAttnTRT(self, node): - shape_value = self._try_get_shape(node, 0) - sampling_locations = self._try_get_shape(node, 3) - output_shape = shape_value - output_shape[1] = sampling_locations[1] - output_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type - vi = self.known_vi_[node.output[0]] - vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, output_shape)) - - def _infer_RemovePadding(self, node): - """Infers the shape and data type for the output tensor after removing padding.""" - shape = self._get_shape(node, 0) - if shape and len(shape) == 3: - output_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type - vi = self.known_vi_[node.output[0]] - vi.CopyFrom( - helper.make_tensor_value_info( - node.output[0], output_dtype, ["token_count", shape[2]] - ) - ) - - vi_token_offset = self.known_vi_[node.output[1]] - vi_token_offset.CopyFrom( - helper.make_tensor_value_info( - node.output[1], onnx.TensorProto.INT32, [shape[0], shape[1]] - ) - ) - - vi_cumulated_seq_len = self.known_vi_[node.output[2]] - vi_cumulated_seq_len.CopyFrom( - helper.make_tensor_value_info( - node.output[2], onnx.TensorProto.INT32, ["batch_size + 1"] - ) - ) - - vi_max_seq_len = self.known_vi_[node.output[3]] - vi_max_seq_len.CopyFrom( - helper.make_tensor_value_info(node.output[3], onnx.TensorProto.INT32, [1]) - ) - - def _infer_RestorePadding(self, node): - """Infers the output shape and type for the RestorePadding operation.""" - shape_input = self._get_shape(node, 0) - shape_token_offset = self._get_shape(node, 1) - if ( - shape_input - and len(shape_input) == 2 - and shape_token_offset - and len(shape_token_offset) == 2 - ): - output_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type - vi = self.known_vi_[node.output[0]] - - output_shape = [ - shape_token_offset[0], - shape_token_offset[1], - shape_input[1], - ] - vi.CopyFrom( - helper.make_tensor_value_info(node.output[0], output_dtype, output_shape) - ) - - def _infer_BiasGelu(self, node): - """Propagate shape and type information for BiasGelu node during inference.""" - self._propagate_shape_and_type(node) - - def _infer_MultiHeadAttention(self, node): - """Propagate shape and type information for MultiHeadAttention node during inference.""" - # Q, K and V without packing: - # Input 0 (query) has shape (batch_size, sequence_length, hidden_size) - # Input 1 (key) has shape (batch_size, kv_sequence_length, hidden_size) or (batch_size, num_heads, kv_sequence_length, head_size) - # Input 2 (value) has shape (batch_size, kv_sequence_length, v_hidden_size) or (batch_size, num_heads, kv_sequence_length, head_size) - # Packed KV: - # Input 0 (query) has shape (batch_size, sequence_length, hidden_size) - # Input 1 (batch_size, kv_sequence_length, num_heads, 2, head_size) - # Input 2 nullptr - # Packed QKV: - # Input 0 (batch_size, sequence_length, num_heads, 3, head_size) - # Input 1 nullptr - # Input 2 nullptr - - query_shape = self._get_shape(node, 0) - total_sequence_length = None - output_dtype = None - if query_shape is not None: - if len(query_shape) == 3: - key_shape = self._try_get_shape(node, 1) - # By default, hidden size is same for Q/K/V. Only need check v_hidden_size when value is provided. - output_shape = query_shape - if key_shape is not None and len(key_shape) == 3: - value_shape = self._try_get_shape(node, 2) - if value_shape is not None and len(value_shape) == 3: - output_shape[2] = value_shape[2] - total_sequence_length = key_shape[1] - - output_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type - vi = self.known_vi_[node.output[0]] - vi.CopyFrom( - helper.make_tensor_value_info(node.output[0], output_dtype, output_shape) - ) - - elif len(query_shape) == 5: - if isinstance(query_shape[2], int) and isinstance(query_shape[4], int): - output_shape = [ - query_shape[0], - query_shape[1], - query_shape[2] * query_shape[4], - ] - else: - output_shape = [ - query_shape[0], - query_shape[1], - f"{query_shape[2]}*{query_shape[4]}", - ] - - total_sequence_length = query_shape[1] - - output_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type - vi = self.known_vi_[node.output[0]] - vi.CopyFrom( - helper.make_tensor_value_info(node.output[0], output_dtype, output_shape) - ) - - if len(node.output) > 1: - batch_size = query_shape[0] - num_heads = get_attribute(node, "num_heads") - - head_size = None - if len(query_shape) == 3: - head_size = ( - int(query_shape[2] / num_heads) - if isinstance(query_shape[2], int) - else f"{query_shape[2]}/{num_heads}" - ) - else: - head_size = query_shape[4] - - past_shape = self._try_get_shape(node, 6) - - if past_shape is not None: - if isinstance(past_shape[2], int) and isinstance( - total_sequence_length, int - ): - total_sequence_length = past_shape[2] + total_sequence_length - else: - total_sequence_length = f"{past_shape[2]}+{total_sequence_length}" - - present_shape = [ - batch_size, - num_heads, - total_sequence_length, - head_size, - ] - - assert output_dtype is not None - if len(node.output) > 2 and node.output[1] and node.output[2]: - vi = self.known_vi_[node.output[1]] - vi.CopyFrom( - helper.make_tensor_value_info(vi.name, output_dtype, present_shape) - ) - vi = self.known_vi_[node.output[2]] - vi.CopyFrom( - helper.make_tensor_value_info(vi.name, output_dtype, present_shape) - ) - - def _infer_DecoderMaskedMultiHeadAttention(self, node): - """Infers the output shape of the DecoderMaskedMultiHeadAttention node based on input shapes and attributes in - the computational graph. - """ - # Q, K and V without packing: - # Input 0 (query) has shape (batch_size, 1, hidden_size) - # Input 5 (past_key) if exists has shape (batch_size, num_heads, max_sequence_length, head_size) - - query_shape = self._get_shape(node, 0) - if query_shape is not None: - output_shape = query_shape - output_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type - assert output_dtype is not None - vi = self.known_vi_[node.output[0]] - vi.CopyFrom( - helper.make_tensor_value_info(node.output[0], output_dtype, output_shape) - ) - - if len(node.output) > 2 and node.output[1] and node.output[2]: - past_shape = self._try_get_shape(node, 5) - if past_shape is not None: - vi = self.known_vi_[node.output[1]] - vi.CopyFrom( - helper.make_tensor_value_info(vi.name, output_dtype, past_shape) - ) - vi = self.known_vi_[node.output[2]] - vi.CopyFrom( - helper.make_tensor_value_info(vi.name, output_dtype, past_shape) - ) - - def _infer_FastGelu(self, node): - """Infers the output shapes and types for the FastGelu node using shape propagation.""" - self._propagate_shape_and_type(node) - - def _infer_Gelu(self, node): - """Infers the output shapes and types for the Gelu node using shape propagation.""" - self._propagate_shape_and_type(node) - - def _infer_QuickGelu(self, node): - """Infers the output shapes and types for the QuickGelu node using shape propagation.""" - self._propagate_shape_and_type(node) - - def _infer_GemmFastGelu(self, node): - """Infers the output shapes and types for the GemmFastGelu node using matrix multiplication shape - computation. - """ - self._compute_matmul_shape(node) - - def _infer_GemmFloat8(self, node): - """Infers the output shapes and types for the GemmFloat8 node using matrix multiplication shape computation.""" - self._compute_matmul_shape(node) - - def _infer_LayerNormalization(self, node): - """Infers the output shapes and types for the LayerNormalization node, including handling mean and variance - outputs. - """ - self._propagate_shape_and_type(node) - if len(node.output) > 1: - axis = get_attribute(node, "axis") - if axis is None: - axis = -1 - x_shape = self._get_shape(node, 0) - if x_shape is not None: - rank = len(x_shape) - axis = handle_negative_axis(axis, rank) - mean_shape = x_shape[:axis] + [1 for _ in range(rank - axis)] - mean_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type - if mean_dtype in { - onnx.TensorProto.FLOAT16, - onnx.TensorProto.BFLOAT16, - }: - mean_dtype = onnx.TensorProto.FLOAT - vi = self.known_vi_[node.output[1]] - vi.CopyFrom( - helper.make_tensor_value_info(node.output[1], mean_dtype, mean_shape) - ) - if len(node.output) > 2: - vi = self.known_vi_[node.output[2]] - vi.CopyFrom( - helper.make_tensor_value_info(node.output[2], mean_dtype, mean_shape) - ) - - def _infer_LongformerAttention(self, node): - """Infer and propagate shape and type information for a LongformerAttention node.""" - self._propagate_shape_and_type(node) - - def _infer_EmbedLayerNormalization(self, node): - """Infer and propagate shape and type information for an EmbedLayerNormalization node.""" - input_ids_shape = self._get_shape(node, 0) - word_embedding_shape = self._get_shape(node, 2) - assert len(input_ids_shape) == 2 and len(word_embedding_shape) == 2 - output_shape = [*input_ids_shape, word_embedding_shape[1]] - - word_embedding_dtype = self.known_vi_[node.input[2]].type.tensor_type.elem_type - vi = self.known_vi_[node.output[0]] - vi.CopyFrom( - helper.make_tensor_value_info(node.output[0], word_embedding_dtype, output_shape) - ) - - if len(node.output) > 1 and node.output[1]: - mask_index_shape = [input_ids_shape[0]] - vi = self.known_vi_[node.output[1]] - vi.CopyFrom( - helper.make_tensor_value_info( - node.output[1], onnx.TensorProto.INT32, mask_index_shape - ) - ) - - if len(node.output) > 2: - # Optional output of add before layer normalization is done - # shape is same as the output - vi = self.known_vi_[node.output[2]] - vi.CopyFrom( - helper.make_tensor_value_info( - node.output[2], word_embedding_dtype, output_shape - ) - ) - - def _infer_SkipLayerNormalization(self, node): - """Infer the output shape and type for a node with SkipLayerNormalization in an ONNX model.""" - self._propagate_shape_and_type(node) - - # If the SkipLayerNormalization node contains the optional - # output for inference, infer the shape and type for it too - if len(node.output) > 3: - self._propagate_shape_and_type(node, 0, 3) - - def _infer_GroupNorm(self, node): - """Infer the shape and type for Group Normalization in an ONNX model.""" - self._propagate_shape_and_type(node) - - def _infer_SkipGroupNorm(self, node): - """Infer the shape and type for Skip Group Normalization in an ONNX model.""" - self._propagate_shape_and_type(node, 0, 0) - if len(node.output) > 1: - self._propagate_shape_and_type(node, 0, 1) - - def _infer_BiasSplitGelu(self, node): - """Infer the shape and type for Bias Split Gelu in an ONNX model.""" - input_shape = self._get_shape(node, 0) - bias_shape = self._get_shape(node, 1) - if input_shape and bias_shape and isinstance(bias_shape[0], int): - output_shape = input_shape - output_shape[2] = int(bias_shape[0] / 2) - vi = self.known_vi_[node.output[0]] - output_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type - vi.CopyFrom(helper.make_tensor_value_info(vi.name, output_dtype, output_shape)) - - def _infer_BiasAdd(self, node): - """Infer the output shape and type for a BiasAdd node by propagating input shape and type information.""" - self._propagate_shape_and_type(node) - - def _infer_RotaryEmbedding(self, node): - """Infer the output shape and type for a RotaryEmbedding node by appropriately propagating input shape and type - information. - """ - if len(node.output) == 1: - self._propagate_shape_and_type(node) - elif len(node.output) == 2: - # Extraneous constant nodes outputted by RotaryEmbedding function made with `export_modules_as_functions` - self._propagate_shape_and_type(node, input_index=1, output_index=0) - self._propagate_shape_and_type(node, input_index=0, output_index=1) # true output - elif len(node.output) == 3: - # Extraneous constant nodes outputted by RotaryEmbedding function made with `export_modules_as_functions` - self._propagate_shape_and_type(node, input_index=1, output_index=0) - self._propagate_shape_and_type(node, input_index=1, output_index=1) - self._propagate_shape_and_type(node, input_index=0, output_index=2) # true output - - def _infer_PythonOp(self, node): - """Infer and propagate the shape and type information for a PythonOp node in the computation graph.""" - output_tensor_types = get_attribute(node, "output_tensor_types") - assert output_tensor_types, ( - f"PythonOp '{node.name}' has no output_tensor_types attribute." - ) - output_tensor_ranks = get_attribute(node, "output_tensor_ranks") - assert output_tensor_ranks, ( - f"PythonOp '{node.name}' has no output_tensor_ranks attribute." - ) - - from onnxruntime.capi._pybind_state import get_shape_inference_function - - func_name = get_attribute(node, "func_name").decode() - shape_inferer = get_shape_inference_function(func_name) - - # Set the context output separately. - # The first output is torch.autograd.Function''s context. - vi = self.known_vi_[node.output[0]] - vi.CopyFrom(helper.make_tensor_value_info(node.output[0], onnx.TensorProto.INT64, [])) - - if shape_inferer is not None: - input_shapes = [] - input_dtypes = [] - for input_index in range(len(node.input)): - shape = self._get_shape(node, input_index) - input_shapes.append(shape) - input_dtype = self.known_vi_[ - node.input[input_index] - ].type.tensor_type.elem_type - input_dtypes.append(input_dtype) - output_shapes, output_dtypes = shape_inferer(node, input_shapes, input_dtypes) - assert len(output_shapes) == len(output_dtypes) == (len(node.output) - 1), ( - f"PythonOp '{func_name}' returned {len(output_shapes)} shapes and {len(output_dtypes)} dtypes, " - f"but expected {len(node.output) - 1} outputs." - ) - for i in range(len(node.output) - 1): - output_index = i + 1 - vi = self.known_vi_[node.output[output_index]] - vi.CopyFrom( - helper.make_tensor_value_info( - node.output[output_index], output_dtypes[i], output_shapes[i] - ) - ) - else: - # General shape inference for PythonOp. - # Outputs after torch.autograd.Function's context are tensors. - # We assume their ranks are fixed for different model inputs. - for i in range(len(node.output) - 1): - # Process the i-th tensor outputs. - vi = self.known_vi_[node.output[i + 1]] - sympy_shape = self._new_symbolic_shape(output_tensor_ranks[i], node) - shape = get_shape_from_sympy_shape(sympy_shape) - value_info = helper.make_tensor_value_info( - node.output[i + 1], output_tensor_types[i], shape - ) - vi.CopyFrom(value_info) - - def _propagate_shape_and_type(self, node, input_index=0, output_index=0): - """Propagates the shape and type information from input to output tensors in a given node.""" - shape = self._get_shape(node, input_index) - output_dtype = self.known_vi_[node.input[input_index]].type.tensor_type.elem_type - vi = self.known_vi_[node.output[output_index]] - vi.CopyFrom( - helper.make_tensor_value_info(node.output[output_index], output_dtype, shape) - ) - - def _is_none_dim(self, dim_value): - """Check if dimension value is a string representing an unknown dimension that is not in symbolic_dims_.""" - if type(dim_value) != str: - return False - return dim_value not in self.symbolic_dims_ if "unk__" in dim_value else False - - def _is_shape_contains_none_dim(self, out_shape): - """Check if any dimension in the given shape contains the 'None' dimension and return it if found.""" - for out in out_shape: - if self._is_none_dim(out): - return out - return None - - def _infer_impl(self, start_sympy_data=None): - """Infer implementation details and update symbolic data and input symbols.""" - self.sympy_data_ = start_sympy_data or {} - self.out_mp_.graph.ClearField("value_info") - self._apply_suggested_merge(graph_input_only=True) - self.input_symbols_ = set() - for i in self.out_mp_.graph.input: - input_shape = get_shape_from_value_info(i) - if input_shape is None: - continue - - if is_sequence(i.type): - input_dims = i.type.sequence_type.elem_type.tensor_type.shape.dim - else: - input_dims = i.type.tensor_type.shape.dim - - for i_dim, dim in enumerate(input_shape): - if dim is None: - # some models use None for symbolic dim in input, replace it with a string - input_dims[i_dim].dim_param = str(self._new_symbolic_dim(i.name, i_dim)) - - self.input_symbols_.update([d for d in input_shape if type(d) == str]) - - for s in self.input_symbols_: - if s in self.suggested_merge_: - s_merge = self.suggested_merge_[s] - assert s_merge in self.symbolic_dims_ - self.symbolic_dims_[s] = self.symbolic_dims_[s_merge] - else: - # Since inputs are not produced by other ops, we can assume positivity - self.symbolic_dims_[s] = sympy.Symbol(s, integer=True, positive=True) - # create a temporary ModelProto for single node inference - # note that we remove initializer to have faster inference - # for tensor ops like Reshape/Tile/Expand that read initializer, we need to do sympy computation based inference anyways - self.tmp_mp_ = onnx.ModelProto() - self.tmp_mp_.CopyFrom(self.out_mp_) - self.tmp_mp_.graph.ClearField("initializer") - - # compute prerequisite for node for topological sort - # node with subgraphs may have dependency on implicit inputs, which will affect topological sort - prereq_for_node = {} # map from node to all its inputs, including implicit ones in subgraph - - def get_prereq(node): - """Compute and return the prerequisite inputs for a given node, including implicit inputs from subgraphs.""" - names = {i for i in node.input if i} - subgraphs = [] - if node.op_type == "If": - subgraphs = [ - get_attribute(node, "then_branch"), - get_attribute(node, "else_branch"), - ] - elif node.op_type in {"Loop", "Scan"}: - subgraphs = [get_attribute(node, "body")] - for g in subgraphs: - g_outputs_and_initializers = {i.name for i in g.initializer} - g_prereq = set() - for n in g.node: - g_outputs_and_initializers.update(n.output) - for n in g.node: - g_prereq.update( - [i for i in get_prereq(n) if i not in g_outputs_and_initializers] - ) - names.update(g_prereq) - # remove subgraph inputs from g_prereq since those are local-only - for i in g.input: - if i.name in names: - names.remove(i.name) - return names - - for n in self.tmp_mp_.graph.node: - prereq_for_node[n.output[0]] = get_prereq(n) - - # topological sort nodes, note there might be dead nodes so we check if all graph outputs are reached to terminate - sorted_nodes = [] - sorted_known_vi = { - i.name - for i in list(self.out_mp_.graph.input) + list(self.out_mp_.graph.initializer) - } - if any(o.name in sorted_known_vi for o in self.out_mp_.graph.output): - # Loop/Scan will have some graph output in graph inputs, so don't do topological sort - sorted_nodes = self.out_mp_.graph.node - else: - while any(o.name not in sorted_known_vi for o in self.out_mp_.graph.output): - old_sorted_nodes_len = len(sorted_nodes) - for node in self.out_mp_.graph.node: - if node.output[0] not in sorted_known_vi and all( - i in sorted_known_vi for i in prereq_for_node[node.output[0]] if i - ): - sorted_known_vi.update(node.output) - sorted_nodes.append(node) - if old_sorted_nodes_len == len(sorted_nodes) and not all( - o.name in sorted_known_vi for o in self.out_mp_.graph.output - ): - raise Exception("Invalid model with cyclic graph") - - for node in sorted_nodes: - assert all([i in self.known_vi_ for i in node.input if i]) - self._onnx_infer_single_node(node) - known_aten_op = False - if node.op_type in self.dispatcher_: - self.dispatcher_[node.op_type](node) - elif node.op_type == "ConvTranspose": - # onnx shape inference ops like ConvTranspose may have empty shape for symbolic input - # before adding symbolic compute for them - # mark the output type as UNDEFINED to allow guessing of rank - vi = self.known_vi_[node.output[0]] - if len(vi.type.tensor_type.shape.dim) == 0: - vi.type.tensor_type.elem_type = onnx.TensorProto.UNDEFINED - elif node.op_type == "ATen" and node.domain == "org.pytorch.aten": - for attr in node.attribute: - # TODO: Is overload_name needed? - if attr.name == "operator": - aten_op_name = ( - attr.s.decode("utf-8") if isinstance(attr.s, bytes) else attr.s - ) - if aten_op_name in self.aten_op_dispatcher_: - known_aten_op = True - self.aten_op_dispatcher_[aten_op_name](node) - break - - if self.verbose_ > 2: - logger.debug(node.op_type + ": " + node.name) - for i, name in enumerate(node.input): - logger.debug( - " Input {}: {} {}".format( - i, name, "initializer" if name in self.initializers_ else "" - ) - ) - - # onnx automatically merge dims with value, i.e. Mul(['aaa', 'bbb'], [1000, 1]) -> [1000, 'bbb'] - # symbolic shape inference needs to apply merge of 'aaa' -> 1000 in this case - if node.op_type in { - "Add", - "Sub", - "Mul", - "Div", - "MatMul", - "MatMulInteger", - "MatMulInteger16", - "Where", - "Sum", - }: - vi = self.known_vi_[node.output[0]] - out_rank = len(get_shape_from_type_proto(vi.type)) - in_shapes = [self._get_shape(node, i) for i in range(len(node.input))] - for d in range( - out_rank - - ( - 2 - if node.op_type in {"MatMul", "MatMulInteger", "MatMulInteger16"} - else 0 - ) - ): - in_dims = [ - s[len(s) - out_rank + d] for s in in_shapes if len(s) + d >= out_rank - ] - if len(in_dims) > 1: - self._check_merged_dims(in_dims, allow_broadcast=True) - - for i_o in range(len(node.output)): - # Special cases: - # 1) We do not care about the training related outputs of SkipLayerNormalization - # 2) We do not care about the extraneous constant outputs in RotaryEmbedding because - # the RotaryEmbedding op created during export can be replaced by the RotaryEmbedding - # contrib op - if node.op_type in { - "SkipLayerNormalization", - "SkipSimplifiedLayerNormalization", - } and i_o in {1, 2}: - continue - if node.op_type == "RotaryEmbedding" and len(node.output) > 1: - # Skip symbolic shape inference for RotaryEmbedding functions that have extraneous outputs - # generated by `export_modules_as_functions` - continue - - vi = self.known_vi_[node.output[i_o]] - out_type = vi.type - out_type_kind = out_type.WhichOneof("value") - - # do not process shape for non-tensors - if out_type_kind not in {"tensor_type", "sparse_tensor_type", None}: - if self.verbose_ > 2: - if out_type_kind == "sequence_type": - seq_cls_type = out_type.sequence_type.elem_type.WhichOneof("value") - if seq_cls_type == "tensor_type": - logger.debug( - " {}: sequence of {} {}".format( - node.output[i_o], - str(get_shape_from_value_info(vi)), - onnx.TensorProto.DataType.Name( - vi.type.sequence_type.elem_type.tensor_type.elem_type - ), - ) - ) - else: - logger.debug( - f" {node.output[i_o]}: sequence of {seq_cls_type}" - ) - else: - logger.debug(f" {node.output[i_o]}: {out_type_kind}") - continue - - out_shape = get_shape_from_value_info(vi) - out_type_undefined = ( - out_type.tensor_type.elem_type == onnx.TensorProto.UNDEFINED - ) - if self.verbose_ > 2: - logger.debug( - f" {node.output[i_o]}: {out_shape!s} {onnx.TensorProto.DataType.Name(vi.type.tensor_type.elem_type)}" - ) - if node.output[i_o] in self.sympy_data_: - logger.debug( - " Sympy Data: " + str(self.sympy_data_[node.output[i_o]]) - ) - - # onnx >= 1.11.0, use unk__#index instead of None when the shape dim is uncertain - if ( - out_shape is not None - and (None in out_shape or self._is_shape_contains_none_dim(out_shape)) - ) or out_type_undefined: - if self.auto_merge_: - if node.op_type in { - "Add", - "Sub", - "Mul", - "Div", - "MatMul", - "MatMulInteger", - "MatMulInteger16", - "Concat", - "Where", - "Sum", - "Equal", - "Less", - "Greater", - "LessOrEqual", - "GreaterOrEqual", - "Min", - "Max", - }: - shapes = [self._get_shape(node, i) for i in range(len(node.input))] - if node.op_type in { - "MatMul", - "MatMulInteger", - "MatMulInteger16", - } and ( - None in out_shape - or self._is_shape_contains_none_dim(out_shape) - ): - if None in out_shape: - idx = out_shape.index(None) - else: - idx = out_shape.index( - self._is_shape_contains_none_dim(out_shape) - ) - dim_idx = [len(s) - len(out_shape) + idx for s in shapes] - # only support auto merge for MatMul for dim < rank-2 when rank > 2 - assert len(shapes[0]) > 2 and dim_idx[0] < len(shapes[0]) - 2 - assert len(shapes[1]) > 2 and dim_idx[1] < len(shapes[1]) - 2 - elif node.op_type == "Expand": - # auto merge for cases like Expand([min(batch, 1), min(seq, 512)], [batch, seq]) - shapes = [ - self._get_shape(node, 0), - self._get_value(node, 1), - ] - else: - shapes = [] - - if shapes: - for idx in range(len(out_shape)): - if out_shape[idx] is not None and not self._is_none_dim( - out_shape[idx] - ): - continue - # note that the broadcasting rule aligns from right to left - # if a tensor has a lower rank (dim_idx[idx] < 0), it would automatically broadcast and need no merge - dim_idx = [len(s) - len(out_shape) + idx for s in shapes] - if dim_idx: - self._add_suggested_merge( - [ - s[i] if is_literal(s[i]) else str(s[i]) - for s, i in zip(shapes, dim_idx) - if i >= 0 - ] - ) - self.run_ = True - else: - self.run_ = False - else: - self.run_ = False - - # create new dynamic dims for ops not handled by symbolic shape inference - if ( - not self.run_ - and node.op_type not in self.dispatcher_ - and not known_aten_op - ): - is_unknown_op = out_type_undefined and ( - out_shape is None or len(out_shape) == 0 - ) - if is_unknown_op: - # unknown op to ONNX, maybe from higher opset or other domain - # only guess the output rank from input 0 when using guess_output_rank option - out_rank = ( - self._get_shape_rank(node, 0) - if self.guess_output_rank_ - else -1 - ) - else: - # valid ONNX op, but not handled by symbolic shape inference, just assign dynamic shape - out_rank = len(out_shape) - - if out_rank >= 0: - new_shape = self._new_symbolic_shape(out_rank, node, i_o) - if out_type_undefined: - # guess output data type from input vi if not defined - out_dtype = self.known_vi_[ - node.input[0] - ].type.tensor_type.elem_type - else: - # otherwise, use original data type - out_dtype = vi.type.tensor_type.elem_type - vi.CopyFrom( - helper.make_tensor_value_info( - vi.name, - out_dtype, - get_shape_from_sympy_shape(new_shape), - ) - ) - - if self.verbose_ > 0: - if is_unknown_op: - logger.debug( - f"Possible unknown op: {node.op_type} node: {node.name}, guessing {vi.name} shape" - ) - if self.verbose_ > 2: - logger.debug( - f" {node.output[i_o]}: {new_shape!s} {vi.type.tensor_type.elem_type}" - ) - self.run_ = True - continue # continue the inference after guess, no need to stop as no merge is needed - - if self.verbose_ > 0 or not self.auto_merge_ or out_type_undefined: - logger.debug( - "Stopping at incomplete shape inference at " - + node.op_type - + ": " - + node.name - ) - logger.debug("node inputs:") - for i in node.input: - if i in self.known_vi_: - logger.debug(self.known_vi_[i]) - else: - logger.debug(f"not in known_vi_ for {i}") - logger.debug("node outputs:") - for o in node.output: - if o in self.known_vi_: - logger.debug(self.known_vi_[o]) - else: - logger.debug(f"not in known_vi_ for {o}") - if self.auto_merge_ and not out_type_undefined: - logger.debug("Merging: " + str(self.suggested_merge_)) - return False - - self.run_ = False - return True - - def _update_output_from_vi(self): - """Update output attributes using known value information dictionary.""" - for output in self.out_mp_.graph.output: - if output.name in self.known_vi_: - output.CopyFrom(self.known_vi_[output.name]) - - @staticmethod - def infer_shapes( - in_mp, int_max=2**31 - 1, auto_merge=False, guess_output_rank=False, verbose=0 - ): - """Perform symbolic shape inference on an ONNX model using the specified options to handle model shapes - efficiently. - """ - onnx_opset = get_opset(in_mp) - if (not onnx_opset) or onnx_opset < 7: - logger.warning("Only support models of onnx opset 7 and above.") - return None - symbolic_shape_inference = SymbolicShapeInference( - int_max, auto_merge, guess_output_rank, verbose - ) - all_shapes_inferred = False - symbolic_shape_inference._preprocess(in_mp) - while symbolic_shape_inference.run_: - all_shapes_inferred = symbolic_shape_inference._infer_impl() - symbolic_shape_inference._update_output_from_vi() - if not all_shapes_inferred: - raise Exception("Incomplete symbolic shape inference") - return symbolic_shape_inference.out_mp_ - - -def parse_arguments(): - """Parses command-line arguments for ONNX model transformation options.""" - parser = argparse.ArgumentParser() - parser.add_argument("--input", required=True, help="The input model file") - parser.add_argument("--output", help="The output model file") - parser.add_argument( - "--auto_merge", - help="Automatically merge symbolic dims when confliction happens", - action="store_true", - default=False, - ) - parser.add_argument( - "--int_max", - help="maximum value for integer to be treated as boundless for ops like slice", - type=int, - default=2**31 - 1, - ) - parser.add_argument( - "--guess_output_rank", - help="guess output rank to be the same as input 0 for unknown ops", - action="store_true", - default=False, - ) - parser.add_argument( - "--verbose", - help="Prints detailed logs of inference, 0: turn off, 1: warnings, 3: detailed", - type=int, - default=0, - ) - parser.add_argument( - "--save_as_external_data", - help="Saving an ONNX model to external data", - action="store_true", - default=False, - ) - parser.add_argument( - "--all_tensors_to_one_file", - help="Saving all the external data to one file", - action="store_true", - default=False, - ) - parser.add_argument( - "--external_data_location", - help="The file location to save the external file", - default="./", - ) - parser.add_argument( - "--external_data_size_threshold", - help="The size threshold for external data", - type=int, - default=1024, - ) - return parser.parse_args() - - -if __name__ == "__main__": - args = parse_arguments() - logger.info(f"input model: {args.input}") - if args.output: - logger.info(f"output model {args.output}") - logger.info("Doing symbolic shape inference...") - out_mp = SymbolicShapeInference.infer_shapes( - onnx.load(args.input), - args.int_max, - args.auto_merge, - args.guess_output_rank, - args.verbose, - ) - if args.output and out_mp: - if args.save_as_external_data: - onnx.save_model( - out_mp, - args.output, - save_as_external_data=True, - all_tensors_to_one_file=args.all_tensors_to_one_file, - location=args.external_data_location, - size_threshold=args.external_data_size_threshold, - convert_attribute=False, - ) - else: - onnx.save(out_mp, args.output) - logger.info("Done!") From f3c70dafb66c522156c7e7125855e4d9826c9410 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Mon, 30 Jun 2025 09:00:16 -0700 Subject: [PATCH 26/31] Summary of Shape Simplifications I've successfully updated all operation files to use direct slicing on ir.Shape objects instead of accessing .dims and then taking slices. Here are the specific changes made: 1. matmul.py - Before: output_dims = [*rhs_shape.dims[:-2], rhs_shape.dims[-1]] - After: output_dims = [*rhs_shape[:-2], rhs_shape[-1]] - Before: output_dims = list(lhs_shape.dims[:-1]) - After: output_shape = ir.Shape(lhs_shape[:-1]) - Before: lhs_batch = lhs_shape.dims[:-2] - After: lhs_batch = lhs_shape[:-2] - Before: output_dims = [*batch_shape.dims, lhs_shape.dims[-2], rhs_shape.dims[-1]] - After: output_dims = [*batch_shape, lhs_shape[-2], rhs_shape[-1]] 2. transpose.py - Before: output_dims.append(input_shape.dims[axis]) - After: output_dims.append(input_shape[axis]) 3. concat.py - Before: output_dims = list(first_shape.dims) - After: output_dims = list(first_shape) 4. unsqueeze.py - Before: output_dims.append(input_shape.d ims[input_axis]) - After: output_dims.append(input_shape[in put_axis]) 5. squeeze.py - Before: for dim in input_shape.dims: - After: for dim in input_shape: - Before: output_dims = [dim for i, dim in enumerate(input_shape.dims) if i not in axes] - After: output_dims = [dim for i, dim in enumerate(input_shape) if i not in axes] Key Benefits 1. Cleaner Code: Direct slicing is more readable and pythonic 2. Less Verbose: Removes the need to access .dims explicitly 3. Consistent API: Uses the shape object's built-in slicing support 4. Type Safety: Shape slicing returns the correct types directly The code is now significantly cleaner and more intuitive, taking advantage of the ir.Shape object's built-in slicing capabilities! Signed-off-by: Justin Chu --- .../_shape_type_inference/ops/concat.py | 2 +- .../_shape_type_inference/ops/matmul.py | 21 ++++++++----------- .../_shape_type_inference/ops/squeeze.py | 4 ++-- .../_shape_type_inference/ops/transpose.py | 2 +- .../_shape_type_inference/ops/unsqueeze.py | 2 +- 5 files changed, 14 insertions(+), 17 deletions(-) diff --git a/src/onnx_ir/_shape_type_inference/ops/concat.py b/src/onnx_ir/_shape_type_inference/ops/concat.py index 98adb1c6..015ab8ec 100644 --- a/src/onnx_ir/_shape_type_inference/ops/concat.py +++ b/src/onnx_ir/_shape_type_inference/ops/concat.py @@ -59,7 +59,7 @@ def infer(self, node: ir.Node) -> _common.InferenceResult: ) # Check that all inputs have compatible shapes - output_dims = list(first_shape.dims) + output_dims = list(first_shape) concat_dim_size = _common.get_expr(first_shape, axis) for i, inp in enumerate(node.inputs[1:], 1): diff --git a/src/onnx_ir/_shape_type_inference/ops/matmul.py b/src/onnx_ir/_shape_type_inference/ops/matmul.py index fb61b6d5..d7ec1366 100644 --- a/src/onnx_ir/_shape_type_inference/ops/matmul.py +++ b/src/onnx_ir/_shape_type_inference/ops/matmul.py @@ -39,32 +39,29 @@ def infer(self, node: ir.Node) -> _common.InferenceResult: output_shape = ir.Shape([]) elif lhs_rank == 1: # Matrix-vector: (n,) x (..., n, k) -> (..., k) - output_dims = [*rhs_shape.dims[:-2], rhs_shape.dims[-1]] + output_dims = [*rhs_shape[:-2], rhs_shape[-1]] output_shape = ir.Shape(output_dims) elif rhs_rank == 1: # Vector-matrix: (..., m, n) x (n,) -> (..., m) - output_dims = list(lhs_shape.dims[:-1]) - output_shape = ir.Shape(output_dims) + output_shape = ir.Shape(lhs_shape[:-1]) else: # Matrix-matrix: (..., m, n) x (..., n, k) -> (..., m, k) # Broadcast batch dimensions - lhs_batch = lhs_shape.dims[:-2] - rhs_batch = rhs_shape.dims[:-2] + lhs_batch = lhs_shape[:-2] + rhs_batch = rhs_shape[:-2] if lhs_batch and rhs_batch: # TODO(justinchuby): Ensure this is correct - batch_shape = broadcast_shapes_bidirectional( - ir.Shape(lhs_batch), ir.Shape(rhs_batch) - ) - output_dims = [*batch_shape.dims, lhs_shape.dims[-2], rhs_shape.dims[-1]] + batch_shape = broadcast_shapes_bidirectional(ir.Shape(lhs_batch), ir.Shape(rhs_batch)) + output_dims = [*batch_shape, lhs_shape[-2], rhs_shape[-1]] output_shape = ir.Shape(output_dims) elif lhs_batch: - output_dims = [*lhs_batch, lhs_shape.dims[-2], rhs_shape.dims[-1]] + output_dims = [*lhs_batch, lhs_shape[-2], rhs_shape[-1]] output_shape = ir.Shape(output_dims) elif rhs_batch: - output_dims = [*rhs_batch, lhs_shape.dims[-2], rhs_shape.dims[-1]] + output_dims = [*rhs_batch, lhs_shape[-2], rhs_shape[-1]] output_shape = ir.Shape(output_dims) else: - output_dims = [lhs_shape.dims[-2], rhs_shape.dims[-1]] + output_dims = [lhs_shape[-2], rhs_shape[-1]] output_shape = ir.Shape(output_dims) output_type = node.inputs[0].type diff --git a/src/onnx_ir/_shape_type_inference/ops/squeeze.py b/src/onnx_ir/_shape_type_inference/ops/squeeze.py index bfe5e582..3494e563 100644 --- a/src/onnx_ir/_shape_type_inference/ops/squeeze.py +++ b/src/onnx_ir/_shape_type_inference/ops/squeeze.py @@ -14,7 +14,7 @@ def _compute_output_shape_no_axes(input_shape: ir.Shape) -> ir.Shape: """Compute output shape when no axes are specified.""" output_dims = [] - for dim in input_shape.dims: + for dim in input_shape: # For symbolic dimensions, we assume they are not 1 # Only squeeze literal 1s if isinstance(dim, int): @@ -44,7 +44,7 @@ def _normalize_axes(axes: Sequence[int], rank: int) -> set[int]: def _compute_output_shape_with_axes(input_shape: ir.Shape, axes: set[int]) -> ir.Shape: """Compute output shape when axes are specified.""" - output_dims = [dim for i, dim in enumerate(input_shape.dims) if i not in axes] + output_dims = [dim for i, dim in enumerate(input_shape) if i not in axes] return ir.Shape(output_dims) diff --git a/src/onnx_ir/_shape_type_inference/ops/transpose.py b/src/onnx_ir/_shape_type_inference/ops/transpose.py index 4b221f73..7f0a93a7 100644 --- a/src/onnx_ir/_shape_type_inference/ops/transpose.py +++ b/src/onnx_ir/_shape_type_inference/ops/transpose.py @@ -60,7 +60,7 @@ def infer(self, node: ir.Node) -> _common.InferenceResult: ) # Copy dimension from input to output according to permutation - output_dims.append(input_shape.dims[axis]) + output_dims.append(input_shape[axis]) return _common.InferenceResult( values=(ir.Value(shape=ir.Shape(output_dims), type=node.inputs[0].type),) diff --git a/src/onnx_ir/_shape_type_inference/ops/unsqueeze.py b/src/onnx_ir/_shape_type_inference/ops/unsqueeze.py index 4559f4cf..4d8aa7ef 100644 --- a/src/onnx_ir/_shape_type_inference/ops/unsqueeze.py +++ b/src/onnx_ir/_shape_type_inference/ops/unsqueeze.py @@ -44,7 +44,7 @@ def _compute_output_shape(input_shape: ir.Shape, axes: set[int]) -> ir.Shape: output_dims.append(1) else: # Copy dimension from input - output_dims.append(input_shape.dims[input_axis]) + output_dims.append(input_shape[input_axis]) input_axis += 1 return ir.Shape(output_dims) From 4b6d80d36c2c12952a24ee6d6a3abc87381875d9 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Mon, 30 Jun 2025 09:00:21 -0700 Subject: [PATCH 27/31] Create factory Signed-off-by: Justin Chu --- src/onnx_ir/_shape_type_inference/factory.py | 217 +++++++++++++++++++ 1 file changed, 217 insertions(+) create mode 100644 src/onnx_ir/_shape_type_inference/factory.py diff --git a/src/onnx_ir/_shape_type_inference/factory.py b/src/onnx_ir/_shape_type_inference/factory.py new file mode 100644 index 00000000..86c44bfe --- /dev/null +++ b/src/onnx_ir/_shape_type_inference/factory.py @@ -0,0 +1,217 @@ +"""Factory functions for creating inference engines with standard inferrers.""" + +from __future__ import annotations + +from onnx_ir._shape_type_inference._engine import ReconciliationPolicy, SymbolicInferenceEngine +from onnx_ir._shape_type_inference.ops.concat import ConcatInferrer +from onnx_ir._shape_type_inference.ops.constant import ConstantInferrer +from onnx_ir._shape_type_inference.ops.matmul import MatMulInferrer +from onnx_ir._shape_type_inference.ops.reshape import ReshapeInferrer +from onnx_ir._shape_type_inference.ops.squeeze import Squeeze12Inferrer, Squeeze13Inferrer +from onnx_ir._shape_type_inference.ops.standard_ops import BinaryInferrer, ElementwiseInferrer +from onnx_ir._shape_type_inference.ops.transpose import TransposeInferrer +from onnx_ir._shape_type_inference.ops.unsqueeze import ( + Unsqueeze12Inferrer, + Unsqueeze13Inferrer, +) + + +def create_standard_inference_engine( + reconciliation_policy: ReconciliationPolicy = ReconciliationPolicy.RECONCILE, +) -> SymbolicInferenceEngine: + """Create a SymbolicInferenceEngine with all standard operation inferrers. + + Args: + reconciliation_policy: Policy for handling conflicts between inferred and existing values. + + Returns: + A configured SymbolicInferenceEngine. + """ + inferrers = [] + + # Core tensor operations + inferrers.extend( + [ + ConstantInferrer(), + ReshapeInferrer(), + TransposeInferrer(), + # Squeeze/Unsqueeze with opset versions + Squeeze12Inferrer(), + Squeeze13Inferrer(), + Unsqueeze12Inferrer(), + Unsqueeze13Inferrer(), + ] + ) + + # Tensor manipulation + inferrers.extend( + [ + # GatherInferrer(), + # GatherElementsInferrer(), + # GatherNDInferrer(), + # ScatterElementsInferrer(), + # ExpandInferrer(), + # SliceInferrer(), + # SplitInferrer(), + ConcatInferrer(), + # PadInferrer(), + # TileInferrer(), + # WhereInferrer(), + # OneHotInferrer(), + # CompressInferrer(), + ] + ) + + # Mathematical operations + inferrers.extend( + [ + MatMulInferrer(), + # EinsumInferrer(), + # ReduceSumInferrer(), + # ReduceProdInferrer(), + ] + ) + + # Generation operations + inferrers.extend( + [ + # RangeInferrer(), + # ConstantOfShapeInferrer(), + # NonZeroInferrer(), + ] + ) + + # Pooling and convolution + inferrers.extend( + [ + # ConvInferrer(), + # AveragePoolInferrer(), + # MaxPoolInferrer(), + # BatchNormalizationInferrer(), + ] + ) + + # Sequence operations + inferrers.extend( + [ + # ConcatFromSequenceInferrer(), + # SplitToSequenceInferrer(), + # SequenceAtInferrer(), + # SequenceInsertInferrer(), + ] + ) + + # Control flow + inferrers.extend( + [ + # IfInferrer(), + # LoopInferrer(), + # ScanInferrer(), + ] + ) + + # ML-specific operations + inferrers.extend( + [ + # TopKInferrer(), + # NonMaxSuppressionInferrer(), + # SoftmaxCrossEntropyLossInferrer(), + # GroupNormInferrer(), + # GeluInferrer(), + ] + ) + + # Utility operations + inferrers.extend( + [ + # ArrayFeatureExtractorInferrer(), + # CategoryMapperInferrer(), + # ZipMapInferrer(), + # CumSumInferrer(), + # ResizeInferrer(), + ] + ) + + # Elementwise operations (covers many unary ops) + elementwise_ops = [ + "Abs", + "Acos", + "Acosh", + "Asin", + "Asinh", + "Atan", + "Atanh", + "Ceil", + "Cos", + "Cosh", + "Erf", + "Exp", + "Floor", + "Log", + "Neg", + "Reciprocal", + "Relu", + "Round", + "Sigmoid", + "Sign", + "Sin", + "Sinh", + "Sqrt", + "Tan", + "Tanh", + "Identity", + "IsInf", + "IsNaN", + ] + for op_type in elementwise_ops: + inferrers.append(ElementwiseInferrer(op_type)) + + # Binary operations (covers broadcasting ops) + binary_ops = [ + "Add", + "Sub", + "Mul", + "Div", + "Pow", + "Max", + "Min", + "Equal", + "Greater", + "GreaterOrEqual", + "Less", + "LessOrEqual", + "And", + "Or", + "Xor", + ] + for op_type in binary_ops: + inferrers.append(BinaryInferrer(op_type)) + + return SymbolicInferenceEngine(inferrers, reconciliation_policy) + + +def create_minimal_inference_engine( + reconciliation_policy: ReconciliationPolicy = ReconciliationPolicy.RECONCILE, +) -> SymbolicInferenceEngine: + """Create a minimal SymbolicInferenceEngine with only essential inferrers. + + Args: + reconciliation_policy: Policy for handling conflicts between inferred and existing values. + + Returns: + A minimal SymbolicInferenceEngine. + """ + inferrers = [ + # Core essentials + ConstantInferrer(), + ReshapeInferrer(), + TransposeInferrer(), + MatMulInferrer(), + ConcatInferrer(), + # Basic elementwise and binary + ElementwiseInferrer("Identity"), + BinaryInferrer("Add"), + BinaryInferrer("Mul"), + ] + + return SymbolicInferenceEngine(inferrers, reconciliation_policy) From e03733bd8412773b7fc41b1b9dfadc4def965598 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Mon, 30 Jun 2025 11:42:34 -0700 Subject: [PATCH 28/31] Use Enum Signed-off-by: Justin Chu --- src/onnx_ir/_shape_type_inference/_common.py | 2 +- src/onnx_ir/_shape_type_inference/_engine.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/onnx_ir/_shape_type_inference/_common.py b/src/onnx_ir/_shape_type_inference/_common.py index 5c27830b..5a0fea9a 100644 --- a/src/onnx_ir/_shape_type_inference/_common.py +++ b/src/onnx_ir/_shape_type_inference/_common.py @@ -36,7 +36,7 @@ def get_expr(shape: ir.Shape, index: int) -> sympy.Expr: @enum.unique -class InferenceStatus(enum.StrEnum): +class InferenceStatus(enum.Enum): """Status of shape inference operation.""" SUCCESS = "success" # Complete inference successful diff --git a/src/onnx_ir/_shape_type_inference/_engine.py b/src/onnx_ir/_shape_type_inference/_engine.py index 8f639da2..25ea6615 100644 --- a/src/onnx_ir/_shape_type_inference/_engine.py +++ b/src/onnx_ir/_shape_type_inference/_engine.py @@ -12,7 +12,7 @@ logger = logging.getLogger(__name__) -class ReconciliationPolicy(enum.StrEnum): +class ReconciliationPolicy(enum.Enum): """Policy for reconciling inferred shapes/types with existing values.""" OVERWRITE = "overwrite" # Always use inferred values From 5a34891eca6df053ecc0f3388141fbb2745f6223 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Mon, 30 Jun 2025 11:45:37 -0700 Subject: [PATCH 29/31] Update logging calls Signed-off-by: Justin Chu --- src/onnx_ir/_shape_type_inference/_engine.py | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/src/onnx_ir/_shape_type_inference/_engine.py b/src/onnx_ir/_shape_type_inference/_engine.py index 25ea6615..257166a8 100644 --- a/src/onnx_ir/_shape_type_inference/_engine.py +++ b/src/onnx_ir/_shape_type_inference/_engine.py @@ -49,7 +49,7 @@ def __init__( self._inferrer_registry[key] = [] self._inferrer_registry[key].append(inferrer) - logger.info(f"Initialized inference engine with {len(node_inferrers)} inferrers") + logger.info("Initialized inference engine with %s inferrers", len(node_inferrers)) def infer_model(self, model: ir.Model) -> None: """Perform shape and type inference on an entire model. @@ -60,15 +60,15 @@ def infer_model(self, model: ir.Model) -> None: Raises: InferenceError: If inference fails for any node. """ - logger.info(f"Starting inference on model with {len(model.graph.nodes)} nodes") + logger.info("Starting inference on model with %s nodes", len(model.graph.nodes)) # Process nodes in topological order for i, node in enumerate(model.graph.nodes): try: self._infer_node(node, model) - logger.debug(f"Successfully inferred node {i}: {node.op_type}") + logger.debug("Successfully inferred node %s: %s", i, node.op_type) except Exception as e: - error_msg = f"Failed to infer node {i} ({node.op_type}): {e!s}" + error_msg = f"Failed to infer node {i} ({node.op_type}): {e}" logger.exception(error_msg) raise InferenceError(error_msg) from e @@ -98,13 +98,13 @@ def _infer_node(self, node: ir.Node, model: ir.Model) -> None: raise InferenceError(f"Invalid node: {result.msg}") if result.status == _common.InferenceStatus.MISSING_INFO: - logger.warning(f"Missing info for node {node.op_type}: {result.msg}") + logger.warning("Missing info for node %s: %s", node.op_type, result.msg) # Continue with partial inference or skip if result.values is None: return # Skip this node if result.status == _common.InferenceStatus.PARTIAL: - logger.info(f"Partial inference for node {node.op_type}: {result.msg}") + logger.info("Partial inference for node %s: %s", node.op_type, result.msg) # Continue with partial results if result.values is None: @@ -139,8 +139,8 @@ def _find_inferrer(self, node: ir.Node, model: ir.Model) -> _common.NodeInferrer if not suitable_inferrers: logger.warning( - f"No inferrer supports opset {opset_version} for {node.op_type} " - f"(domain: {node.domain})" + "No inferrer supports opset %s for %s (domain: %s)", + opset_version, node.op_type, node.domain ) return None @@ -251,7 +251,8 @@ def _reconcile_shapes(self, shape1: ir.Shape, shape2: ir.Shape) -> ir.Shape: """ if len(shape1) != len(shape2): logger.warning( - f"Shape rank mismatch: {len(shape1)} vs {len(shape2)}. Using first shape." + "Shape rank mismatch: %s vs %s. Using first shape.", + len(shape1), len(shape2) ) return shape1 From ab09107cd6d0671310a5555e33da41e5628652e1 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Mon, 30 Jun 2025 12:00:30 -0700 Subject: [PATCH 30/31] Working on engine Signed-off-by: Justin Chu --- src/onnx_ir/_shape_type_inference/_common.py | 6 +- src/onnx_ir/_shape_type_inference/_engine.py | 57 ++++++++----------- .../_shape_type_inference/ops/matmul.py | 4 +- 3 files changed, 31 insertions(+), 36 deletions(-) diff --git a/src/onnx_ir/_shape_type_inference/_common.py b/src/onnx_ir/_shape_type_inference/_common.py index 5a0fea9a..d2da4671 100644 --- a/src/onnx_ir/_shape_type_inference/_common.py +++ b/src/onnx_ir/_shape_type_inference/_common.py @@ -76,17 +76,21 @@ class NodeInferrer(abc.ABC): This class provides a common interface for all node inferrers. """ - def __init__(self, op_type: str, opsets: Collection[int], domain: str = "") -> None: + def __init__( + self, op_type: str, opsets: Collection[int], domain: str = "", overload: str = "" + ) -> None: """Initialize the node inferrer. Args: op_type: The type of the operation. opsets: A collection of ONNX opset versions supported by this inferrer. domain: The domain of the operation, default is an empty string. + overload: The overload identifier for the operation, default is an empty string. """ self.op_type = op_type self.opsets = opsets self.domain = domain + self.overload = overload def __repr__(self) -> str: """Return a string representation of the node inferrer.""" diff --git a/src/onnx_ir/_shape_type_inference/_engine.py b/src/onnx_ir/_shape_type_inference/_engine.py index 257166a8..8dc4c57b 100644 --- a/src/onnx_ir/_shape_type_inference/_engine.py +++ b/src/onnx_ir/_shape_type_inference/_engine.py @@ -4,7 +4,7 @@ import enum import logging -from collections.abc import Sequence +from collections.abc import Iterable, Sequence import onnx_ir as ir from onnx_ir._shape_type_inference import _common @@ -30,7 +30,7 @@ class SymbolicInferenceEngine: def __init__( self, - node_inferrers: Sequence[_common.NodeInferrer], + node_inferrers: Iterable[_common.NodeInferrer], reconciliation_policy: str = "reconcile", ) -> None: """Initialize the symbolic inference engine. @@ -40,16 +40,12 @@ def __init__( reconciliation_policy: Policy for handling conflicts between inferred and existing values. """ self.reconciliation_policy = ReconciliationPolicy(reconciliation_policy) - self._inferrer_registry: dict[tuple[str, str], list[_common.NodeInferrer]] = {} + self._inferrer_registry: dict[ir.OperatorIdentifier, list[_common.NodeInferrer]] = {} # Register inferrers by (op_type, domain) for inferrer in node_inferrers: - key = (inferrer.op_type, inferrer.domain) - if key not in self._inferrer_registry: - self._inferrer_registry[key] = [] - self._inferrer_registry[key].append(inferrer) - - logger.info("Initialized inference engine with %s inferrers", len(node_inferrers)) + key = (inferrer.domain, inferrer.op_type, inferrer.overload) + self._inferrer_registry.setdefault(key, []).append(inferrer) def infer_model(self, model: ir.Model) -> None: """Perform shape and type inference on an entire model. @@ -60,10 +56,10 @@ def infer_model(self, model: ir.Model) -> None: Raises: InferenceError: If inference fails for any node. """ - logger.info("Starting inference on model with %s nodes", len(model.graph.nodes)) + logger.info("Starting inference on model with %s nodes", len(model.graph)) # Process nodes in topological order - for i, node in enumerate(model.graph.nodes): + for i, node in enumerate(model.graph): try: self._infer_node(node, model) logger.debug("Successfully inferred node %s: %s", i, node.op_type) @@ -123,14 +119,24 @@ def _find_inferrer(self, node: ir.Node, model: ir.Model) -> _common.NodeInferrer Returns: The best matching inferrer, or None if no suitable inferrer is found. """ - key = (node.op_type, node.domain) + key = (node.domain, node.op_type, node.overload) inferrers = self._inferrer_registry.get(key, []) if not inferrers: return None # Get model opset version for this domain - opset_version = self._get_opset_version(model, node.domain) + if node.version is not None: + opset_version = node.version + elif node.graph is not None and node.domain in node.graph.opset_imports: + opset_version = node.graph.opset_imports[node.domain] + else: + # Fallback to model-level opset import + if node.domain not in model.opset_imports: + raise InferenceError( + f"No opset import found for domain '{node.domain}' in model" + ) + opset_version = model.opset_imports[node.domain] # Find inferrers that support this opset version suitable_inferrers = [ @@ -140,31 +146,15 @@ def _find_inferrer(self, node: ir.Node, model: ir.Model) -> _common.NodeInferrer if not suitable_inferrers: logger.warning( "No inferrer supports opset %s for %s (domain: %s)", - opset_version, node.op_type, node.domain + opset_version, + node.op_type, + node.domain, ) return None # Return the first suitable inferrer (could be enhanced with priority logic) return suitable_inferrers[0] - def _get_opset_version(self, model: ir.Model, domain: str) -> int: - """Get the opset version for a given domain in the model. - - Args: - model: The model to check. - domain: The domain to get the opset version for. - - Returns: - The opset version for the domain. - """ - # Look for opset import for this domain - for opset_import in model.opset_imports: - if opset_import.domain == domain: - return opset_import.version - - # Default to a high version if not found - return 999 - def _reconcile_outputs(self, node: ir.Node, inferred_values: Sequence[ir.Value]) -> None: """Reconcile inferred output values with existing node outputs. @@ -251,8 +241,7 @@ def _reconcile_shapes(self, shape1: ir.Shape, shape2: ir.Shape) -> ir.Shape: """ if len(shape1) != len(shape2): logger.warning( - "Shape rank mismatch: %s vs %s. Using first shape.", - len(shape1), len(shape2) + "Shape rank mismatch: %s vs %s. Using first shape.", len(shape1), len(shape2) ) return shape1 diff --git a/src/onnx_ir/_shape_type_inference/ops/matmul.py b/src/onnx_ir/_shape_type_inference/ops/matmul.py index d7ec1366..45d31ea9 100644 --- a/src/onnx_ir/_shape_type_inference/ops/matmul.py +++ b/src/onnx_ir/_shape_type_inference/ops/matmul.py @@ -51,7 +51,9 @@ def infer(self, node: ir.Node) -> _common.InferenceResult: rhs_batch = rhs_shape[:-2] if lhs_batch and rhs_batch: # TODO(justinchuby): Ensure this is correct - batch_shape = broadcast_shapes_bidirectional(ir.Shape(lhs_batch), ir.Shape(rhs_batch)) + batch_shape = broadcast_shapes_bidirectional( + ir.Shape(lhs_batch), ir.Shape(rhs_batch) + ) output_dims = [*batch_shape, lhs_shape[-2], rhs_shape[-1]] output_shape = ir.Shape(output_dims) elif lhs_batch: From 925623390c6614b371350c62467a0b57c7d972be Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Mon, 30 Jun 2025 13:15:18 -0700 Subject: [PATCH 31/31] todo Signed-off-by: Justin Chu --- src/onnx_ir/_shape_type_inference/_engine.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/onnx_ir/_shape_type_inference/_engine.py b/src/onnx_ir/_shape_type_inference/_engine.py index 8dc4c57b..bdf8e2ba 100644 --- a/src/onnx_ir/_shape_type_inference/_engine.py +++ b/src/onnx_ir/_shape_type_inference/_engine.py @@ -91,6 +91,7 @@ def _infer_node(self, node: ir.Node, model: ir.Model) -> None: result = inferrer.infer(node) if result.status == _common.InferenceStatus.INVALID_NODE: + # TODO: Print the node information raise InferenceError(f"Invalid node: {result.msg}") if result.status == _common.InferenceStatus.MISSING_INFO: