Skip to content

Commit

Permalink
[OnnxToTorch][GridSample] Add support for border padding mode
Browse files Browse the repository at this point in the history
  • Loading branch information
Ax9D committed Oct 25, 2024
1 parent 6946b24 commit 45947d3
Showing 1 changed file with 11 additions and 3 deletions.
14 changes: 11 additions & 3 deletions lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -140,12 +140,19 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
}

std::string padding;
int64_t paddingModeInt;
if (binder.customOpNameStringAttr(padding, "padding_mode", "zeros"))
return rewriter.notifyMatchFailure(binder.op,
"padding_mode bind failure");
if (padding != "zeros")
if (padding == "zeros") {
paddingModeInt = 0;
} else if (padding == "border") {
paddingModeInt = 1;
} else {
return rewriter.notifyMatchFailure(
binder.op, "currently only padding_mode : zeros supported");
binder.op,
"currently only padding_mode : zeros and border supported");
}
int64_t align;
if (binder.s64IntegerAttr(align, "align_corners", 0))
return rewriter.notifyMatchFailure(binder.op,
Expand All @@ -157,7 +164,8 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(

Value paddingMode = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getType<Torch::IntType>(),
rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0));
rewriter.getIntegerAttr(rewriter.getIntegerType(64),
paddingModeInt));

bool alignMode = align;
Value alignCorners = rewriter.create<Torch::ConstantBoolOp>(
Expand Down

0 comments on commit 45947d3

Please sign in to comment.