Skip to content

Commit

Permalink
remove bool cast
Browse files Browse the repository at this point in the history
  • Loading branch information
grimoire committed Jun 27, 2022
1 parent 6ee6e66 commit e3af5bc
Showing 1 changed file with 13 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,16 @@ using torch::jit::Block;
using torch::jit::IValue;
using torch::jit::Node;

bool RemoveBoolCast(Node* node) {
auto bottom_node = node->input()->node();
if (bottom_node->kind() != Symbol::onnx("Greater") &&
bottom_node->kind() != Symbol::onnx("Less")) {
return false;
}
node->output()->replaceAllUsesWith(bottom_node->output());
return true;
}

bool FuseSelectAssign(Node* node, std::unordered_map<std::string, Tensor>& params,
std::unordered_map<std::string, Value*>& vmap, SubgraphMatcher& matcher) {
auto values_map = matcher.values_map();
Expand Down Expand Up @@ -106,7 +116,9 @@ void FuseSelectAssign(Block* block, std::unordered_map<std::string, Tensor>& par
FuseSelectAssign(block, params, vmap, matcher);
}

if (matcher.matchesSubgraphFromAnchorNode(node)) {
if (node->kind() == Symbol::onnx("Cast") && node->i(Symbol::attr("to")) == 9) {
RemoveBoolCast(node);
} else if (matcher.matchesSubgraphFromAnchorNode(node)) {
FuseSelectAssign(node, params, vmap, matcher);
}
}
Expand Down

0 comments on commit e3af5bc

Please sign in to comment.