Skip to content

Commit

Permalink
feat(aten::ones): Adding support for aten::ones
Browse files Browse the repository at this point in the history
Signed-off-by: Naren Dasan <naren@narendasan.com>
Signed-off-by: Naren Dasan <narens@nvidia.com>
  • Loading branch information
narendasan committed May 27, 2021
1 parent a7d2b5e commit 2b45a3d
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 0 deletions.
14 changes: 14 additions & 0 deletions core/conversion/evaluators/aten.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,20 @@ auto aten_registrations TRTORCH_UNUSED =
auto out_tensor = torch::zeros(args.at(n->input(0)).unwrapToIntList().vec(), options);
return out_tensor;
}})
.evaluator({c10::Symbol::fromQualString("aten::ones"),
// aten::ones(int[] size, *, int? dtype=None, int? layout=None,
// Device? device=None, bool? pin_memory=None) -> (Tensor)
[](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
auto options = torch::TensorOptions().layout(torch::kStrided).device(torch::kCUDA);

// Input 1 here is the dtype
if (!args.at(n->input(1)).isNone() && !args.at(n->input(1)).IValue()->isNone()) {
options = options.dtype(c10::ScalarType(args.at(n->input(1)).unwrapToInt()));
}

auto out_tensor = torch::ones(args.at(n->input(0)).unwrapToIntList().vec(), options);
return out_tensor;
}})
.evaluator({c10::Symbol::fromQualString("aten::slice"),
[](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
c10::List<c10::IValue> list = args.at(n->input(0)).IValue()->to<c10::List<c10::IValue>>();
Expand Down
39 changes: 39 additions & 0 deletions tests/core/conversion/evaluators/test_aten_evaluators.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,45 @@ TEST(Evaluators, DivFloatEvaluatesCorrectly) {
ASSERT_TRUE(jit_results[0] == trt_results[0]);
}

TEST(Evaluators, OnesEvaluatesCorrectly) {
const auto graph = R"IR(
graph(%x.1 : Tensor):
%2 : None = prim::Constant() # :0:0
%3 : int[] = aten::size(%x.1) # <string>:7:9
%z.1 : Tensor = aten::ones(%3, %2, %2, %2, %2) # experiments/test_zeros.py:8:12
return (%z.1))IR";

auto in = at::randint(1, 10, {1, 5, 5, 5}, {at::kCUDA});

auto g = std::make_shared<torch::jit::Graph>();
torch::jit::parseIR(graph, g.get());

auto jit_results = trtorch::tests::util::EvaluateGraphJIT(g, {in});
auto trt_results = trtorch::tests::util::EvaluateGraph(g->block(), {in});

ASSERT_TRUE(at::equal(jit_results[0].toTensor().to(at::kCUDA), trt_results[0].toTensor()));
}

TEST(Evaluators, OnesDataTypeEvaluatesCorrectly) {
const auto graph = R"IR(
graph(%x.1 : Tensor):
%2 : int = prim::Constant[value=5]() # :0:0 (Float16)
%3 : None = prim::Constant() # :0:0
%4 : int[] = aten::size(%x.1) # <string>:7:9
%z.1 : Tensor = aten::ones(%4, %2, %3, %3, %3) # experiments/test_zeros.py:8:12
return (%z.1))IR";

auto in = at::randint(1, 10, {1, 5, 5, 5}, {at::kCUDA});

auto g = std::make_shared<torch::jit::Graph>();
torch::jit::parseIR(graph, g.get());

auto jit_results = trtorch::tests::util::EvaluateGraphJIT(g, {in});
auto trt_results = trtorch::tests::util::EvaluateGraph(g->block(), {in});

ASSERT_TRUE(at::equal(jit_results[0].toTensor().to(at::kCUDA), trt_results[0].toTensor()));
}

TEST(Evaluators, ZerosEvaluatesCorrectly) {
const auto graph = R"IR(
graph(%x.1 : Tensor):
Expand Down

0 comments on commit 2b45a3d

Please sign in to comment.