Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement lowering of aten.atleast_2d #3546

Merged
merged 1 commit into from
Aug 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -10171,6 +10171,29 @@ def Torch_AtenAtleast1dOp : Torch_Op<"aten.atleast_1d", [
}];
}

def Torch_AtenAtleast2dOp : Torch_Op<"aten.atleast_2d", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::atleast_2d : (Tensor) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self
);
let results = (outs
AnyTorchOptionalTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenAtleast2dOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 1, 1);
}
void AtenAtleast2dOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 1, 1);
}
}];
}

def Torch_AtenEinsumOp : Torch_Op<"aten.einsum", [
AllowsTypeRefinement,
HasValueSemantics,
Expand Down
26 changes: 26 additions & 0 deletions lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10351,6 +10351,28 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" }\n"
" return %2 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.atleast_2d\"(%arg0: !torch.list<int>) -> !torch.list<int> {\n"
" %int0 = torch.constant.int 0\n"
" %int1 = torch.constant.int 1\n"
" %0 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int\n"
" %1 = torch.aten.eq.int %0, %int0 : !torch.int, !torch.int -> !torch.bool\n"
" %2 = torch.prim.If %1 -> (!torch.list<int>) {\n"
" %3 = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int>\n"
" torch.prim.If.yield %3 : !torch.list<int>\n"
" } else {\n"
" %3 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int\n"
" %4 = torch.aten.eq.int %3, %int1 : !torch.int, !torch.int -> !torch.bool\n"
" %5 = torch.prim.If %4 -> (!torch.list<int>) {\n"
" %6 = torch.aten.__getitem__.t %arg0, %int0 : !torch.list<int>, !torch.int -> !torch.int\n"
" %7 = torch.prim.ListConstruct %int1, %6 : (!torch.int, !torch.int) -> !torch.list<int>\n"
" torch.prim.If.yield %7 : !torch.list<int>\n"
" } else {\n"
" torch.prim.If.yield %arg0 : !torch.list<int>\n"
" }\n"
" torch.prim.If.yield %5 : !torch.list<int>\n"
" }\n"
" return %2 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.stack\"(%arg0: !torch.list<list<int>>, %arg1: !torch.int) -> !torch.list<int> {\n"
" %0 = call @__torch__.torch.jit._shape_functions.stack(%arg0, %arg1) : (!torch.list<list<int>>, !torch.int) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
Expand Down Expand Up @@ -14558,6 +14580,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" return %0#1 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.atleast_2d\"(%arg0: !torch.tuple<int, int>) -> !torch.int {\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" return %0#1 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.einsum\"(%arg0: !torch.str, %arg1: !torch.list<tuple<int, int>>, %arg2: !torch.optional<list<int>>) -> !torch.int {\n"
" %true = torch.constant.bool true\n"
" %none = torch.constant.none\n"
Expand Down
48 changes: 48 additions & 0 deletions lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1514,6 +1514,53 @@ class DecomposeAtenAtleast1dOp : public OpRewritePattern<AtenAtleast1dOp> {
};
} // namespace

namespace {
// Decompose aten.atleast_2d into: aten.reshape. See
// https://github.com/pytorch/pytorch/blob/9a8ab778d34bd24c5caceb340837483decc4c311/torch/_refs/__init__.py#L2604
// def atleast_2d(
// arg: Union[TensorLikeType, Sequence[TensorLikeType]], *args:
// TensorLikeType
// ) -> Union[TensorLikeType, Tuple[TensorLikeType, ...]]:
// """Reference implementation of :func:`torch.atleast_2d`."""
// if not args and isinstance(arg, collections.abc.Sequence):
// args_ = arg
// else:
// assert not isinstance(arg, collections.abc.Sequence)
// args_ = (arg,) + args
// unsqueeze_atleast_1d = partial(_unsqueeze_atleast, atleast_1d, 0)
// res = tuple(a if a.ndim >= 2 else unsqueeze_atleast_1d(a) for a in args_)
// return res if len(res) > 1 else res[0]
class DecomposeAtenAtleast2dOp : public OpRewritePattern<AtenAtleast2dOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AtenAtleast2dOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
Value input = op.getSelf();
Type opType = op.getType();

auto inputType = cast<BaseTensorType>(input.getType());
SmallVector<int64_t> inputShape(inputType.getSizes());

if (inputShape.size() >= 2) {
rewriter.replaceOp(op, input);
return success();
}
auto atleast1dResShape =
inputShape.empty() ? SmallVector<int64_t, 1>{1} : inputShape;
auto atleast1dResType = rewriter.getType<ValueTensorType>(
atleast1dResShape, inputType.getOptionalDtype());
auto atleast1dRes =
rewriter.create<AtenAtleast1dOp>(loc, atleast1dResType, input);
Value zero = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(0));
rewriter.replaceOpWithNewOp<AtenUnsqueezeOp>(op, opType, atleast1dRes,
zero);
return success();
}
};
} // namespace

namespace {
// Decompose AtenEinsumOp to AtenMatmulOp, and supports possible reduce
// operation and permute operation. Currently, this pass doesn't support
Expand Down Expand Up @@ -8967,6 +9014,7 @@ class DecomposeComplexOpsPass
addPatternIfTargetOpIsIllegal<DecomposeAtenRreluOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenCeluOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenAtleast1dOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenAtleast2dOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenEinsumOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenTraceOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenHardswishOp>(patterns);
Expand Down
1 change: 1 addition & 0 deletions lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -395,6 +395,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context,
target.addIllegalOp<Aten_SoftmaxBackwardDataOp>();
target.addIllegalOp<AtenTanhBackwardOp>();
target.addIllegalOp<AtenAtleast1dOp>();
target.addIllegalOp<AtenAtleast2dOp>();
target.addIllegalOp<AtenEinsumOp>();
target.addIllegalOp<AtenTraceOp>();
target.addIllegalOp<AtenAddmmOp>();
Expand Down
9 changes: 9 additions & 0 deletions projects/pt1/e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -842,6 +842,9 @@
"TypeConversionUint8ToF32Module_basic",
"Atleast1dModule0dInput_basic",
"Atleast1dModule1dInput_basic",
"Atleast2dModule0dInput_basic",
"Atleast2dModule1dInput_basic",
"Atleast2dModule2dInput_basic",
"AtenLinear1D_basic",
"AtenLinear2D_basic",
"AtenLinear3DBias_basic",
Expand Down Expand Up @@ -1513,6 +1516,9 @@
"TensorSplitSections_ListUnpackModule_basic",
"Atleast1dModule0dInput_basic",
"Atleast1dModule1dInput_basic",
"Atleast2dModule0dInput_basic",
"Atleast2dModule1dInput_basic",
"Atleast2dModule2dInput_basic",
"AtenLinear2D_basic",
"AtenLinear3DBias_basic",
"ElementwiseAddScalar_NumToTensorFloat_Module_basic",
Expand Down Expand Up @@ -1993,6 +1999,9 @@
"AtenLinearVecMatBias_basic",
"Atleast1dModule0dInput_basic",
"Atleast1dModule1dInput_basic",
"Atleast2dModule0dInput_basic",
"Atleast2dModule1dInput_basic",
"Atleast2dModule2dInput_basic",
"MaxPool1dEmptyStrideStaticModule_basic",
"MaxPool1dStaticCeilModeTrueModule_basic",
"MaxPool1dStaticModule_basic",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2066,6 +2066,15 @@ def aten〇atleast_1d〡shape(self: List[int]) -> List[int]:
else:
return self

def aten〇atleast_2d〡shape(self: List[int]) -> List[int]:
if len(self) == 0:
return [1, 1]
elif len(self) == 1:
x = self[0]
return [1, x]
else:
return self

def aten〇stack〡shape(tensors: List[List[int]], dim: int = 0) -> List[int]:
return upstream_shape_functions.stack(tensors, dim)

Expand Down Expand Up @@ -5118,6 +5127,11 @@ def aten〇atleast_1d〡dtype(self_rank_dtype: Tuple[int, int]) -> int:
self_rank, self_dtype = self_rank_dtype
return self_dtype

@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1))
def aten〇atleast_2d〡dtype(self_rank_dtype: Tuple[int, int]) -> int:
self_rank, self_dtype = self_rank_dtype
return self_dtype

@check_dtype_function(
[Invocation("i,j->ij", [TensorOfShape(1, dtype=torch.float32),
TensorOfShape(1, dtype=torch.int32)]),])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -784,6 +784,7 @@ def emit_with_mutating_variants(key, **kwargs):
emit("aten::argmin : (Tensor, int?, bool) -> (Tensor)")
emit("aten::one_hot : (Tensor, int) -> (Tensor)")
emit("aten::atleast_1d : (Tensor) -> (Tensor)")
emit("aten::atleast_2d : (Tensor) -> (Tensor)")
emit("aten::einsum : (str, Tensor[], int[]?) -> (Tensor)")
emit("aten::trace : (Tensor) -> (Tensor)")
emit("aten::bucketize.Tensor : (Tensor, Tensor, bool, bool) -> (Tensor)")
Expand Down
63 changes: 63 additions & 0 deletions projects/pt1/python/torch_mlir_e2e_test/test_suite/reshape_like.py
Original file line number Diff line number Diff line change
Expand Up @@ -1504,3 +1504,66 @@ def forward(self, x):
@register_test_case(module_factory=lambda: Atleast1dModule1dInput())
def Atleast1dModule1dInput_basic(module, tu: TestUtils):
module.forward(tu.rand(4))


# ==============================================================================


class Atleast2dModule0dInput(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args(
[
None,
([], torch.float32, True),
]
)
def forward(self, x):
return torch.ops.aten.atleast_2d(x)


@register_test_case(module_factory=lambda: Atleast2dModule0dInput())
def Atleast2dModule0dInput_basic(module, tu: TestUtils):
module.forward(tu.rand())


class Atleast2dModule1dInput(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args(
[
None,
([4], torch.float32, True),
]
)
def forward(self, x):
return torch.ops.aten.atleast_2d(x)


@register_test_case(module_factory=lambda: Atleast2dModule1dInput())
def Atleast2dModule1dInput_basic(module, tu: TestUtils):
module.forward(tu.rand(4))


class Atleast2dModule2dInput(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args(
[
None,
([4, 4], torch.float32, True),
]
)
def forward(self, x):
return torch.ops.aten.atleast_2d(x)


@register_test_case(module_factory=lambda: Atleast2dModule2dInput())
def Atleast2dModule2dInput_basic(module, tu: TestUtils):
module.forward(tu.rand(4, 4))
Loading