diff --git a/src/relax/transform/fuse_ops.cc b/src/relax/transform/fuse_ops.cc index 434a7e7653a5..bfc278b9c779 100644 --- a/src/relax/transform/fuse_ops.cc +++ b/src/relax/transform/fuse_ops.cc @@ -427,17 +427,18 @@ class FunctionCreator : public ExprMutator { } for (const Expr& arg : call->args) { - if (GetStructInfoAs(arg) != nullptr) { - // The argument is fully referenced. Thus we remove it from the mapping. - partially_used_tuple_params_.erase(arg.get()); - const Tuple& tup_args = Downcast(arg); - for (const Expr& tup_arg : tup_args->fields) { + if (auto tuple = arg.as()) { + for (const Expr& tup_arg : tuple->fields) { CheckDefAndUpdateParam(tup_arg); ICHECK(GetStructInfoAs(tup_arg) == nullptr); } } else { CheckDefAndUpdateParam(arg); } + if (GetStructInfoAs(arg) != nullptr) { + // The argument is fully referenced. Thus we remove it from the mapping. + partially_used_tuple_params_.erase(arg.get()); + } } } } else if (var_binding->value.as()) {