@@ -14,27 +14,27 @@ namespace converters {
1414namespace impl {
1515namespace {
1616auto max_registrations TORCHTRT_UNUSED = RegisterNodeConversionPatterns().pattern(
17- {" aten::max.dim(Tensor self, int dim, bool keepdim=False) -> (Tensor values, Tensor indices)" ,
18- [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
19- auto self = args[0 ].ITensorOrFreeze (ctx);
20- auto dim = args[1 ].unwrapToInt ();
21- auto selfDim = util::toVec (self->getDimensions ());
22- if (dim < 0 ) {
23- dim = selfDim.size () + dim;
24- }
25- uint32_t shiftDim = 1 << dim;
26- auto TopKOperation = nvinfer1::TopKOperation::kMAX ;
27- auto new_layer = ctx->net ->addTopK (*self, TopKOperation, 1 , shiftDim);
28- TORCHTRT_CHECK (new_layer, " Unable to create max layer from node: " << *n);
17+ {" aten::max.dim(Tensor self, int dim, bool keepdim=False) -> (Tensor values, Tensor indices)" ,
18+ [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
19+ auto self = args[0 ].ITensorOrFreeze (ctx);
20+ auto dim = args[1 ].unwrapToInt ();
21+ auto selfDim = util::toVec (self->getDimensions ());
22+ if (dim < 0 ) {
23+ dim = selfDim.size () + dim;
24+ }
25+ uint32_t shiftDim = 1 << dim;
26+ auto TopKOperation = nvinfer1::TopKOperation::kMAX ;
27+ auto new_layer = ctx->net ->addTopK (*self, TopKOperation, 1 , shiftDim);
28+ TORCHTRT_CHECK (new_layer, " Unable to create max layer from node: " << *n);
2929
30- auto out0 = ctx->AssociateValueAndTensor (n->outputs ()[0 ], new_layer->getOutput (0 ));
31- auto out1 = ctx->AssociateValueAndTensor (n->outputs ()[1 ], new_layer->getOutput (1 ));
30+ auto out0 = ctx->AssociateValueAndTensor (n->outputs ()[0 ], new_layer->getOutput (0 ));
31+ auto out1 = ctx->AssociateValueAndTensor (n->outputs ()[1 ], new_layer->getOutput (1 ));
3232
33- LOG_DEBUG (" Output tensor(0) shape: " << out0->getDimensions ());
34- LOG_DEBUG (" Output tensor(1) shape: " << out1->getDimensions ());
33+ LOG_DEBUG (" Output tensor(0) shape: " << out0->getDimensions ());
34+ LOG_DEBUG (" Output tensor(1) shape: " << out1->getDimensions ());
3535
36- return true ;
37- }});
36+ return true ;
37+ }});
3838} // namespace
3939} // namespace impl
4040} // namespace converters
0 commit comments