Skip to content

Commit 415378e

Browse files
committed
feat(//core/conversion/converters/impl): added support for aten::stack
Signed-off-by: Abhiram Iyer <abhirami@nvidia.com> Signed-off-by: Abhiram Iyer <abhi.iyer.ai@gmail.com>
1 parent 97c8f52 commit 415378e

File tree

2 files changed

+66
-1
lines changed

2 files changed

+66
-1
lines changed

core/conversion/converters/BUILD

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,8 @@ cc_library(
2929
"impl/softmax.cpp",
3030
"impl/unary.cpp",
3131
"impl/interpolate.cpp",
32-
"impl/select.cpp"
32+
"impl/select.cpp",
33+
"impl/stack.cpp"
3334
],
3435
deps = [
3536
"@tensorrt//:nvinfer",
Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
#include "torch/torch.h"
2+
#include "core/util/prelude.h"
3+
#include "core/conversion/converters/converters.h"
4+
#include "core/conversion/tensorcontainer/TensorContainer.h"
5+
#include "NvInfer.h"
6+
7+
#include <ATen/ATen.h>
8+
#include <vector>
9+
10+
namespace trtorch {
11+
namespace core {
12+
namespace conversion {
13+
namespace converters {
14+
namespace impl {
15+
namespace {
16+
17+
auto stack_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns()
18+
.pattern({
19+
"aten::stack(Tensor[] tensors, int dim=0) -> (Tensor)",
20+
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
21+
auto in = args[0].IValue()->toListRef();
22+
auto dim = args[1].unwrapToInt();
23+
24+
std::vector<nvinfer1::ITensor*> tensors;
25+
26+
for (auto t : in) {
27+
nvinfer1::ITensor* itensor;
28+
29+
if (t.isTensor()) {
30+
auto weight = Weights(ctx, t.toTensor());
31+
32+
auto const_layer = ctx->net->addConstant(weight.shape, weight.data);
33+
TRTORCH_CHECK(const_layer, "Unable to create constant layer from node: " << *n);
34+
35+
itensor = const_layer->getOutput(0);
36+
} else {
37+
auto cont = t.toCustomClass<TensorContainer>();
38+
itensor = cont->tensor();
39+
}
40+
41+
auto shuffle_layer = ctx->net->addShuffle(*itensor);
42+
TRTORCH_CHECK(shuffle_layer, "Unable to create shuffle layer from node: " << *n);
43+
shuffle_layer->setReshapeDimensions(util::unsqueezeDims(itensor->getDimensions(), dim));
44+
45+
tensors.push_back(shuffle_layer->getOutput(0));
46+
}
47+
48+
auto concat_layer = ctx->net->addConcatenation(tensors.data(), tensors.size());
49+
TRTORCH_CHECK(concat_layer, "Unable to create concatenation layer from node: " << *n);
50+
concat_layer->setAxis(static_cast<int>(dim));
51+
auto out = ctx->AssociateValueAndTensor(n->outputs()[0], concat_layer->getOutput(0));
52+
53+
LOG_DEBUG("Output tensor shape: " << out->getDimensions());
54+
55+
return true;
56+
}
57+
});
58+
59+
} // namespace
60+
} // namespace impl
61+
} // namespace converters
62+
} // namespace conversion
63+
} // namespace core
64+
} // namespace trtorch

0 commit comments

Comments
 (0)