diff --git a/csrc/mmdeploy/backend_ops/torchscript/optimizer/passes/onnx/fuse_select_assign.cpp b/csrc/mmdeploy/backend_ops/torchscript/optimizer/passes/onnx/fuse_select_assign.cpp index 2ff4d09f70..01cf6e3e3d 100644 --- a/csrc/mmdeploy/backend_ops/torchscript/optimizer/passes/onnx/fuse_select_assign.cpp +++ b/csrc/mmdeploy/backend_ops/torchscript/optimizer/passes/onnx/fuse_select_assign.cpp @@ -13,6 +13,16 @@ using torch::jit::Block; using torch::jit::IValue; using torch::jit::Node; +bool RemoveBoolCast(Node* node) { + auto bottom_node = node->input()->node(); + if (bottom_node->kind() != Symbol::onnx("Greater") && + bottom_node->kind() != Symbol::onnx("Less")) { + return false; + } + node->output()->replaceAllUsesWith(bottom_node->output()); + return true; +} + bool FuseSelectAssign(Node* node, std::unordered_map& params, std::unordered_map& vmap, SubgraphMatcher& matcher) { auto values_map = matcher.values_map(); @@ -106,7 +116,9 @@ void FuseSelectAssign(Block* block, std::unordered_map& par FuseSelectAssign(block, params, vmap, matcher); } - if (matcher.matchesSubgraphFromAnchorNode(node)) { + if (node->kind() == Symbol::onnx("Cast") && node->i(Symbol::attr("to")) == 9) { + RemoveBoolCast(node); + } else if (matcher.matchesSubgraphFromAnchorNode(node)) { FuseSelectAssign(node, params, vmap, matcher); } }