Skip to content

Commit

Permalink
feat: Add converter files for torch::max
Browse files Browse the repository at this point in the history
Signed-off-by: hongwei03 <hongwei03@corp.netease.com>
  • Loading branch information
Yiran Liu authored and p517332051 committed Mar 23, 2022
1 parent 569bcde commit f628aca
Showing 1 changed file with 1 addition and 5 deletions.
6 changes: 1 addition & 5 deletions core/conversion/converters/impl/max.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,13 @@ auto max_registrations TORCHTRT_UNUSED = RegisterNodeConversionPatterns().patter
{"aten::max.dim(Tensor self, int dim, bool keepdim=False) -> (Tensor values, Tensor indices)",
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
auto self = args[0].ITensorOrFreeze(ctx);
auto k = 1;
auto dim = args[1].unwrapToInt();
auto largest = true;
auto selfDim = util::toVec(self->getDimensions());
if (dim < 0) {
dim = selfDim.size() + dim;
}
uint32_t shiftDim = 1 << dim;

auto TopKOperation = largest ? (nvinfer1::TopKOperation::kMAX) : (nvinfer1::TopKOperation::kMIN);

auto TopKOperation = nvinfer1::TopKOperation::kMAX;
auto new_layer = ctx->net->addTopK(*self, TopKOperation, 1, shiftDim);
TORCHTRT_CHECK(new_layer, "Unable to create max layer from node: " << *n);

Expand Down

0 comments on commit f628aca

Please sign in to comment.