1
+ #include " torch/torch.h"
2
+ #include " core/util/prelude.h"
3
+ #include " core/conversion/converters/converters.h"
4
+ #include " core/conversion/tensorcontainer/TensorContainer.h"
5
+ #include " NvInfer.h"
6
+
7
+ #include < ATen/ATen.h>
8
+ #include < vector>
9
+
10
+ namespace trtorch {
11
+ namespace core {
12
+ namespace conversion {
13
+ namespace converters {
14
+ namespace impl {
15
+ namespace {
16
+
17
+ auto stack_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns()
18
+ .pattern({
19
+ " aten::stack(Tensor[] tensors, int dim=0) -> (Tensor)" ,
20
+ [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
21
+ auto in = args[0 ].IValue ()->toListRef ();
22
+ auto dim = args[1 ].unwrapToInt ();
23
+
24
+ std::vector<nvinfer1::ITensor*> tensors;
25
+
26
+ for (auto t : in) {
27
+ nvinfer1::ITensor* itensor;
28
+
29
+ if (t.isTensor ()) {
30
+ auto weight = Weights (ctx, t.toTensor ());
31
+
32
+ auto const_layer = ctx->net ->addConstant (weight.shape , weight.data );
33
+ TRTORCH_CHECK (const_layer, " Unable to create constant layer from node: " << *n);
34
+
35
+ itensor = const_layer->getOutput (0 );
36
+ } else {
37
+ auto cont = t.toCustomClass <TensorContainer>();
38
+ itensor = cont->tensor ();
39
+ }
40
+
41
+ auto shuffle_layer = ctx->net ->addShuffle (*itensor);
42
+ TRTORCH_CHECK (shuffle_layer, " Unable to create shuffle layer from node: " << *n);
43
+ shuffle_layer->setReshapeDimensions (util::unsqueezeDims (itensor->getDimensions (), dim));
44
+
45
+ tensors.push_back (shuffle_layer->getOutput (0 ));
46
+ }
47
+
48
+ auto concat_layer = ctx->net ->addConcatenation (tensors.data (), tensors.size ());
49
+ TRTORCH_CHECK (concat_layer, " Unable to create concatenation layer from node: " << *n);
50
+ concat_layer->setAxis (static_cast <int >(dim));
51
+ auto out = ctx->AssociateValueAndTensor (n->outputs ()[0 ], concat_layer->getOutput (0 ));
52
+
53
+ LOG_DEBUG (" Output tensor shape: " << out->getDimensions ());
54
+
55
+ return true ;
56
+ }
57
+ });
58
+
59
+ } // namespace
60
+ } // namespace impl
61
+ } // namespace converters
62
+ } // namespace conversion
63
+ } // namespace core
64
+ } // namespace trtorch
0 commit comments