@@ -71,35 +71,34 @@ auto select_registrations TRTORCH_UNUSED =
71
71
RegisterNodeConversionPatterns ()
72
72
.pattern({" aten::select.int(Tensor(a) self, int dim, int index) -> (Tensor(a))" ,
73
73
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
74
- auto in = args[0 ].ITensor ( );
74
+ auto in = args[0 ].ITensorOrFreeze (ctx );
75
75
auto maxDim = static_cast <int64_t >(in->getDimensions ().nbDims );
76
76
auto axis = args[1 ].unwrapToInt ();
77
77
axis = axis < 0 ? axis + maxDim : axis;
78
78
auto ind = (int32_t )args[2 ].unwrapToInt ();
79
79
80
80
// index to access needs to be an at::Tensor
81
81
at::Tensor indices = torch::tensor ({ind}).to (torch::kI32 );
82
- auto weights = Weights (ctx, indices);
83
-
84
- // IConstantLayer to convert indices from Weights to ITensor
85
- auto const_layer = ctx->net ->addConstant (weights.shape , weights.data );
86
- TRTORCH_CHECK (const_layer, " Unable to create constant layer from node: " << *n);
87
- auto const_out = const_layer->getOutput (0 );
82
+ auto const_out = tensor_to_const (ctx, indices);
88
83
89
84
// IGatherLayer takes in input tensor, the indices, and the axis
90
85
// of input tensor to take indices from
91
86
auto gather_layer = ctx->net ->addGather (*in, *const_out, axis);
92
87
TRTORCH_CHECK (gather_layer, " Unable to create gather layer from node: " << *n);
93
- auto gather_out = gather_layer->getOutput (0 );
88
+ auto out = gather_layer->getOutput (0 );
94
89
95
- // IShuffleLayer removes redundant dimensions
96
- auto shuffle_layer = ctx->net ->addShuffle (*gather_out);
97
- TRTORCH_CHECK (shuffle_layer, " Unable to create shuffle layer from node: " << *n);
98
- shuffle_layer->setReshapeDimensions (util::squeezeDims (gather_out->getDimensions (), axis));
99
- shuffle_layer->setName (util::node_info (n).c_str ());
100
- auto shuffle_out = shuffle_layer->getOutput (0 );
90
+ LOG_DEBUG (" Gather tensor shape: " << out->getDimensions ());
101
91
102
- auto out = ctx->AssociateValueAndTensor (n->outputs ()[0 ], shuffle_out);
92
+ if (out->getDimensions ().nbDims != 1 ) {
93
+ // IShuffleLayer removes redundant dimensions
94
+ auto shuffle_layer = ctx->net ->addShuffle (*out);
95
+ TRTORCH_CHECK (shuffle_layer, " Unable to create shuffle layer from node: " << *n);
96
+ shuffle_layer->setReshapeDimensions (util::squeezeDims (out->getDimensions (), axis));
97
+ shuffle_layer->setName (util::node_info (n).c_str ());
98
+ out = shuffle_layer->getOutput (0 );
99
+ }
100
+
101
+ out = ctx->AssociateValueAndTensor (n->outputs ()[0 ], out);
103
102
104
103
LOG_DEBUG (" Output tensor shape: " << out->getDimensions ());
105
104
@@ -253,15 +252,14 @@ auto select_registrations TRTORCH_UNUSED =
253
252
" aten::masked_fill.Scalar(Tensor self, Tensor mask, Scalar value) -> (Tensor)" ,
254
253
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
255
254
auto self = args[0 ].ITensorOrFreeze (ctx);
256
- LOG_DEBUG (args[1 ].unwrapToTensor ());
257
255
auto mask = castITensor (ctx, args[1 ].ITensorOrFreeze (ctx), nvinfer1::DataType::kBOOL );
256
+ mask = addPadding (ctx, n, mask, self->getDimensions ().nbDims , false , true );
258
257
auto val = args[2 ].unwrapToScalar ().to <float >();
259
- LOG_DEBUG (torch::full (util::toVec (self->getDimensions ()), val));
260
258
auto val_t = tensor_to_const (ctx, torch::full (util::toVec (self->getDimensions ()), val));
261
259
262
260
TRTORCH_CHECK (util::broadcastable (self->getDimensions (), mask->getDimensions (), /* multidirectional=*/ false ), " Self and mask tensors are not broadcastable" );
263
261
264
- auto new_layer = ctx->net ->addSelect (*mask, *self , *val_t );
262
+ auto new_layer = ctx->net ->addSelect (*mask, *val_t , *self );
265
263
TRTORCH_CHECK (new_layer, " Unable to create layer for aten::masked_fill" );
266
264
267
265
new_layer->setName (util::node_info (n).c_str ());
0 commit comments