Skip to content
18 changes: 18 additions & 0 deletions python/tvm/relay/frontend/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -4565,6 +4565,23 @@ def _impl_v1(cls, inputs, attr, params):
"Attempting to unify ranks but this may produce incorrect results."
)
warnings.warn(warning_msg)
# Skip constant If node to avoid irrational broadcast
if isinstance(inputs[0], tvm.relay.expr.Constant):
predicate = inputs[0].data.asnumpy()[0]
node_name = attr["tvm_custom"]["name"]
warn_msg_begin = f"Predicate of If node {node_name} is always "
if predicate == np.bool_(True):
warnings.warn(
warn_msg_begin
+ "true so only then branch would be executed. Removing else branch. "
)
else_expr = then_expr
elif predicate == np.bool_(False):
warnings.warn(
warn_msg_begin
+ "false so only else branch would be executed. Removing then branch. "
)
then_expr = else_expr
if len(then_shape) < len(else_shape):
then_expr = _op.broadcast_to_like(then_expr, else_expr)
else:
Expand Down Expand Up @@ -6529,6 +6546,7 @@ def _impl_v11(cls, inputs, attr, params):
# compatible operators that do NOT require any conversion.
_identity_list = []


# _convert_map defines maps of name to converter functor(callable)
# for 1 to 1 mapping, use Renamer if nothing but name is different
# use AttrCvt if attributes need to be converted
Expand Down