From a7a02457920cc538e23e79d0d64ba571e5791f4f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Petar=20Kapri=C5=A1?= Date: Wed, 12 Jun 2024 16:11:31 +0200 Subject: [PATCH] Implement lowering of aten.atleast_2d --- .../Dialect/Torch/IR/GeneratedTorchOps.td | 23 +++++++ .../Transforms/AbstractInterpLibrary.cpp | 26 ++++++++ .../Torch/Transforms/DecomposeComplexOps.cpp | 48 ++++++++++++++ .../Transforms/LowerToBackendContract.cpp | 1 + projects/pt1/e2e_testing/xfail_sets.py | 9 +++ .../build_tools/abstract_interp_lib_gen.py | 14 +++++ .../build_tools/torch_ods_gen.py | 1 + .../test_suite/reshape_like.py | 63 +++++++++++++++++++ 8 files changed, 185 insertions(+) diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index 626e259fef10..75dd21d0acba 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -10171,6 +10171,29 @@ def Torch_AtenAtleast1dOp : Torch_Op<"aten.atleast_1d", [ }]; } +def Torch_AtenAtleast2dOp : Torch_Op<"aten.atleast_2d", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::atleast_2d : (Tensor) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self + ); + let results = (outs + AnyTorchOptionalTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenAtleast2dOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 1, 1); + } + void AtenAtleast2dOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 1, 1); + } + }]; +} + def Torch_AtenEinsumOp : Torch_Op<"aten.einsum", [ AllowsTypeRefinement, HasValueSemantics, diff --git a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp index 218c6840da70..9010c8afd4ac 100644 --- a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp @@ -10351,6 +10351,28 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " }\n" " return %2 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten.atleast_2d\"(%arg0: !torch.list) -> !torch.list {\n" +" %int0 = torch.constant.int 0\n" +" %int1 = torch.constant.int 1\n" +" %0 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" %1 = torch.aten.eq.int %0, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" %2 = torch.prim.If %1 -> (!torch.list) {\n" +" %3 = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list\n" +" torch.prim.If.yield %3 : !torch.list\n" +" } else {\n" +" %3 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" %4 = torch.aten.eq.int %3, %int1 : !torch.int, !torch.int -> !torch.bool\n" +" %5 = torch.prim.If %4 -> (!torch.list) {\n" +" %6 = torch.aten.__getitem__.t %arg0, %int0 : !torch.list, !torch.int -> !torch.int\n" +" %7 = torch.prim.ListConstruct %int1, %6 : (!torch.int, !torch.int) -> !torch.list\n" +" torch.prim.If.yield %7 : !torch.list\n" +" } else {\n" +" torch.prim.If.yield %arg0 : !torch.list\n" +" }\n" +" torch.prim.If.yield %5 : !torch.list\n" +" }\n" +" return %2 : !torch.list\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten.stack\"(%arg0: !torch.list>, %arg1: !torch.int) -> !torch.list {\n" " %0 = call @__torch__.torch.jit._shape_functions.stack(%arg0, %arg1) : (!torch.list>, !torch.int) -> !torch.list\n" " return %0 : !torch.list\n" @@ -14558,6 +14580,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " return %0#1 : !torch.int\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.atleast_2d\"(%arg0: !torch.tuple) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" " func.func @\"__torch_mlir_dtype_fn.aten.einsum\"(%arg0: !torch.str, %arg1: !torch.list>, %arg2: !torch.optional>) -> !torch.int {\n" " %true = torch.constant.bool true\n" " %none = torch.constant.none\n" diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index 2af330280871..c261c3459c16 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -1514,6 +1514,53 @@ class DecomposeAtenAtleast1dOp : public OpRewritePattern { }; } // namespace +namespace { +// Decompose aten.atleast_2d into: aten.reshape. See +// https://github.com/pytorch/pytorch/blob/9a8ab778d34bd24c5caceb340837483decc4c311/torch/_refs/__init__.py#L2604 +// def atleast_2d( +// arg: Union[TensorLikeType, Sequence[TensorLikeType]], *args: +// TensorLikeType +// ) -> Union[TensorLikeType, Tuple[TensorLikeType, ...]]: +// """Reference implementation of :func:`torch.atleast_2d`.""" +// if not args and isinstance(arg, collections.abc.Sequence): +// args_ = arg +// else: +// assert not isinstance(arg, collections.abc.Sequence) +// args_ = (arg,) + args +// unsqueeze_atleast_1d = partial(_unsqueeze_atleast, atleast_1d, 0) +// res = tuple(a if a.ndim >= 2 else unsqueeze_atleast_1d(a) for a in args_) +// return res if len(res) > 1 else res[0] +class DecomposeAtenAtleast2dOp : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenAtleast2dOp op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + Value input = op.getSelf(); + Type opType = op.getType(); + + auto inputType = cast(input.getType()); + SmallVector inputShape(inputType.getSizes()); + + if (inputShape.size() >= 2) { + rewriter.replaceOp(op, input); + return success(); + } + auto atleast1dResShape = + inputShape.empty() ? SmallVector{1} : inputShape; + auto atleast1dResType = rewriter.getType( + atleast1dResShape, inputType.getOptionalDtype()); + auto atleast1dRes = + rewriter.create(loc, atleast1dResType, input); + Value zero = rewriter.create( + loc, rewriter.getI64IntegerAttr(0)); + rewriter.replaceOpWithNewOp(op, opType, atleast1dRes, + zero); + return success(); + } +}; +} // namespace + namespace { // Decompose AtenEinsumOp to AtenMatmulOp, and supports possible reduce // operation and permute operation. Currently, this pass doesn't support @@ -8967,6 +9014,7 @@ class DecomposeComplexOpsPass addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); diff --git a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp index bbce3926eb9e..9620a1158ae4 100644 --- a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp +++ b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp @@ -395,6 +395,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context, target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); + target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 90b01c804098..7af1bf42389c 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -842,6 +842,9 @@ "TypeConversionUint8ToF32Module_basic", "Atleast1dModule0dInput_basic", "Atleast1dModule1dInput_basic", + "Atleast2dModule0dInput_basic", + "Atleast2dModule1dInput_basic", + "Atleast2dModule2dInput_basic", "AtenLinear1D_basic", "AtenLinear2D_basic", "AtenLinear3DBias_basic", @@ -1513,6 +1516,9 @@ "TensorSplitSections_ListUnpackModule_basic", "Atleast1dModule0dInput_basic", "Atleast1dModule1dInput_basic", + "Atleast2dModule0dInput_basic", + "Atleast2dModule1dInput_basic", + "Atleast2dModule2dInput_basic", "AtenLinear2D_basic", "AtenLinear3DBias_basic", "ElementwiseAddScalar_NumToTensorFloat_Module_basic", @@ -1993,6 +1999,9 @@ "AtenLinearVecMatBias_basic", "Atleast1dModule0dInput_basic", "Atleast1dModule1dInput_basic", + "Atleast2dModule0dInput_basic", + "Atleast2dModule1dInput_basic", + "Atleast2dModule2dInput_basic", "MaxPool1dEmptyStrideStaticModule_basic", "MaxPool1dStaticCeilModeTrueModule_basic", "MaxPool1dStaticModule_basic", diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py index e7b6a0efec4e..dcd4558da8c8 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py @@ -2066,6 +2066,15 @@ def aten〇atleast_1d〡shape(self: List[int]) -> List[int]: else: return self +def aten〇atleast_2d〡shape(self: List[int]) -> List[int]: + if len(self) == 0: + return [1, 1] + elif len(self) == 1: + x = self[0] + return [1, x] + else: + return self + def aten〇stack〡shape(tensors: List[List[int]], dim: int = 0) -> List[int]: return upstream_shape_functions.stack(tensors, dim) @@ -5118,6 +5127,11 @@ def aten〇atleast_1d〡dtype(self_rank_dtype: Tuple[int, int]) -> int: self_rank, self_dtype = self_rank_dtype return self_dtype +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) +def aten〇atleast_2d〡dtype(self_rank_dtype: Tuple[int, int]) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + @check_dtype_function( [Invocation("i,j->ij", [TensorOfShape(1, dtype=torch.float32), TensorOfShape(1, dtype=torch.int32)]),]) diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py index 07ab6dcc145c..574f7003c7db 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py @@ -784,6 +784,7 @@ def emit_with_mutating_variants(key, **kwargs): emit("aten::argmin : (Tensor, int?, bool) -> (Tensor)") emit("aten::one_hot : (Tensor, int) -> (Tensor)") emit("aten::atleast_1d : (Tensor) -> (Tensor)") + emit("aten::atleast_2d : (Tensor) -> (Tensor)") emit("aten::einsum : (str, Tensor[], int[]?) -> (Tensor)") emit("aten::trace : (Tensor) -> (Tensor)") emit("aten::bucketize.Tensor : (Tensor, Tensor, bool, bool) -> (Tensor)") diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/reshape_like.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/reshape_like.py index 3ef4978e1957..f2ef7bc5317d 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/reshape_like.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/reshape_like.py @@ -1504,3 +1504,66 @@ def forward(self, x): @register_test_case(module_factory=lambda: Atleast1dModule1dInput()) def Atleast1dModule1dInput_basic(module, tu: TestUtils): module.forward(tu.rand(4)) + + +# ============================================================================== + + +class Atleast2dModule0dInput(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([], torch.float32, True), + ] + ) + def forward(self, x): + return torch.ops.aten.atleast_2d(x) + + +@register_test_case(module_factory=lambda: Atleast2dModule0dInput()) +def Atleast2dModule0dInput_basic(module, tu: TestUtils): + module.forward(tu.rand()) + + +class Atleast2dModule1dInput(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([4], torch.float32, True), + ] + ) + def forward(self, x): + return torch.ops.aten.atleast_2d(x) + + +@register_test_case(module_factory=lambda: Atleast2dModule1dInput()) +def Atleast2dModule1dInput_basic(module, tu: TestUtils): + module.forward(tu.rand(4)) + + +class Atleast2dModule2dInput(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([4, 4], torch.float32, True), + ] + ) + def forward(self, x): + return torch.ops.aten.atleast_2d(x) + + +@register_test_case(module_factory=lambda: Atleast2dModule2dInput()) +def Atleast2dModule2dInput_basic(module, tu: TestUtils): + module.forward(tu.rand(4, 4))