1+ #include " torch/torch.h"
2+ #include " core/util/prelude.h"
3+ #include " core/conversion/converters/converters.h"
4+ #include " core/conversion/tensorcontainer/TensorContainer.h"
5+ #include " NvInfer.h"
6+
7+ #include < ATen/ATen.h>
8+ #include < vector>
9+
10+ namespace trtorch {
11+ namespace core {
12+ namespace conversion {
13+ namespace converters {
14+ namespace impl {
15+ namespace {
16+
17+ auto stack_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns()
18+ .pattern({
19+ " aten::stack(Tensor[] tensors, int dim=0) -> (Tensor)" ,
20+ [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
21+ auto in = args[0 ].IValue ()->toListRef ();
22+ auto dim = args[1 ].unwrapToInt ();
23+
24+ std::vector<nvinfer1::ITensor*> tensors;
25+
26+ for (auto t : in) {
27+ nvinfer1::ITensor* itensor;
28+
29+ if (t.isTensor ()) {
30+ auto weight = Weights (ctx, t.toTensor ());
31+
32+ auto const_layer = ctx->net ->addConstant (weight.shape , weight.data );
33+ TRTORCH_CHECK (const_layer, " Unable to create constant layer from node: " << *n);
34+
35+ itensor = const_layer->getOutput (0 );
36+ } else {
37+ auto cont = t.toCustomClass <TensorContainer>();
38+ itensor = cont->tensor ();
39+ }
40+
41+ auto shuffle_layer = ctx->net ->addShuffle (*itensor);
42+ TRTORCH_CHECK (shuffle_layer, " Unable to create shuffle layer from node: " << *n);
43+ shuffle_layer->setReshapeDimensions (util::unsqueezeDims (itensor->getDimensions (), dim));
44+
45+ tensors.push_back (shuffle_layer->getOutput (0 ));
46+ }
47+
48+ auto concat_layer = ctx->net ->addConcatenation (tensors.data (), tensors.size ());
49+ TRTORCH_CHECK (concat_layer, " Unable to create concatenation layer from node: " << *n);
50+ concat_layer->setAxis (static_cast <int >(dim));
51+ auto out = ctx->AssociateValueAndTensor (n->outputs ()[0 ], concat_layer->getOutput (0 ));
52+
53+ LOG_DEBUG (" Output tensor shape: " << out->getDimensions ());
54+
55+ return true ;
56+ }
57+ });
58+
59+ } // namespace
60+ } // namespace impl
61+ } // namespace converters
62+ } // namespace conversion
63+ } // namespace core
64+ } // namespace trtorch
0 commit comments