diff --git a/src/relax/transform/fuse_ops.cc b/src/relax/transform/fuse_ops.cc index 434a7e7653a5..a222f44a0983 100644 --- a/src/relax/transform/fuse_ops.cc +++ b/src/relax/transform/fuse_ops.cc @@ -427,9 +427,7 @@ class FunctionCreator : public ExprMutator { } for (const Expr& arg : call->args) { - if (GetStructInfoAs(arg) != nullptr) { - // The argument is fully referenced. Thus we remove it from the mapping. - partially_used_tuple_params_.erase(arg.get()); + if (arg.as()) { const Tuple& tup_args = Downcast(arg); for (const Expr& tup_arg : tup_args->fields) { CheckDefAndUpdateParam(tup_arg); @@ -438,6 +436,10 @@ class FunctionCreator : public ExprMutator { } else { CheckDefAndUpdateParam(arg); } + if (GetStructInfoAs(arg) != nullptr) { + // The argument is fully referenced. Thus we remove it from the mapping. + partially_used_tuple_params_.erase(arg.get()); + } } } } else if (var_binding->value.as()) {