Skip to content

Commit

Permalink
feat(aten::__range_length): Adding range length evaluator
Browse files Browse the repository at this point in the history
Signed-off-by: Naren Dasan <naren@narendasan.com>
Signed-off-by: Naren Dasan <narens@nvidia.com>
  • Loading branch information
narendasan committed Feb 3, 2022
1 parent 4fd886d commit 11c4608
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 1 deletion.
20 changes: 19 additions & 1 deletion core/conversion/evaluators/aten.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -718,7 +718,25 @@ auto aten_registrations TORCHTRT_UNUSED =
torch::jit::pop(stack, output);
return output;
},
EvalOptions().validSchemas({"aten::format(str self, ...) -> (str)"})});
EvalOptions().validSchemas({"aten::format(str self, ...) -> (str)"})})
.evaluator({c10::Symbol::fromQualString("aten::__range_length"),
[](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
auto lo = args.at(n->input(0)).unwrapToInt();
auto hi = args.at(n->input(1)).unwrapToInt();
auto step = args.at(n->input(2)).unwrapToInt();

if (step == 0) {
TORCHTRT_THROW_ERROR("aten::__range_length() arg 3 must not be zero");
}
if (step > 0 && lo < hi) {
return 1 + (hi - 1 - lo) / step;
} else if (step < 0 && lo > hi) {
return 1 + (lo - 1 - hi) / (0 - step);
} else {
return 0;
}
},
EvalOptions().validSchemas({"aten::__range_length(int lo, int hi, int step) -> int"})});
} // namespace
} // namespace evaluators
} // namespace conversion
Expand Down
36 changes: 36 additions & 0 deletions tests/core/conversion/evaluators/test_aten_evaluators.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -665,4 +665,40 @@ TEST(Evaluators, AtenFormatRaiseExceptionEvaluatesCorrectly) {
} else {
ASSERT_TRUE(false);
}
}

TEST(Evaluators, RangeLengthEvaluatesCorrectly) {
const auto graph = R"IR(
graph():
%1 : int = prim::Constant[value=1]()
%2 : int = prim::Constant[value=10]()
%3 : int = prim::Constant[value=2]()
%4 : int = aten::__range_length(%1, %2, %3)
return (%3))IR";

auto g = std::make_shared<torch::jit::Graph>();
torch::jit::parseIR(graph, g.get());

auto jit_results = torch_tensorrt::tests::util::EvaluateGraphJIT(g, {});
auto trt_results = torch_tensorrt::tests::util::EvaluateGraph(g->block(), {});

ASSERT_TRUE(jit_results[0] == trt_results[0]);
}

TEST(Evaluators, RangeLengthNegEvaluatesCorrectly) {
const auto graph = R"IR(
graph():
%1 : int = prim::Constant[value=10]()
%2 : int = prim::Constant[value=1]()
%3 : int = prim::Constant[value=-2]()
%4 : int = aten::__range_length(%1, %2, %3)
return (%3))IR";

auto g = std::make_shared<torch::jit::Graph>();
torch::jit::parseIR(graph, g.get());

auto jit_results = torch_tensorrt::tests::util::EvaluateGraphJIT(g, {});
auto trt_results = torch_tensorrt::tests::util::EvaluateGraph(g->block(), {});

ASSERT_TRUE(jit_results[0] == trt_results[0]);
}

0 comments on commit 11c4608

Please sign in to comment.