diff --git a/core/conversion/evaluators/aten.cpp b/core/conversion/evaluators/aten.cpp index 66dc93d9df..3715d91adc 100644 --- a/core/conversion/evaluators/aten.cpp +++ b/core/conversion/evaluators/aten.cpp @@ -98,7 +98,12 @@ DEFINE_GENERIC_TWO_INPUT_EVALUATOR( "aten::ge.float_int(float a, int b) -> (bool)", })); -DEFINE_TWO_INPUT_SIMPLE_EVALUATOR(and, "aten::__and__", a&& b, bool, {"aten::__and__(int a, int b) -> (bool)"}); +DEFINE_TWO_INPUT_SIMPLE_EVALUATOR( + and, + "aten::__and__", + a&& b, + bool, + std::set({"aten::__and__(int a, int b) -> (bool)", "aten::__and__.bool(bool a, bool b) -> (bool)"})); DEFINE_TWO_INPUT_SIMPLE_EVALUATOR(or, "aten::__or__", a || b, bool, {"aten::__or__(int a, int b) -> (bool)"}); DEFINE_TWO_INPUT_SIMPLE_EVALUATOR( xor, diff --git a/tests/core/conversion/evaluators/test_aten_evaluators.cpp b/tests/core/conversion/evaluators/test_aten_evaluators.cpp index e6ed40072b..94631ebe9c 100644 --- a/tests/core/conversion/evaluators/test_aten_evaluators.cpp +++ b/tests/core/conversion/evaluators/test_aten_evaluators.cpp @@ -540,5 +540,39 @@ TEST(Evaluators, EqStrResultIsFalseEvaluatesCorrectly) { auto jit_results = trtorch::tests::util::EvaluateGraphJIT(g, {}); auto trt_results = trtorch::tests::util::EvaluateGraph(g->block(), {}); + ASSERT_TRUE(jit_results[0] == trt_results[0]); +} + +TEST(Evaluators, AndBoolResultIsTrueEvaluatesCorrectly) { + const auto graph = R"IR( + graph(): + %1 : bool = prim::Constant[value=1]() + %2 : bool = prim::Constant[value=1]() + %3 : bool = aten::__and__(%1, %2) + return (%3))IR"; + + auto g = std::make_shared(); + torch::jit::parseIR(graph, g.get()); + + auto jit_results = trtorch::tests::util::EvaluateGraphJIT(g, {}); + auto trt_results = trtorch::tests::util::EvaluateGraph(g->block(), {}); + + ASSERT_TRUE(jit_results[0] == trt_results[0]); +} + +TEST(Evaluators, AndBoolResultIsFalseEvaluatesCorrectly) { + const auto graph = R"IR( + graph(): + %1 : bool = prim::Constant[value=1]() + %2 : bool = prim::Constant[value=0]() + %3 : bool = aten::__and__(%1, %2) + return (%3))IR"; + + auto g = std::make_shared(); + torch::jit::parseIR(graph, g.get()); + + auto jit_results = trtorch::tests::util::EvaluateGraphJIT(g, {}); + auto trt_results = trtorch::tests::util::EvaluateGraph(g->block(), {}); + ASSERT_TRUE(jit_results[0] == trt_results[0]); } \ No newline at end of file