Skip to content

Commit

Permalink
feat(aten::softmax): Adding support for any neg index
Browse files Browse the repository at this point in the history
Signed-off-by: Naren Dasan <naren@narendasan.com>
Signed-off-by: Naren Dasan <narens@nvidia.com>
  • Loading branch information
narendasan committed Feb 8, 2021
1 parent 658ee4f commit abc29a2
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 3 deletions.
4 changes: 2 additions & 2 deletions core/conversion/converters/impl/softmax.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@ static auto softmax_registrations TRTORCH_UNUSED = RegisterNodeConversionPattern

int64_t dim = args[1].IValue()->toInt();
LOG_DEBUG("Softmax original dim " << dim);
if (dim == -1) {
dim = shape.size() - 1;
if (dim < 0) {
dim = shape.size() + dim;
}
LOG_DEBUG("Softmax converted dim " << dim);
auto softmax = ctx->net->addSoftMax(*in);
Expand Down
28 changes: 27 additions & 1 deletion tests/core/conversion/converters/test_softmax.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ TEST(Converters, ATenSoftmaxNDConvertsCorrectlyAbove3DIndex) {
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt, 2e-6));
}

TEST(Converters, ATenSoftmaxNDConvertsCorrectlyNegtiveIndex) {
TEST(Converters, ATenSoftmaxNDConvertsCorrectlyNegtiveOneIndex) {
const auto graph = R"IR(
graph(%0 : Tensor):
%1 : None = prim::Constant()
Expand All @@ -100,5 +100,31 @@ TEST(Converters, ATenSoftmaxNDConvertsCorrectlyNegtiveIndex) {

auto trt = trt_results[0].reshape_as(jit_results[0]);

ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt, 2e-6));
}

TEST(Converters, ATenSoftmaxNDConvertsCorrectlyNegtiveIndex) {
const auto graph = R"IR(
graph(%0 : Tensor):
%1 : None = prim::Constant()
%2 : int = prim::Constant[value=-2]()
%3 : Tensor = aten::softmax(%0, %2, %1)
return (%3))IR";

auto g = std::make_shared<torch::jit::Graph>();
torch::jit::parseIR(graph, &*g);

auto in = at::randint(0, 5, {1, 2, 2, 2, 2}, {at::kCUDA});

auto jit_in = at::clone(in);
auto params = trtorch::core::conversion::get_named_params(g->inputs(), {});
auto jit_results = trtorch::tests::util::RunGraph(g, params, {jit_in});

auto trt_in = at::clone(in);
params = trtorch::core::conversion::get_named_params(g->inputs(), {});
auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {trt_in});

auto trt = trt_results[0].reshape_as(jit_results[0]);

ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt, 2e-6));
}

0 comments on commit abc29a2

Please sign in to comment.