Skip to content

Commit

Permalink
feat: support aten::__and__.bool evaluator
Browse files Browse the repository at this point in the history
Signed-off-by: Ruoqian Guo <ruoqiang@nvidia.com>
  • Loading branch information
ruoqianguo committed Oct 29, 2021
1 parent 5643972 commit 6d73e43
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 1 deletion.
7 changes: 6 additions & 1 deletion core/conversion/evaluators/aten.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,12 @@ DEFINE_GENERIC_TWO_INPUT_EVALUATOR(
"aten::ge.float_int(float a, int b) -> (bool)",
}));

DEFINE_TWO_INPUT_SIMPLE_EVALUATOR(and, "aten::__and__", a&& b, bool, {"aten::__and__(int a, int b) -> (bool)"});
DEFINE_TWO_INPUT_SIMPLE_EVALUATOR(
and,
"aten::__and__",
a&& b,
bool,
std::set<std::string>({"aten::__and__(int a, int b) -> (bool)", "aten::__and__.bool(bool a, bool b) -> (bool)"}));
DEFINE_TWO_INPUT_SIMPLE_EVALUATOR(or, "aten::__or__", a || b, bool, {"aten::__or__(int a, int b) -> (bool)"});
DEFINE_TWO_INPUT_SIMPLE_EVALUATOR(
xor,
Expand Down
34 changes: 34 additions & 0 deletions tests/core/conversion/evaluators/test_aten_evaluators.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -540,5 +540,39 @@ TEST(Evaluators, EqStrResultIsFalseEvaluatesCorrectly) {
auto jit_results = trtorch::tests::util::EvaluateGraphJIT(g, {});
auto trt_results = trtorch::tests::util::EvaluateGraph(g->block(), {});

ASSERT_TRUE(jit_results[0] == trt_results[0]);
}

TEST(Evaluators, AndBoolResultIsTrueEvaluatesCorrectly) {
const auto graph = R"IR(
graph():
%1 : bool = prim::Constant[value=1]()
%2 : bool = prim::Constant[value=1]()
%3 : bool = aten::__and__(%1, %2)
return (%3))IR";

auto g = std::make_shared<torch::jit::Graph>();
torch::jit::parseIR(graph, g.get());

auto jit_results = trtorch::tests::util::EvaluateGraphJIT(g, {});
auto trt_results = trtorch::tests::util::EvaluateGraph(g->block(), {});

ASSERT_TRUE(jit_results[0] == trt_results[0]);
}

TEST(Evaluators, AndBoolResultIsFalseEvaluatesCorrectly) {
const auto graph = R"IR(
graph():
%1 : bool = prim::Constant[value=1]()
%2 : bool = prim::Constant[value=0]()
%3 : bool = aten::__and__(%1, %2)
return (%3))IR";

auto g = std::make_shared<torch::jit::Graph>();
torch::jit::parseIR(graph, g.get());

auto jit_results = trtorch::tests::util::EvaluateGraphJIT(g, {});
auto trt_results = trtorch::tests::util::EvaluateGraph(g->block(), {});

ASSERT_TRUE(jit_results[0] == trt_results[0]);
}

0 comments on commit 6d73e43

Please sign in to comment.